【UR #8】宿命多项式
有个多项式\(f(x)=\sum_{i=0}^na_ix^i\)。规定\(i=0..n,f(i)\in[1,c_i]\)。问有多少组合法的\(a_i\)(\(a_i\)都是整数)
\(n\le 6\)
为了方便把范围定成\([0,c_i-1]\)。
首先写成下降幂多项式,好处是:\(f(i)\)的表达式可以由\(a_{0..i}\)确定。设\(f(x)=\sum_{i=0}^n q_i x^{\underline i}\)。
假如决定了\(q_{0..i-1}\),对于\(f(i)\)的表达式,要求:\(0\le q_ii!+\sum_{j=0}^{i-1}q_ji^{\underline j}<c_i\)。设\(C=\sum_{j=0}^{i-1}q_ji^{\underline j}\),则解的个数为\(\lfloor\frac{c_i}{i!}\rfloor+[c_i\mod i!>c\mod i!]\)。
注意到\(q_ji^{\underline j}=(q_j\mod (i-j)!)i^{\underline j} \pmod {i!}\)。这样我们就将需要知道的\(q_j\)的取值限定了。
设\(r_i=q_i\mod (n-i)!\)。枚举\(r_i\)。
于是对于\(f(i)\),令\(q_i=(n-i)!t+r_i\),则要求:\(0\le (n-i)!i!t+r_ii!+\sum_{j=0}^{i-1}q_ji^{\underline j}<c_i\)。
类似地,需要求出\((r_ii!+\sum_{j=0}^{i-1}q_ji^{\underline j})\mod (n-i)!i!\),它等于\(\sum_{j=0}^i r_ii^{\underline j}\mod (n-i)!i!\)。(因为\((n-i)!i!|(n-j)!i^{\underline j}\),相除就是个组合数)
于是就得到了:\(O(n\prod_{i=0}^{n} i!)\)的时间。
然后可以最后枚举\(r_0\)。可以发现根据每个位置的贡献不同将\(r_0\)分段,位置\(i\)会分出\(O(\frac{n}{(n-i)!i!})\)段,加起来刚好就是\(O(2^n)\)。搞一搞可以将时间中的一个\(n!\)换成\(2^n\lg 2^n\)。题解说这个\(\lg\)可以去掉说是用计数排序(好像要离线一下排序量比较大的时候才做)。
然而感觉常数才是最大的问题,卡不过去,TLE90爬
using namespace std;
#include <bits/stdc++.h>
#define ll long long
#define fi first
#define se second
#define mp make_pair
const int N=6;
const int mo=998244353;
const double _mo=1.0/mo;
ll qpow(ll x,ll y=mo-2){
ll r=1;
for (;y;y>>=1,x=x*x%mo)
if (y&1)
r=r*x%mo;
return r;
}
int n;
ll fac[N+1];
ll c[N+1];
ll r[N+1],C[N+1],f[N+1],tran[(N+1)*2],g[N+1],b[N+1];
ll ans;
pair<int,int> q[1<<N+2];
void dfs(int i){
if (i>n){
int k=0;
q[k++]=mp(0,0);
q[k++]=mp(g[0],0);
for (int j=1;j<=n;++j){
ll m=fac[j]*fac[n-j];
for (int t=0;t*m-C[j]<fac[n];++t){
q[k++]=mp(t*m-C[j],j);
q[k++]=mp(t*m+g[j]-C[j],j);
}
}
sort(q,q+k);
//return;
int cnt0=0;
ll pro=1;
for (int j=0;j<=n;++j){
b[j]=0;
if (f[j]==0) cnt0++; else pro=pro*f[j]%mo;
}
ll lst=0;
for (int t=0;t<k && q[t].fi<fac[n];++t){
if (lst<q[t].fi){
if (!cnt0)
ans+=pro*(q[t].fi-lst);
lst=q[t].fi;
}
int w=q[t].se;
if (f[w]){
pro=pro*tran[w<<1|b[w]]%mo;
//pro*=tran[w<<1|b[w]];
//pro=pro-(ll)(pro*_mo)*mo;
}
else
cnt0+=(b[w]?1:-1);
b[w]^=1;
}
if (!cnt0)
ans+=pro*(fac[n]-lst);
ans%=mo;
return;
}
C[i]=0;
ll m=fac[i]*fac[n-i];
for (int j=0;j<i;++j)
C[i]=(C[i]+r[j]*(fac[i]/fac[i-j]))%m;
for (int v=0;v<fac[n-i];++v){
r[i]=v;
dfs(i+1);
C[i]=(C[i]+fac[i])%m;
}
}
int main(){
fac[0]=1;
for (int i=1;i<=N;++i)
fac[i]=fac[i-1]*i;
int T;
scanf("%d",&T);
while (T--){
scanf("%d",&n);
for (int i=0;i<=n;++i)
scanf("%lld",&c[i]);
ans=0;
for (int i=0;i<=n;++i){
f[i]=c[i]/(fac[i]*fac[n-i]);
tran[i<<1]=qpow(f[i])*(f[i]+1)%mo;
tran[i<<1|1]=qpow(f[i]+1)*(f[i])%mo;
g[i]=c[i]%(fac[i]*fac[n-i]);
}
dfs(1);
ans%=mo;
printf("%lld\n",ans);
}
return 0;
}