【XSY3241】暴风士兵

【XSY3241】暴风士兵

他是暴风士兵,我是伞兵。

我们考虑令\(C(x)=\sum_{i=0}^{exp}(exp-i)x^i\)\(P(x)\)为扣\(i\)滴血的概率\(P_i\)的生成函数。

那么不难发现,对于一个时间\(t\),答案即为:

\[ans_t=\sum_{i=0}^{exp}C_iP_i \]

然后我们不难发现,每经过一个时间点\(t\)\(P(x)\times=(P_ix+(1-P_i))\)

但这样似乎还是\(n^2\)的,做不了呀。

我们考虑随便设一个断点\(k\),然后让\(A(x)=\prod_{i=1}^k(P_ix+(1-P_i))\),\(B(x)=\prod_{i=k+1}^t(P_ix+(1-P_i))\),于是就有:

\[\begin{aligned} ans_t &= \sum_{i=0}^{exp} C_i [x^i]A(x)B(x)\\ &= \sum_{i=0}^{exp} C_i \sum_{j=0}^i A_jB_{i-j}\\ &= \sum_{i=0}^{exp} B_i \sum_{j=0}^{exp-i} C_{j+i}A_j\\ &= \sum_{i=0}^{exp} C'i[x^i]B(x) \end{aligned} \]

于是我们对于每一个\(t\),令\(k=t-1\),可以用分治\(NTT\)加上减法卷积算出\([1,t-1]\)\(C'\),然后点乘上\(i\)处的\(B\)就可以了。此处的\(B\)只有两项,\(O(1)\)即可。

(于是一个强制在线问题被分治搞掉了,真的高。)

(还有更草的,作为蒟蒻的我之前居然没有试过减法NTT的分治www然后卡了半天(

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x<y?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline int qpow(int n,int k){
    int ret=1;
    while(k){
        if(k&1)ret=mul(ret,n);
        n=mul(n,n);
        k>>=1;
    }
    return ret;
}
int G[2][270010][20];
void init(int lim){
    for(int mid=1,dep=0;mid<lim;mid<<=1,dep++){
        int len=mid<<1;
        int gn=qpow(3,(mod-1)/len);
        int ign=qpow(gn,mod-2);
        int g=1,ig=1;
        for(int j=0;j<mid;++j,g=mul(g,gn),ig=mul(ig,ign)){
            G[1][j][dep]=g,G[0][j][dep]=ig;
        }
    }
}
int rev[270010];
void NTT(int *A,int lim,int opt){
    for(int i=0;i<lim;++i){
        rev[i]=(rev[i>>1]>>1)|((i&1)*(lim>>1));
        if(i<rev[i])swap(A[i],A[rev[i]]);
    }
    for(int mid=1,dep=0;mid<lim;mid<<=1,dep++){
        int len=mid<<1;
        for(int i=0;i<lim;i+=len){
            for(int j=0;j<mid;++j){
                int x=A[i+j],y=mul(G[opt][j][dep],A[i+j+mid]);
                A[i+j]=add(x,y);
                A[i+j+mid]=dec(x,y);
            }
        }
    }
    if(!opt){
        int div=qpow(lim,mod-2);
        for(int i=0;i<lim;++i)A[i]=mul(A[i],div);
    }
}
int lst;
vector<int> c[400010];
vector<int> p[400010];
#define ls (o<<1)
#define rs (o<<1|1)
void solve(int o,int l,int r){
	if(l==r){
		int nowp;
		scanf("%d",&nowp);
		p[o].push_back(dec(1,add(nowp,lst)));p[o].push_back(add(nowp,lst));
//		cout<<p[o][0]<<" "<<p[o][1]<<" "<<c[o][0]<<" "<<c[o][1]<<endl;
		printf("%d\n",lst=(add(mul(c[o][0],p[o][0]),mul(c[o][1],p[o][1]))));
		return;
	}
	static int A[270010],B[270010];
	int mid=(l+r)/2;
	int len=(r-l+1),lenl=(mid-l+1),lenr=(r-mid);
	for(int i=0;i<=lenl;++i)c[ls].push_back(c[o][i]);
	solve(ls,l,mid);
	int lim=1;
	while(lim<=len+lenl)lim<<=1;
	for(int i=0;i<=len;++i)A[i]=c[o][i];
	for(int i=0;i<=lenl;++i)B[i]=p[ls][lenl-i];
	NTT(A,lim,1),NTT(B,lim,1);
	for(int i=0;i<lim;++i)A[i]=mul(A[i],B[i]);
	NTT(A,lim,0);
	for(int i=0;i<=lenr;++i)c[rs].push_back(A[i+lenl]);
	for(int i=0;i<lim;++i)A[i]=B[i]=0;
	solve(rs,mid+1,r);
	lim=1;
	while(lim<=lenl+lenr)lim<<=1;
	for(int i=0;i<=lenl;++i)A[i]=p[ls][i];
	for(int i=0;i<=lenr;++i)B[i]=p[rs][i];
	NTT(A,lim,1),NTT(B,lim,1);
	for(int i=0;i<lim;++i)A[i]=mul(A[i],B[i]);
	NTT(A,lim,0);
	for(int i=0;i<=len;++i)p[o].push_back(A[i]);
	for(int i=0;i<lim;++i)A[i]=B[i]=0;
	p[ls].clear(),p[rs].clear();
}
#undef ls
#undef rs
int main(){
	int exp,n;
	scanf("%d%d",&exp,&n);
	lst=exp;
	int lim=1;
	while(lim<=(n<<1))lim<<=1;
	init(lim);
	for(int i=0;i<=exp;++i)c[1].push_back(exp-i);
	for(int i=exp+1;i<=n;++i)c[1].push_back(0);
	solve(1,1,n);
}
posted @ 2021-08-02 19:11  FakeDragon  阅读(49)  评论(0编辑  收藏  举报