CodeChef BINOMSUM
由于问题id,把答案写成带组合数的形式:
\(\sum_{i=1}^T{D+i-1\choose L}\sum_{j=0}^{K-1}\left(D+i-1\right)^jA^{K-j-1}{K-j-1\choose j}\)
\(j\)是在枚举吃了多少道菜。
后面的组合数是指:把所有吃菜的行为插入到所有行为内部的方案数。
\(K\)只有\(10^5\),这启示我们把后面写成一个项数为\(K/2\)的多项式。
把后面的式子写成上升幂多项式
(貌似这个套路在具体数学上有讲)
\(F(x)=\sum_{i=0}^{K-1}a_i\prod_{j=1}^i(x+j)\)
考虑枚举\(D+i-1\),答案就是\(\sum_{i=D}^{D+T-1}{i\choose L}F(i)\)
\(=\sum_{i=D}^{D+T-1}{i\choose L}\sum_{j=0}^{K-1}a_j\prod_{k=1}^j(i+k)\)
\(=\sum_{i=D}^{D+T-1}{i\choose L}\sum_{j=0}^{K-1}a_j\frac{(i+j)!}{i!}\)
\(=\sum_{i=0}^{K-1}a_i\sum_{j=D}^{D+T-1}{j\choose L}\frac{(i+j)!}{j!}\)
\(=\frac{1}{L!}\sum_{i=0}^{K-1}a_i\sum_{j=D}^{D+T-1}\frac{(i+j)!}{(j-L)!}\)
\(=\frac{1}{L!}\sum_{i=0}^{K-1}a_i(i+L)!\sum_{j=D}^{D+T-1}{i+j\choose i+L}\)
\(=\frac{1}{L!}\sum_{i=0}^{K-1}a_i(i+L)!\left({D+T+i\choose L+i+1}-{D+i\choose L+i+1}\right)\)
后面的套路在求自然数幂和时会用到,在具体数学上也有讲过。
如果我们求出\(a\),每次询问就可以在\(O(K)\)的时间内解决。
考虑怎么求出\(a\)。
\(F(x)=\sum_{i=0}^{K-1}a_i\cdot i!{x+i\choose i}\)
把\(-1,-2,...-k\)带入原式
\(F(-n-1)=\sum_{i=0}^{K-1}a_i\cdot i!{-n-1+i\choose i}=\sum_{i=0}^na_i\cdot i!(-1)^i{n\choose i}\)
二项式反演得到:
\(a_n\cdot n!=\sum_{i=0}^{K-1}(-1)^i{n\choose i}F(-i-1)\)
\(a_n=\sum_{i=0}^{K-1}(-1)^i\frac{1}{i!(n-i)!}F(-i-1)\)
令\(b_i=\frac{(-1)^i}{i!}F(-i-1),c_i=\frac{1}{(n-i)!}\)
把\(b\)和\(c\)卷积即可求出\(a\)。
\(F\)就是\(\sum_{j=0}^{K-1}\left(D+i-1\right)^jA^{K-j-1}{K-j-1\choose j}\)。
按照题意进行矩阵乘法优化dp即可在\(\log_2K\)的时间内求出一个点值。
被卡常数的代码:
#include<bits/stdc++.h>
using namespace std;
#define mo 998244353
#define N 500010
#pragma GCC optimize(3)
#define ll unsigned long long
#define int long long
#define pl vector<int>
int qp(int x,int y){
int r=1;
for(;y;y>>=1,x=1ll*x*x%mo)
if(y&1)r=1ll*r*x%mo;
return r;
}
int n,m,rev[N],v,le,w[N],p[N],ans[N],jc[N*40],ij[N*40],k,a,q;
void deb(pl x){
for(int i:x)cout<<i<<' ';
puts("");
}
void init(int n){
v=1;
le=0;
while(v<n)le++,v*=2;
for(int i=0;i<v;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(le-1));
int g=qp(3,(mo-1)/v);
w[v/2]=1;
for(int i=v/2+1;i<v;i++)
w[i]=1ull*w[i-1]*g%mo;
for(int i=v/2-1;~i;i--)
w[i]=w[i*2];
}
void fft(int v,pl &a,int t){
static unsigned long long b[N];
int s=le-__builtin_ctz(v);
for(int i=0;i<v;i++)
b[rev[i]>>s]=a[i];
int c=0;
w[0]=1;
for(int i=1;i<v;i*=2,c++)
for(int r=i*2,j=0;j<v;j+=r)
for(int k=0;k<i;k++){
int tx=b[j+i+k]*w[k+i]%mo;
b[j+i+k]=b[j+k]+mo-tx;
b[j+k]+=tx;
}
for(int i=0;i<v;i++)
a[i]=b[i]%mo;
if(t==0)return;
int iv=qp(v,mo-2);
for(int i=0;i<v;i++)
a[i]=1ull*a[i]*iv%mo;
a.resize(v);
reverse(a.begin()+1,a.end());
}
pl operator *(pl x,pl y){
int s=x.size()+y.size()-1;
if(x.size()<=20||y.size()<=20){
pl r;
r.resize(s);
for(int i=0;i<x.size();i++)
for(int j=0;j<y.size();j++)
r[i+j]=(r[i+j]+x[i]*y[j])%mo;
return r;
}
init(s);
x.resize(v);
y.resize(v);
fft(v,x,0);
fft(v,y,0);
//deb(x);
//deb(y);
for(int i=0;i<v;i++)
x[i]=x[i]*y[i]%mo;
fft(v,x,1);
x.resize(s);
return x;
}
void inv(int n,pl &b,pl &a){
if(n==1){
b[0]=qp(a[0],mo-2);
return;
}
inv((n+1)/2,b,a);
static pl c;
init(n*2);
c.resize(v);
b.resize(v);
for(int i=0;i<n;i++)
c[i]=a[i];
fft(v,c,0);
//deb(c);
fft(v,b,0);
//deb(b);
for(int i=0;i<v;i++)
b[i]=1ll*(2ll-1ll*c[i]*b[i]%mo+mo)%mo*b[i]%mo;
//deb(b);
fft(v,b,1);
b.resize(n);
//deb(b);
}
void ad(pl &x,pl y,int l){
x.resize(max((int)x.size(),(int)y.size()+l));
for(int i=0;i<y.size();i++)
x[i+l]=(x[i+l]+y[i])%mo;
}
pl operator +(pl x,pl y){
ad(x,y,0);
return x;
}
pl iv(pl x){
pl y;
int n=x.size();
y.resize(n);
inv(n,y,x);
y.resize(n);
return y;
}
pl operator /(pl a,pl y){
int n=a.size()-1,m=y.size()-1;
pl x,b,t;
x.resize(n+1);
b.resize(m+1);
for(int i=0;i<=n;i++)
x[n-i]=a[i];
for(int i=0;i<=m;i++)
b[m-i]=y[i];
for(int i=n-m+2;i<=m;i++)
b[i]=0;
b.resize(n-m+1);
t=iv(b);
//deb(t);
//deb(x);
//deb(t);
x=x*t;
//deb(x);
x.resize(n-m+1);
reverse(x.begin(),x.end());
return x;
}
pl operator -(pl x,pl y){
int s=max(x.size(),y.size());
x.resize(s);
y.resize(s);
for(int i=0;i<s;i++)
x[i]=(x[i]-y[i]+mo)%mo;
return x;
}
pl operator %(pl x,pl y){
int n=(int)x.size()-1,m=(int)y.size()-1;
if(x.size()<y.size())return x;
if(!m){
pl a;
a.resize(1);
return a;
}
x=x-(x/y)*y;
x.resize(m);
return x;
}
pl qd(pl x){
pl y;
int n=x.size();
y.resize(n-1);
//deb(x);
for(int i=0;i<n-1;i++)
y[i]=x[i+1]*(i+1)%mo;
//deb(y);
return y;
}
pl jf(pl x){
int n=x.size();
pl y;
y.resize(n+1);
for(int i=1;i<=n;i++)
y[i]=x[i-1]*qp(i,mo-2)%mo;
return y;
}
pl ln(pl x){
int n=x.size();
pl y=qd(x),z=iv(x);
y=y*z;
y=jf(y);
y.resize(n);
return y;
}
inline char nc(){
static char buf[500000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,500000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
char ch=nc();int sum=0;
while(!(ch>='0'&&ch<='9'))ch=nc();
while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=nc();
return sum;
}
char bf[100];
void wr(int x){
if(!x){
putchar('0');
putchar(' ');
return;
}
int ct=0;
while(x){
bf[++ct]=x%10;
x/=10;
}
for(int i=ct;i;i--)
putchar(bf[i]+'0');
putchar(' ');
}
namespace qz{
ll b[N],ans[N];
pl t;
void fz(int o,int l,int r,pl &p,pl *a){
if(l==r){
a[o].resize(2);
a[o][0]=(mo-p[l])%mo;
a[o][1]=1;
return;
}
int md=(l+r)/2;
fz(o*2,l,md,p,a);
fz(o*2+1,md+1,r,p,a);
a[o]=a[o*2]*a[o*2+1];
//deb(a[o]);
}
void ga(int o,int l,int r,pl &ans,pl *a,pl *c){
if(l==r){
ans[l]=c[o][0];
return;
}
int md=(l+r)/2;
c[o*2]=c[o]%a[o*2];
c[o*2+1]=c[o]%a[o*2+1];
ga(o*2,l,md,ans,a,c);
ga(o*2+1,md+1,r,ans,a,c);
}
void gt(pl t,pl &ans){
int n=ans.size();
static pl a[N],b[N];
fz(1,0,n-1,ans,a);
if(n>=m)b[1]=t%a[1];
ga(1,0,n-1,ans,a,b);
}
void d2(int o,int l,int r,pl &y,pl *a,pl *b){
if(l==r){
b[o].resize(1);
b[o][0]=y[l];
return;
}
int md=(l+r)/2;
d2(o*2,l,md,y,a,b);
d2(o*2+1,md+1,r,y,a,b);
b[o]=b[o*2]*a[o*2+1]+b[o*2+1]*a[o*2];
}
};
pl cz(int n,pl &x,pl &y){
static pl a[N],b[N];
qz::fz(1,0,n-1,x,a);
qz::gt(qd(a[1]),x);
//deb(x);
for(int i=0;i<n;i++)
y[i]=y[i]*qp(x[i],mo-2)%mo;
//deb(y);
qz::d2(1,0,n-1,y,a,b);
return b[1];
}
void gt(int n,pl &y,pl x){
if(n==1){
y.resize(1);
y[0]=1;
return;
}
gt((n+1)/2,y,x);
pl z=x,a;
z.resize(n);
y.resize(n);
a.resize(1);
a[0]=1;
y=y*(a-ln(y)+z);
y.resize(n);
}
pl ep(pl x){
pl y;
int n=x.size();
gt(n,y,x);
return y;
}
void put(pl a){
for(int i=0;i<a.size();i++)
printf("%lld ",a[i]);
puts("");
}
struct no{
int a[2][2];
};
no operator *(no x,no y){
no ans;
memset(ans.a,0,sizeof(ans.a));
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
ans.a[i][j]=(ans.a[i][j]+x.a[i][k]*y.a[k][j]%mo)%mo;
return ans;
}
no qp(no x,int y){
no ans;
memset(ans.a,0,sizeof(ans.a));
ans.a[0][0]=ans.a[1][1]=1;
for(;y;y>>=1,x=x*x)
if(y&1)
ans=ans*x;
return ans;
}
int cal(int x){
if(k==2)
return a;
x=(x+mo)%mo;
no va;
va.a[0][0]=0;
va.a[0][1]=x*a%mo;
va.a[1][0]=1;
va.a[1][1]=a;
va=qp(va,k-3);
return (a*va.a[0][0]%mo*x%mo+a*va.a[0][1]%mo+a*a%mo*va.a[1][0]%mo*x%mo+a*a%mo*va.a[1][1]%mo)%mo;
}
int c(int y,int x){
if(y<0||x<0||y<x)
return 0;
return jc[y]*ij[x]%mo*ij[y-x]%mo;
}
signed main(){
//freopen("bloom.in","r",stdin);
//freopen("bloom.out","w",stdout);
jc[0]=ij[0]=1;
for(int i=1;i<N*40;i++)
jc[i]=jc[i-1]*i%mo;
ij[N*40-1]=qp(jc[N*40-1],mo-2);
for(int i=N*40-1;i;i--)
ij[i-1]=ij[i]*i%mo;
scanf("%lld%lld%lld",&k,&a,&q);
vector<int>x,y;
x.resize(k);
y.resize(k);
for(int i=0;i<k;i++){
x[i]=qp(mo-1,i)*cal(-i-1)%mo*ij[i]%mo;
y[i]=ij[i]%mo;
}
x=x*y;
while(q--){
int l,d,t,ans=0;
scanf("%lld%lld%lld",&l,&d,&t);
for(int i=0;i<k;i++)
ans=(ans+x[i]*jc[i+l]%mo*(c(d+t+i,i+l+1)-c(d+i,i+l+1)+mo)%mo)%mo;
printf("%lld\n",ans*ij[l]%mo);
}
}