题意简述
数列 \(\{b_n\}\) 构造如下:(即 Thue-Morse
数列,OEIS A010060)
- \(b_0=0\)
- \(\forall i\in [2^n,2^{n+1}),b_i=1-b_{i-2^n}\)
输入 \(n\) 和一个次数为 \(m-1\) 次的多项式 \(f(x)=\sum\limits_{i=0}^{m-1}a_ix^i\) ,求 \(\sum\limits_{i=0}^{n-1}b_if(i)\) ,答案对 \(10^9+7\) 取模。
\(0\le \log_2 n\le 5\cdot 10^5\ ,\ 1\le m\le 500\ ,\ 0\le a_i< 10^9+7\ ,\ a_{m-1}\neq 0\) 。
算法分析
我们首先尝试找到 \(\{b_n\}\) 的通项,方便我们下面的求解。
由构造过程,不难发现,我们在 \(x\) 的最高位前添上一个 \(1\) (设为 \(y\) ),则 \(b_{y}=b_{x}\mathrm{xor}1\) ,这个过程中最明显的变化就是二进制中 \(1\) 的个数 \(+1\) ,因此我们有如下结论:
\(\mathrm{popcount}(n)\) 即为 \(n\) 的二进制表示下有多少个 \(1\) 。
但这个\(\pmod 2\) 并不好处理,我们转化成乘积,即:
因此答案如下:
稍作化简:
由于 \(f(x)\) 是一个 \(m-1\) 次多项式,因此可以归纳证明 \(f(x)\) 的前缀和是一个 \(m\) 次多项式,用拉格朗日插值即可求出,时间复杂度 \(O(m^2)\) 。
现在我们只看后面一部分,即:
把 \(f(i)\) 展开,交换求和顺序:
我们只要能求出 \(\sum\limits_{i=0}^{n-1}(-1)^{\mathrm{popcount}(i)}i^k\),即可 \(\mathcal{O}(m)\) 计算。
考虑类似于 数位dp
的做法,把 \([0,n)\) 这个区间拆解为若干个形如 \([x,x+2^t)(x>2^t)\) 的区间,这样我们所有的区间长度都是 \(2\) 的整次幂,方便我们计算。
举个例子,\(n=\left(10101101\right)_2\),我们拆解为如下几个区间:
- \(\left[\left(00000000\right)_2,\left(10000000\right)_2\right)\)
- \(\left[\left(10000000\right)_2,\left(10100000\right)_2\right)\)
- \(\left[\left(10100000\right)_2,\left(10101000\right)_2\right)\)
- \(\left[\left(10101000\right)_2,\left(10101100\right)_2\right)\)
- \(\left[\left(10101100\right)_2,\left(10101101\right)_2\right)\)
接下来我们的工作就是要计算 \(\sum\limits_{i=x}^{x+2^t-1}(-1)^{\mathrm{popcount}(i)}i^k\),不妨记此为 \(f(x,t,k)\),则:
我们划分区间的优势现在就体现出来了,因为 \(x>2^t\) ,所以 \(\mathrm{lowbit}(x)>\mathrm{highbit}(2^t)\) ,因此 \(\mathrm{popcount}\) 我们可以独立计算,即:
后面的 \((i+x)^k\) 用二项式定理展开,即:
继续交换求和顺序,并独立出 \(x\) ,即:
子问题出现了,即:
但 \(f(0,t,j)\) 并不能由我们上面的式子计算出,考虑 \([0,2^t)=[0,2^{t-1})\cup[2^{t-1},2^t)\),而这种区间的划分仍然满足上面的条件(\(x>2^t\)),因此:
综上,我们可以在 \(\mathcal{O}(m)\) 的时间复杂度实现 \(f(x,t,k)\) 的转移,后面两维状态数是 \(\mathcal{O}(m^2)\) 的,总共有 \(\mathcal{O}(\log_2 n+m)\)个可能的\(x\),因此因此总复杂度为 \(\mathcal{O}((\log_2n+m)m^2)\) 的,可以通过 \(60\%\) 的数据。
通过打表,发现 \(f(x,t,k)\) 在 \(x>2^k\) 或者 \(t>k\) 时都是 \(0\) 。这个也是可以归纳证明的,(但我不会证),此处略去。
因此所有 \(>m\) 的 \(x\) 都没有用,因此实际上只有 \(\mathcal{O}(m)\) 个可能的不同的 \(x\) ,复杂度为 \(\mathcal{O}(\log_2n+m^3)\) ,可以通过。
代码实现
#include<bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define maxm 2000005
#define inf 0x3f3f3f3f
#define LL long long
#define ull unsigned long long
#define db double
#define ldb long double
#define mod 1000000007
#define eps 1e-9
#define local
void file(string s){freopen((s+".in").c_str(),"r",stdin);freopen((s+".out").c_str(),"w",stdout);}
template <typename Tp> void read(Tp &x){
int fh=1;char c=getchar();x=0;
while(c>'9'||c<'0'){if(c=='-'){fh=-1;}c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c&15);c=getchar();}x*=fh;
}
int ksm(int b,int p=mod-2,int Mod=mod){int ret=1;while(p){if(p&1)ret=1ll*ret*b%Mod;b=1ll*b*b%Mod;p>>=1;}return ret;}
int iv2=((mod+1)/2);
int n,m;
char ss[maxn];
int a[maxn];
struct Larg{//拉格朗日插值求前缀和
int calc_f(int x){
int ret=0;
for(int i=m-1;i>=0;--i){
ret=(1ll*ret*x+a[i])%mod;
}
return ret;
}
int frac[505],frinv[505],xi[505],yi[505],syi[505],N;
int pv[505],sf[505];
void prepr(){
N=m+1;
frac[0]=1;for(int i=1;i<=N;++i)frac[i]=1ll*frac[i-1]*i%mod;
frinv[N]=ksm(frac[N]);for(int i=N;i;--i)frinv[i-1]=1ll*frinv[i]*i%mod;
yi[0]=calc_f(0);
for(int i=1;i<=N;++i){
xi[i]=i;yi[i]=(yi[i-1]+calc_f(i))%mod;
int prv=frinv[i-1],suf=frinv[N-i];
if((i^N)&1)suf=(mod-suf);
syi[i]=1ll*yi[i]*prv%mod*suf%mod;
}
}
int calc_sf(int x){
x%=mod;
pv[0]=sf[N+1]=1;
for(int i=1;i<=N;++i)pv[i]=1ll*pv[i-1]*(x-xi[i])%mod;
for(int i=N;i;--i)sf[i]=1ll*sf[i+1]*(x-xi[i])%mod;
int ret=0;
for(int i=1;i<=N;++i){
ret=(ret+1ll*syi[i]*pv[i-1]%mod*sf[i+1]%mod)%mod;
}
return (ret+mod)%mod;
}
}La;
int ans1,ans2;
int C[505][505];
int p2[maxn];
void prep(){
for(int i=0;i<=m;++i){
C[i][0]=C[i][i]=1;
for(int j=1;j<i;++j){
C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
}
}
p2[0]=1;
for(int i=1;i<=n;++i)p2[i]=2ll*p2[i-1]%mod;
}
int ff[505][505];
int calc(int tn,int k){//计算 f(0,tn,k)
if(tn>k)return 0;
if(tn==0)return k==0;
if(ff[tn][k]!=-1)return ff[tn][k];
int ret=calc(tn-1,k);
int X=p2[tn-1];
int px=1;
for(int j=k;j>=0;--j,px=1ll*px*X%mod){
ret=(ret+mod-1ll*C[k][j]*px%mod*calc(tn-1,j)%mod)%mod;
}
return ff[tn][k]=ret;
}
void solve(){
int tnm=0,tppc=0;
for(int i=0;i<n;++i){
int cnm=(2ll*tnm+ss[i]-'0')%mod,cppc=(tppc^(ss[i]-'0'));
if(n-i-1<=m){//无用的x直接剪枝减掉
if(ss[i]=='1'){//划分区间
int X=1ll*tnm*p2[n-i]%mod;
for(int k=0;k<m;++k){
int sm=0;
int px=1;
for(int j=k;j>=0;--j,px=1ll*px*X%mod){
sm=(sm+1ll*C[k][j]*px%mod*calc(n-i-1,j)%mod)%mod;
}
if(tppc)sm=(mod-sm);
ans2=(ans2+1ll*a[k]*sm%mod)%mod;
}
}
}
tnm=cnm,tppc=cppc;
}
}
signed main(){
#ifndef local
file("angry");
#endif
memset(ff,-1,sizeof(ff));
scanf("%s",ss);n=strlen(ss);
read(m);
for(int i=0;i<m;++i)read(a[i]);
La.prepr();prep();
int nm=0;
for(int i=0;i<n;++i)nm=(2ll*nm+ss[i]-'0')%mod;
nm=(nm+mod-1)%mod;
ans1=La.calc_sf(nm);
solve();
int ans=(ans1+mod-ans2)%mod;
ans=1ll*ans*iv2%mod;
printf("%d\n",ans);
fclose(stdin);
fclose(stdout);
return 0;
}