LOJ 2409「THUPC 2017」小 L 的计算题 / Sum
思路
和玩游戏一题类似
定义\(A_k(x)=\sum_{i=0}^\infty a_k^ix^i=\frac{1}{1-a_kx}\)
用\(\ln 'x\)代替\(\frac{1}{x}\),
所以就是求
\[f(x)=\sum_{i=1}^n \ln'(1-a_ix)
\]
这样没法快速计算
所以再设\(G(x)=\sum _{i=1}^n (ln(1-a_ix))'\)
所以
\[G(x)=\sum_{i=1}^n\frac{-a_i}{1-a_ix}
\]
所以
\[f(x)=-xg(x)+n
\]
\[G(x)=\ln'(\prod_{i=1}^n (1-a_ix))
\]
然后上分治+NTT就可以在\(O(n\log^2n)\)的时间内解决了
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#define int long long
using namespace std;
const int MOD = 998244353;
const int G = 3;
const int invG = 332748118;
const int MAXN = 2000000;
int pow(int a,int b){
int ans=1;
while(b){
if(b&1)
ans=(1LL*ans*a)%MOD;
a=(1LL*a*a)%MOD;
b>>=1;
}
return ans;
}
int rev[MAXN],inv_val[MAXN];
void cal_rev(int n,int lim){
for(int i=0;i<n;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(lim-1));
}
void cal_inv(int n){
inv_val[0]=0;
inv_val[1]=1;
for(int i=2;i<=n;i++)
inv_val[i]=(1LL*(MOD-MOD/i)*inv_val[MOD%i])%MOD;
}
void NTT(int *a,int opt,int n,int lim){
for(int i=0;i<n;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int i=2;i<=n;i<<=1){
int len=i/2,tmp=pow((opt)?G:invG,(MOD-1)/i);
for(int j=0;j<n;j+=i){
int arr=1;
for(int k=j;k<j+len;k++){
int t=(1LL*a[k+len]*arr)%MOD;
a[k+len]=(a[k]-t+MOD)%MOD;
a[k]=(a[k]+t)%MOD;
arr=(1LL*arr*tmp)%MOD;
}
}
}
if(!opt){
int invN = pow(n,MOD-2);
for(int i=0;i<n;i++)
a[i]=(1LL*a[i]*invN)%MOD;
}
}
void mul(int *a,int *b,int &at,int bt){
int midlen=1,midlim=0;
while((midlen)<(at+bt+2))
midlen<<=1,midlim++;
cal_rev(midlen,midlim);
static int tmp[MAXN];
for(int i=0;i<midlen;i++)
tmp[i]=b[i];
NTT(a,1,midlen,midlim);
NTT(tmp,1,midlen,midlim);
for(int i=0;i<midlen;i++)
a[i]=(1LL*a[i]*tmp[i])%MOD;
NTT(a,0,midlen,midlim);
at+=bt;
for(int i=0;i<=at;i++)
tmp[i]=0;
for(int i=at+1;i<midlen;i++)
a[i]=0,tmp[i]=0;
}
void inv(int *a,int *b,int dep,int &midlen,int &midlim){
if(dep==1){
b[0]=pow(a[0],MOD-2);
return;
}
inv(a,b,(dep+1)>>1,midlen,midlim);
static int tmp[MAXN];
while((dep<<1)>midlen)
midlen<<=1,midlim++;
for(int i=0;i<dep;i++)
tmp[i]=a[i];
for(int i=dep;i<midlen;i++)
tmp[i]=0;
cal_rev(midlen,midlim);
NTT(tmp,1,midlen,midlim);
NTT(b,1,midlen,midlim);
for(int i=0;i<midlen;i++)
b[i]=(1LL*b[i]*(2-1LL*tmp[i]*b[i]%MOD+MOD)%MOD)%MOD;
NTT(b,0,midlen,midlim);
for(int i=dep;i<midlen;i++)
b[i]=0;
}
void qd(int *a,int &t){
for(int i=0;i<t;i++)
a[i]=1LL*a[i+1]*(i+1)%MOD;
a[t]=0;
t--;
}
void jf(int *a,int &t){
t++;
for(int i=t;i>0;i--){
a[i]=1LL*a[i-1]*inv_val[i]%MOD;
}
a[0]=0;
}
void ln(int *a,int *b,int n){
static int tmp[MAXN];
for(int i=0;i<n;i++)
tmp[i]=0,b[i]=a[i];
int midlen=1,midlim=0;
inv(a,tmp,n,midlen,midlim);
int t=n;
qd(b,t);
mul(b,tmp,t,n);
jf(b,t);
for(int i=0;i<n;i++)
tmp[i]=0;
for(int i=n;i<midlen;i++)
b[i]=tmp[i]=0;
}
int n,a[MAXN],b[MAXN],c[MAXN];
int cnt=0,barrel[40][MAXN];
void solve(int l,int r,int *val,int &len){
if(l==r){
val[0]=1;
val[1]=MOD-a[l];
len=1;
return;
}
int *la=barrel[cnt++],*ra=barrel[cnt++];
int num=cnt,lenl,lenr;
// printf("num=%lld\n",cnt);
int mid=(l+r)>>1;
solve(l,mid,la,lenl);
solve(mid+1,r,ra,lenr);
int midlen=1,midlim=0;
while(midlen<(lenl+lenr+2))
midlen<<=1,midlim++;
cal_rev(midlen,midlim);
NTT(la,1,midlen,midlim);
NTT(ra,1,midlen,midlim);
for(int i=0;i<midlen;i++)
val[i]=(1LL*la[i]*ra[i])%MOD;
NTT(val,0,midlen,midlim);
for(int i=0;i<midlen;i++)
la[i]=ra[i]=0;
len=lenl+lenr;
cnt=num-2;
}
signed main(){
// freopen("test.in","r",stdin);
// freopen("test.out","w",stdout);
int T;
scanf("%lld",&T);
while(T--){
// printf("ok %lld\n",T);
scanf("%lld",&n);
cal_inv(n+10);
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]);
int len=0;
solve(1,n,b,len);
// for(int i=0;i<=n;i++)
// printf("%lld ",b[i]);
// printf("\n");
ln(b,c,n+1);
// for(int i=0;i<=n;i++)
// printf("%lld ",c[i]);
// printf("\n");
int t=n;
qd(c,t);
for(int i=n;i>=1;i--)
c[i]=MOD-c[i-1];
c[0]=n;
// for(int i=0;i<=n;i++)
// printf("%lld ",c[i]);
// printf("\n");
// for(int i=1;i<=n;i++)
// printf("%lld ",c[i]);
// printf("\n");
int ans=0;
for(int i=1;i<=n;i++)
ans^=c[i];
printf("%lld\n",ans);
for(int i=0;i<=n;i++)
a[i]=b[i]=c[i]=0;
}
return 0;
}