Codeforces 712D Memory and Scores
题目大意
链接:CF712D
给定\(a,b,k,t(1\leq a,b\leq 100,1\leq k\leq 1000,1\leq t\leq 100)\)。
取\(2t\)次数,每次取数的范围在\([-k,k]\)之间,求满足最终取出的数之和严格大于\(a-b\)的方案数。
答案对\(1e9+7\)取模。
题目分析
首先,我们可以想到一个非常暴力的做法,直接DP,时间复杂度\(O(k\cdot t^2)\)。
嗯那我们怎么优化呢?
根本不用优化,这个时间复杂度非常优秀,可以AC。QWQ
代码实现:
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdio>
#include<iomanip>
#include<cstdlib>
#define MAXN 0x7fffffff
typedef long long LL;
const int N=1005,mod=1e9+7;
using namespace std;
inline int Getint(){register int x=0,f=1;register char ch=getchar();while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}return x*f;}
int f[205][N*405];
int main(){
int a=Getint(),b=Getint(),k=Getint(),t=Getint()<<1;
int lim=a-b;
f[0][k*t]=1;
for(int i=1;i<=t;i++){
int l=k*t-k*i,r=l,ret=0;
for(int j=l,lim=k*t+k*i;j<=lim;j++){
while(r<=j+k&&r<=lim)ret=(ret+f[i-1][r])%mod,r++;
while(l<j-k)ret=(ret-f[i-1][l]+mod)%mod,l++;
f[i][j]=ret;
}
}
int ans=0;
for(int i=k*t-lim+1;i<=k*t*2;i++)ans=(ans+f[t][i])%mod;
cout<<ans;
return 0;
}
现在,我想说说该怎么让这个时间复杂度变得更加优秀。
(以下\(t\)均等于输入的\(t*2\))
我们可以先进行偏移,向右偏移\(kt\),可以列出
\[(1+x+x^2+\cdots+x^{2k})^{2t}
\]
而最终答案为指数\(> kt-(a-b)\)的项的系数和。
我们只需化简该式即可。
\[\begin{split}
ans&=(1+x+x^2+\cdots+x^{2k})^{t}\\
&=(\frac {1-x^{2k+1}}{1-x})^{t}\\
&=(1+x^{2k+1})^t\cdot(\frac 1{1-x})^t
\end{split}
\]
由二项式定理得
\[(1+x^{2k+1})^t=\sum_{i=0}^{t}\binom ti(-1)^ix^{(2k+1)(t-i)}
\]
由广义二项式定理得
\[(\frac 1{1-x})^t=1+\binom t{t-1}x+\binom {t+1}{t-1}x^2+\cdots
\]
所以
\[ans=(\sum_{i=0}^t\binom ti(-1)^ix^{(2k+1)(t-i)})\cdot(1+\binom t{t-1}x+\binom{t+1}{t-1}x^2+\cdots)
\]
其中,只有系数\(i\)满足\(kt-(a-b)<i\leq 2kt\)的项会对答案产生贡献。
可以预处理出右侧的前缀和\(sum\),枚举左边的\(i\),找到可以产生贡献的区间\([l,r]\)。
\[ans=\sum_{i=0}^t\binom ti(-1)^i(sum[r]-sum[l-1])
\]
最终时间复杂度\(O(kt)\)。
代码实现
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdio>
#include<iomanip>
#include<cstdlib>
#define MAXN 0x7fffffff
typedef long long LL;
const int N=800005,T=1005,mod=1e9+7;
using namespace std;
inline int Getint(){register int x=0,f=1;register char ch=getchar();while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}return x*f;}
int ksm(int x,int k){
int ret=1;
while(k){
if(k&1)ret=(LL)ret*x%mod;
x=(LL)x*x%mod,k>>=1;
}
return ret;
}
int fac[N],inv[N];
int C(int n,int m){if(n<m)return 0;return (LL)fac[n]*inv[m]%mod*inv[n-m]%mod;}
int sum[N];
int main(){
int a=Getint(),b=Getint(),k=Getint(),t=Getint()<<1;
int lim=k*t-a+b+1;
fac[0]=1;
for(int i=1;i<=410000;i++)fac[i]=(LL)fac[i-1]*i%mod;
inv[410000]=ksm(fac[410000],mod-2);
for(int i=410000-1;~i;i--)inv[i]=(LL)inv[i+1]*(i+1)%mod;
sum[0]=1;for(int i=1;i<=2*k*t;i++)sum[i]=(sum[i-1]+C(t+i-1,t-1))%mod;
int ans=0;
for(int i=0;i<=t;i++){
int nw=(2*k+1)*(t-i),l=max(lim-nw,0),r=2*k*t-nw;
if(l>r||r<0)continue;
if(l>r)swap(l,r);
ans=(ans+(LL)C(t,i)*((i&1)?-1:1)*((sum[r]-(l?sum[l-1]:0)+mod)%mod)%mod+mod)%mod;
}
cout<<ans;
return 0;
}