LOJ2541 猎人杀

一道很棒的\(NTT\)
题目链接


\(30\)分做法

\(n\leq 20\)时,很容易想到令\(f[S]\)表示死亡集合为\(S\)时的几率.\(f[S\cup i(i\notin S)]=f[S\cup i(i\notin S)]+f[S]*\frac{w_i}{\sum_{j\notin S} w}\)

当然,我们还有第二种做法
\(f[S]\)表示\(S\)\(1\)之后死的方案数.
\(ans=\sum_S(-1)^{|S|}f[S],f[S]=\frac{w_1}{w_1+w_S}\)
好像并没有什么用对不对...

\(50\)分做法

我们就可以把它变为50分的.
由于\(\sum w\leq 1000\),考虑对不同的分母(即\(f[S]\)的分母)分别处理.只要处理容斥系数即可.
考虑\(01\)背包.由于新选一个节点会使前面的所有节点奇偶性该边,因此\(f[i][j]=f[i-1][j]-f[i-1][j-a[i]]\).
求出\(f[n][W]\)即可.
核心代码如下
注意这里滚动数组,\(i\)要倒着枚举.

for(int j=2;j<=n;j++)
for(int i=S;i>=0;i--)
if(i>=w[j])f[i]=(f[i]-f[i-w[j]])%P;

\(100\)分做法

考虑上面的方程代表什么.
用生成函数的一套理论代进去.
\(f[i]\)表示分母为\(i\)的时候对答案的容斥系数贡献.
令多项式\(A=\sum_{i=1}^nf[i]x^i\)
新加入一个仇恨度为\(w\)的人,相当于将\(A\)乘上\(1+x^w\).
因此答案就是\(\prod_{i=2}^n(x^w_i+1)\)
然后暴力乘起来就好了.注意每次要分成左右两边处理.这样复杂度是可以保证的,为\(Wlog^2W\)

代码如下

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define N (400010)
#define P (998244353)
#define inf (0x7f7f7f7f)
#define rg register int
#define Label puts("NAIVE")
#define spa print(' ')
#define ent print('\n')
#define rand() (((rand())<<(15))^(rand()))
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
inline char read(){
	static const int IN_LEN=1000000;
	static char buf[IN_LEN],*s,*t;
	return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x){
	static bool iosig;
	static char c;
	for(iosig=false,c=read();!isdigit(c);c=read()){
		if(c=='-')iosig=true;
		if(c==-1)return;
	}
	for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0');
	if(iosig)x=-x;
}
inline char readchar(){
	static char c;
	for(c=read();!isalpha(c);c=read())
	if(c==-1)return 0;
	return c;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN],*ooh=obuf;
inline void print(char c) {
	if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf;
	*ooh++=c;
}
template<class T>
inline void print(T x){
	static int buf[30],cnt;
	if(x==0)print('0');
	else{
		if(x<0)print('-'),x=-x;
		for(cnt=0;x;x/=10)buf[++cnt]=x%10+48;
		while(cnt)print((char)buf[cnt--]);
	}
}
inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);}
LL w[N],ans,mi[N],iv[N];
int n,S,mx,Lim,rev[N];
LL ksm(LL a,LL p){
	LL res=1;
	while(p){
		if(p&1)res=(res*a)%P;
		a=(a*a)%P,p>>=1;
	}
	return res;
}
void NTT(vector<LL> &a,int tp){
	for(int i=0;i<Lim;i++)
	if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int pos=1,s=1;pos<Lim;pos<<=1,s++){
		LL w=mi[s]; if(tp==-1)w=iv[s];
		for(int R=pos<<1,j=0;j<Lim;j+=R){
			LL p=1; 
			for(int k=j;k<j+pos;k++,p=(p*w)%P){
				LL x=a[k],y=(p*a[k+pos])%P;
				a[k]=(x+y)%P,a[k+pos]=(x-y+P)%P;
			}
		}
	}
	if(tp==-1){
		LL inv=ksm(Lim,P-2)%P;
		for(int i=0;i<Lim;i++)a[i]=(a[i]*inv)%P;
	}
}
struct Poly{vector<LL> a;int mx;}a[N];
Poly work(int l,int r){
	if(l==r)return a[l];
	int len=0,mid=(l+r)>>1;
	Poly a=work(l,mid);
	Poly b=work(mid+1,r);
	int sum=a.mx+b.mx; rev[0]=0;
	for(Lim=1;Lim<=sum;Lim<<=1)len++;
	for(int i=0;i<Lim;i++)
	rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
	a.a.resize(Lim),b.a.resize(Lim);
	NTT(a.a,1),NTT(b.a,1);
	for(int i=0;i<Lim;i++)
	a.a[i]=(a.a[i]*b.a[i])%P;
	NTT(a.a,-1),a.mx=sum;
	return a;
}
int main(){
	read(n);
	for(int i=1;i<=n;i++)read(w[i]),S+=w[i];
	for(int i=1;(1<<i)<=P;i++)mi[i]=ksm(3ll,(P-1)/(1<<i))%P,iv[i]=ksm(mi[i],P-2)%P;
	for(int i=2;i<=n;i++)a[i].a.resize(w[i]+1),a[i].a[0]=1,a[i].a[w[i]]=P-1,a[i].mx=w[i];
	Poly f=work(2,n);
	for(int i=0;i<=S;i++)(ans+=f.a[i]*w[1]%P*ksm((LL)w[1]+i,P-2)%P+P)%=P;
	printf("%lld\n",ans);
}
posted @ 2018-12-01 09:06  Romeolong  阅读(195)  评论(0编辑  收藏  举报