【题解】洛谷 P7468 [NOI Online 2021 提高组] 愤怒的小N【自然数幂和 结论题】
题意
定义串 \(A_0=\texttt{a}\),\(A_i=A_{i-1}+\overline{A_{i-1}}\)(其中上划线代表 a b 互换),考虑串 \(A_\infty[0..n-1]\)。给定 \(k-1\) 次多项式 \(f\),求 \(\sum_{i=0}^{n-1}[A_\infty[i]=\texttt{b}]f(i)\)。\(n\leq 2^{500000}\),\(k\leq 500\)。
题解
我们首先考虑 \(O(k^2\log n)\) 的做法:我们维护多项式 \(F_{i,\texttt{a}/\texttt{b}}(x)\) 代表把字符串 \(A_i\) 平移到 \(x\) 处后,\(\texttt{a}/\texttt{b}\) 的贡献。
我们有 \(F_{i,\texttt{a}}(x)=F_{i-1,\texttt{a}}(x)+F_{i-1,\texttt{b}}(x+2^{i-1})\)(\(F_{i,\texttt{b}}\) 同理),处理系数平移需要 \(O(k^2)\)(二项式定理展开),一共要递推到 \(F_{\log n,\texttt{a}/\texttt{b}}\)。
如果你在调试时输出了多项式,你可能会发现在 \(i\geq k\) 时,\(F_{i,\texttt{a}}=F_{i,\texttt{b}}\)。证明可以参见官方题解。(考试时我输出了调试信息,但是没发现结论)
知道这个结论后,我们可以只处理到 \(F_{k-1,\texttt{a}/\texttt{b}}\),剩下的部分借助自然数幂和计算。复杂度 \(O(k^3+\log n)\)。
#include<bits/stdc++.h>
using namespace std;
int getint(){
int ans=0;
char c=getchar();
while(!isdigit(c))c=getchar();
while(isdigit(c)){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
const int K=503,N=500010,mod=1e9+7;
int f[2][K],g[2][K];
int c[K][K];
char s[N];int n;
int p2[N];
int p[N],q[N];
int sp[N];
int qpow(int x,int y){
int ans=1;
while(y){
if(y&1)ans=ans*1ll*x%mod;
x=x*1ll*x%mod;
y>>=1;
}
return ans;
}
int main(){
freopen("angry.in","r",stdin);
freopen("angry.out","w",stdout);
scanf("%s",s+1);
n=strlen(s+1);
int k=getint();
if(n<k){
for(int i=n;i;i--)s[k-n+i]=s[i];
for(int i=1;i<=k-n;i++)s[i]='0';
n=k;
}
for(int i=0;i<k;i++)f[0][i]=getint();
for(int i=c[0][0]=1;i<=k;i++)
for(int j=c[i][0]=1;j<=k;j++)
c[i][j]=(c[i-1][j]+c[i-1][j-1])%mod;
p2[0]=1;for(int i=1;i<=n;i++)p2[i]=p2[i-1]<<1,p2[i]>=mod&&(p2[i]-=mod);
for(int i=1;i<=n;i++){
q[n-i]=q[n-i+1]^(s[i]-'0');
p[n-i]=(p[n-i+1]+(s[i]-'0')*p2[n-i])%mod;
}
int ans=0;
int m=p[0];
// cerr<<"> "<<m<<endl;
sp[0]=m;
for(int i=1;i<k;i++){
sp[i]=qpow(m,i+1);
for(int j=0;j<i;j++)
sp[i]=(sp[i]-c[i+1][j]*1ll*sp[j]%mod+mod)%mod;
sp[i]=sp[i]*1ll*qpow(i+1,mod-2)%mod;
// cerr<<"> "<<sp[i]<<endl;
}
for(int i=0;i<k;i++)ans=(ans+f[0][i]*1ll*sp[i])%mod;
for(int i=0;i<k;i++){
// for(int i=0;i<k;i++)cerr<<">"<<f[0][i];cerr<<endl;
// for(int i=0;i<k;i++)cerr<<">"<<f[1][i];cerr<<endl;cerr<<endl;
int q=::q[i],p=::p[i+1];
if(s[n-i]-'0'){
int v=0;
for(int j=0;j<k;j++)
v=(v*1ll*p+f[q][k-j-1])%mod;
ans=(ans+v)%mod;
v=0;
for(int j=0;j<k;j++)
v=(v*1ll*p+f[q^1][k-j-1])%mod;
ans=(ans+mod-v)%mod;
}
memcpy(g,f,sizeof(g));
int e=p2[i];
for(int j=0;j<k;j++){
int pe=1;
for(int l=j;l>=0;l--,pe=pe*1ll*e%mod){
int t=pe*1ll*c[j][l]%mod;
g[1][l]=(g[1][l]+f[0][j]*1ll*t)%mod;
g[0][l]=(g[0][l]+f[1][j]*1ll*t)%mod;
}
}
memcpy(f,g,sizeof(f));
}
ans=(mod+1ll)/2*ans%mod;
cout<<ans;
}