loj#2541. 「PKUWC2018」猎人杀

传送门

思路太清奇了……

考虑容斥,即枚举至少有哪几个是在\(1\)号之后被杀的。设\(A=\sum_{i=1}^nw_i\)\(S\)为那几个在\(1\)号之后被杀的人的\(w\)之和。关于杀了人之后分母的变化,我们可以假设这个人被杀之后还活着(说好的人被杀就会死呢),不过如果选到了它要再选一次,这个和之前的是等价的。于是这几个人在\(1\)之后被杀的概率为$$P=\sum_{i=0}^\infty (1-\frac{S+w_1}{A})^i\frac{w_1}{A}$$

\[P=\frac{w_1}{A}\sum_{i=0}^\infty (1-\frac{S+w_1}{A})^i \]

\[P=\frac{w_1}{A}\times \frac{1}{1-1+\frac{S+w_1}{A}} \]

\[P=\frac{w_1}{S+w_1} \]

直接暴力枚举不现实,由于\(\sum w_i\leq 10^5\),所以我们可以把每个\(S\)的系数,即每个\(S\)出现了多少次给求出来,然后直接计算

由于\(S\)是一堆\(w_i\)乘起来的,而且因为容斥系数所以每多乘上一个\(w_i\)就要变一次号,所以我们可以把每一个\(w_i\)写成生成函数的形式\(1-x^{w_i}\),然后求出\(\prod_{i=2}^n(1-x^{w_i})\),那么\(S\)的系数就是\(x^S\)

然后它这个分治\(NTT\)的意思大概就是……如果我们直接把这几个多项式乘起来复杂度是\(O(nm\log m)\)的(其中\(m\)为最大的次数),因为所有多项式的次数之和为\(m\),我们可以把多项式两两合并,那么每一次多项式的个数都会减少一半,于是总的层数为\(O(\log n)\),又因为每一层多项式的次数之和为\(m\),于是每一层的时间复杂度都是\(O(m\log m)\),那么总的时间复杂度就是\(O(m\log n\log m)\)

据说还有用生成函数的乱七八糟的姿势以及多项式\(\exp\)做到\(O(m\log m)\)的,然而我不会啊2333

//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
    R int res,f=1;R char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
    return res*f;
}
double readdb()
{
    R double x=0,y=0.1,f=1;R char ch;
    while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
    for(x=ch-'0';(ch=getc())>='0'&&ch<='9';x=x*10+ch-'0');
    for(ch=='.'&&(ch=getc());ch>='0'&&ch<='9';x+=(ch-'0')*y,y*=0.1,ch=getc());
    return x*f;
}
char sr[1<<21],z[20];int C=-1,Z=0;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
void print(R int x){
    if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x;
    while(z[++Z]=x%10+48,x/=10);
    while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=5e5+5,P=998244353,Gi=332748118;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
	R int res=1;
	for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
	return res;
}
int A[35][N],O[N],r[N],w[N],deg[N];
int n,m,tot,sum,ans;
void NTT(int *A,int ty,int lim){
	fp(i,0,lim-1)if(i<r[i])swap(A[i],A[r[i]]);
	for(R int mid=1;mid<lim;mid<<=1){
		R int I=(mid<<1),Wn=ksm(ty==1?3:Gi,(P-1)/I);O[0]=1;
		fp(i,1,mid-1)O[i]=mul(O[i-1],Wn);
		for(R int j=0;j<lim;j+=I)for(R int k=0;k<mid;++k){
			int x=A[j+k],y=mul(O[k],A[j+k+mid]);
			A[j+k]=add(x,y),A[j+k+mid]=dec(x,y);
		}
	}if(ty==-1)for(R int i=0,inv=ksm(lim,P-2);i<lim;++i)A[i]=mul(A[i],inv);
}
void solve(int ql,int qr){
	if(ql==qr){
		++tot,A[tot][0]=1,A[tot][w[ql]]=-1,deg[tot]=w[ql];
		fp(i,1,w[ql]-1)A[tot][i]=0;
		return;
	}int mid=(ql+qr)>>1;solve(ql,mid),solve(mid+1,qr);
	int lim=1,l=0,x=tot-1,y=tot,m=deg[x]+deg[y];
	while(lim<=m)lim<<=1,++l;
	fp(i,0,lim-1)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	fp(i,deg[x]+1,lim-1)A[x][i]=0;
	fp(i,deg[y]+1,lim-1)A[y][i]=0;
	NTT(A[x],1,lim),NTT(A[y],1,lim);
	fp(i,0,lim-1)A[x][i]=mul(A[x][i],A[y][i]);
	NTT(A[x],-1,lim);
	--tot,deg[tot]=m;
}
int main(){
//	freopen("testdata.in","r",stdin);
	n=read();fp(i,1,n)w[i]=read(),sum+=w[i];
	sum-=w[1];
	if(n==1)return puts("1"),0;
	solve(2,n);
	fp(i,0,sum)ans=add(ans,mul(w[1],mul(A[1][i],ksm(i+w[1],P-2))));
	printf("%d\n",ans);return 0;
}
posted @ 2019-01-02 16:06  bztMinamoto  阅读(197)  评论(0编辑  收藏  举报
Live2D