Loj #6261. 一个人的高三楼

link : https://loj.ac/problem/6261

 

一看就是一个已经退役的大佬出的题。。。

我一开始还是too young了,忘了看时限,,以为NTT+多项式快速幂就能水过的。。。

于是就写了个我代码里被我注释掉的东西。。。。。

后来被卡了之后才想起来全是1的多项式的N次方的各个项的系数是可以直接用组合数算出来的。。。。

具体的说,设 A = {1,x,x^2,,,,,,,} ^ N 。

A中的 x^i 的系数就是 C(N+i-1,i) ,因为这就是可重组合的定义式吧23333。

而根据组合数的通项公式,我们可以很容易的从 x^i 的系数来递推 x^(i+1) 的系数。

O(N)  计算出 A 之后,直接一遍NTT 让 a卷一下它,得到结果。

 

写的时候犯了很多SB错误,想想都想打自己23333.

    1.一开始求补0之后序列长度的逆元求成补0之前的了,,,虽然这个样例并不能看出来因为样例 4=2^2   2333.

    2. 数组大小一定要开到 > 卷积后最高次数 的最小的 2^k 。

   3.K比较大,,,乘积的时候得先%ha之后再运算233

 

#include<bits/stdc++.h>
#define ll long long
#define maxn 400005 
using namespace std;
const int ha=998244353;
const int root=3;
const int inv=ha/3+1;
ll k;
int a[maxn],b[maxn],n,N;
int INV,e[maxn],r[maxn],l;
int object[2][maxn];
int ni[maxn];

inline int ksm(int x,int y){
	int an=1;
	for(;y;y>>=1,x=x*(ll)x%ha) if(y&1) an=an*(ll)x%ha;
	return an;
}

inline int add(int x,int y){
	x+=y;
	return x>=ha?x-ha:x;
}

inline void NTT(int *c,int f){
	for(int i=0;i<N;i++) if(i<r[i]) swap(c[i],c[r[i]]);
	
	for(int i=1,o=1;i<N;i<<=1,o++){
		int omega=object[f==-1][o];

		for(int p=i<<1,j=0;j<N;j+=p){
			int now=1;
			for(int u=0;u<i;u++,now=now*(ll)omega%ha){
				int x=c[j+u],y=c[j+u+i]*(ll)now%ha;
				c[j+u]=add(x,y);
				c[j+u+i]=add(x,ha-y);
			}
		}
	}
	
	if(f==-1) for(int i=0;i<N;i++) c[i]=c[i]*(ll)INV%ha;
}

inline void calc(){
	b[0]=1;
	for(int i=1;i<n;i++) b[i]=b[i-1]*((ll)(k+i-1)%ha)%ha*(ll)ni[i]%ha;
}

inline void solve(){
	int len=(n-1)<<1;
	for(N=1,l=0;N<=len;N<<=1) l++;
	for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	for(int i=1;i<=l;i++){
		object[0][i]=ksm(root,(ha-1)/(1<<i));
		object[1][i]=ksm(inv,(ha-1)/(1<<i));
	}
	INV=ksm(N,ha-2);
	ni[1]=1;
	for(int i=2;i<=n;i++) ni[i]=-ni[ha%i]*(ll)(ha/i)%ha+ha;
	
	calc();
	
	/*
	while(k){
		if(k&1){
			NTT(a,1),NTT(b,1);
			for(int i=0;i<N;i++) a[i]=a[i]*(ll)b[i]%ha;
			NTT(a,-1),NTT(b,-1);
			fill(a+n,a+N,0);
		}
		
		NTT(b,1);
		for(int i=0;i<N;i++) b[i]=b[i]*(ll)b[i]%ha;
		NTT(b,-1);
		fill(b+n,b+N,0);
		k>>=1;
	}
	*/
	
	NTT(a,1),NTT(b,1);
	for(int i=0;i<N;i++) a[i]=a[i]*(ll)b[i]%ha;
	NTT(a,-1);
}

int main(){
	scanf("%d%lld",&n,&k);
	for(int i=0;i<n;i++){
		scanf("%d",a+i);
	}
	
	solve();

	for(int i=0;i<n;i++) printf("%d\n",a[i]);
	return 0;
}

  

posted @ 2018-03-10 18:44  蒟蒻JHY  阅读(474)  评论(0编辑  收藏  举报