1007
题意:有一个 \(k\) 维空间,第 \(k\) 维的范围为 \([0,N_k]\)。一个人在 \((0,\cdots,0)\),要走到 \((N_1,\cdots,N_k)\)。他的移动方式有 \(n\) 种,第 \(i\) 种用 \((y_{i,1},\cdots,y_{i,k})\) 表示,代表他在第 \(j\) 维上前进 \(y_{i,j}\) 个单位长。(当然他不能走出去)另外还给定了 \(m\) 个位置不能走。数据范围是 \(2\le k\le 16,1\le n \le 10^5,1\le m \le 1000\)。注意PDF的数据范围错了,应该是 \(2\le k \le16\)。
看到这个问题,我首先想到了一个类似的CF题。那个题是有一个巨大的棋盘,可以向上向右走,也是给定了 \(m\) 个位置不能走,问从左下角到右上角的方案数。这个题的解法是容斥:用 \(dp_i\) 表示到达第 \(i\) 个不能走的位置,此前不经过别的这种位置的方案数。那么 \(dp_i=ways((0,0),v_i)-\sum_{j=1}^{i-1}dp_j\cdot ways(v_j,v_i)\)。\(v_i\) 表示第 \(i\) 个不能走的点,\(ways(u,v)\) 表示从 \(u\) 走到 \(v\) 的方案数。把终点视作最后一个不能走的位置,跑这个 DP,就可以 \(O(m^2)\) 过这道题。这个题的题号是 CF559C。
回到 1007,我们发现这个做法可以照搬下来,那么就剩下求 \(ways\)。题目给了总点数 \(N=\prod_{i=1}^k(N_i+1)\le 2\cdot 10^5\),稍微扩大一下,我们来求 \((0,...,0)\) 到各个点的方案数。
如果熟悉多项式方面的事情,应该可以想到用生成函数处理。定义
那么 \(Y^j\) 中项 \(x_1^{c_1}\cdots x_k^{c_k}\) 的系数就表示移动 \(j\) 次到达 \((c_1,c_2,\cdots,c_k)\) 的方法数。我们的目标函数就是 \(\sum_{i\ge 0}Y^i=\frac{1}{1-Y}\),它的各项系数是我们要求的答案。
做一个 \(k\) 维的多项式求逆多少是有点逆天的,所以换一种描述方式。对所有点按照字典序重编号,那么每个 \(Y_i\) 其实都可以描述成一个下标的加法,只是有些位置因为越界了不能加。在多项式中对于越界,常常会选择“直接丢掉”的方式。但是如果只有这一个数的刻画也没法丢。
我对于这道题的思考到这里就进行不下去了。(所以我会认为前面都很简单doge
再来回顾一下刚才的定义:点 \((x_1,\cdots,x_k)\) 对应数 \(x_1+x_2n_1+\cdots+x_kn_1\cdots n_{k-1}\)。(记 \(n_j=N_j+1\))这时直接对多项式求逆会出问题,也就是多项式乘法出了问题。
这时题解给出了一个相当妙的操作:增加一个变元 \(t\),它在 \(x^i\) 项的次数是 \(\lambda(i)=\lfloor \frac{i}{n_1} \rfloor+\lfloor \frac{i}{n_1n_2} \rfloor+\cdots+\lfloor \frac{i}{n_1n_2\cdots n_{k-1}} \rfloor\)。在做多项式乘法时,如果 \(x^i\) 乘以 \(x^j\) 能够贡献给 \(x^{i+j}\),就要求 \(\lambda(i)+\lambda(j)=\lambda(i+j)\)。观察到 \(\lfloor \frac{i}{n} \rfloor+\lfloor \frac{j}{n} \rfloor-\lfloor \frac{i+j}{n} \rfloor \in \{-1,0\}\)(高斯函数的常用性质),所以 \(\lambda(i)+\lambda(j)-\lambda(i+j) \in \{-k+1,\cdots,0\}\)。那么我们只要做 \(\bmod (t^k-1)\) 的多项式乘法,保留合法的项就可以了。
具体到实现细节上,假设要对这样的两个多项式 \(f,g\) 卷积,设 \(F_{i,j}=[\lambda(j)==i]f_jx^j\),\(G\) 类似。然后对各个 \(F_{i},G_{i}\) DFT 转成点值。要在 \(t\) 这一维上处理这些点值,因为 \(k\) 很小,我们直接对 \(t\) 这个变元暴力卷积,也就是 \(F_{i_1,j}\cdot G_{i_2,j}\)(点值)累加到 \(F_{(i_1+i_2)\bmod k,j}\)。再把每个 \(F_i\) IDFT。最后只保留 \(f_i=F_{\lambda(i),i}\) 作为卷积的结果。
这样我们就可以对这个多项式求逆了!这样我们就彻底解决这个题了!
讲题后UPD:增加一个数学意义上的严谨解释。
我们重载了多项式乘法:\(f(x)=\sum_{i=0}^{n-1} a_ix^i\) 与 \(g(x)=\sum_{j=0}^{m-1} b_jx^j\) 的乘积是 \(f(x)g(x)=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1} a_ib_jx^{i+j}[\lambda(i)+\lambda(j)==\lambda(i+j)]\)。这里对 \(\lambda\) 函数的所有运算都在 \(\bmod k\) 意义下。
为此定义
再定义
则
从而
我们通过 DFT 求出 \(F_p,G_q\) 在各个原根位置的点值,要求 \(\sum_{p=0}^{k-1}\sum_{q=0}^{k-1}F_p(x)G_q(x)[p+q\equiv r \pmod k]\) 在各个原根位置的点值,那太容易了,暴力就行。
卷完之后只保留需要的项的系数就行。
#include<bits/stdc++.h>
using namespace std;
const int K=22,N=1e6+5,P=998244353,_G=3,invG=(P+1)/3;
int T,k,b[K],facb[K],n,m,lm[N];
int Y[N],ways[N],tr[N];
int qpow(int a,int b=P-2){int c=1;for(;b;b>>=1,a=1ll*a*a%P)if(b&1)c=1ll*c*a%P;return c;}
void tpre(int n){
for(int i=0;i<n;i++)
tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
}
void NTT(int *f,int flag,int n){
for(int i=0;i<n;i++)if(i<tr[i])swap(f[i],f[tr[i]]);
for(int p=2;p<=n;p<<=1){
int len=p>>1;
int tG=qpow(flag==1?_G:invG,(P-1)/p);
for(int k=0;k<n;k+=p){
int buf=1;
for(int l=k;l<k+len;l++){
int tmp=1ll*buf*f[len+l]%P;
f[len+l]=(f[l]-tmp+P)%P;f[l]=(f[l]+tmp)%P;
buf=1ll*buf*tG%P;
}
}
}
}
int F[K][N],G[K][N],H[N],res[K];
void mul(int*A,int*B,int n,int type,int *ans){
int len=1;for(;len<=n*1.5;len<<=1);tpre(len);
for(int i=0;i<k;i++)
for(int j=0;j<len;j++){F[i][j]=0;if(type)G[i][j]=0;}
for(int i=0;i<n;i++){F[lm[i]][i]=A[i];if(type)G[lm[i]][i]=B[i];}
for(int i=0;i<k;i++){NTT(F[i],1,len);if(type)NTT(G[i],1,len);}
for(int i=0;i<len;i++){
for(int j=0;j<k;j++)res[j]=0;
for(int j=0;j<k;j++)
for(int l=0;l<k;l++){
int c=(j+l)%k;
res[c]=(res[c]+1ll*F[j][i]*G[l][i]%P)%P;
}
for(int j=0;j<k;j++)F[j][i]=res[j];
}
for(int i=0;i<k;i++)NTT(F[i],-1,len);
int invn=qpow(len);
for(int i=0;i<n;i++)ans[i]=1ll*F[lm[i]][i]*invn%P;
}
void inv(int*a,int n,int*ans){
if(n==1){ans[0]=qpow(a[0]);return;}
inv(a,(n+1)>>1,ans);mul(a,ans,n,1,H);mul(H,ans,n,0,H);
for(int i=0;i<n;i++)ans[i]=(2ll*ans[i]-H[i]+P)%P;
}
int p[1005],dp[1005];
bool cmp(int u,int v){
for(int i=1;i<=k;i++)
if(u/facb[i-1]%b[i]>v/facb[i-1]%b[i])return false;
return true;
}
int main(){
facb[0]=1;
scanf("%d",&T);
while(T--){
scanf("%d",&k);
for(int i=1;i<=k;i++)
scanf("%d",b+i),++b[i],facb[i]=facb[i-1]*b[i];
scanf("%d%d",&n,&m);
while(n--){
int tmp=0;
for(int i=1,x;i<=k;i++)
scanf("%d",&x),tmp+=x*facb[i-1];
--Y[tmp];
}
Y[0]=1-P;
for(int i=0;i<facb[k];i++)Y[i]=(Y[i]+P)%P;
memset(lm,0,sizeof(lm));
for(int i=0;i<facb[k];i++){
for(int j=1;j<k;j++)lm[i]+=i/facb[j];
lm[i]%=k;
}
memset(ways,0,sizeof(ways));
inv(Y,facb[k],ways);
memset(Y,0,sizeof(Y));
memset(p,0,sizeof(p));
for(int i=1;i<=m;i++)
for(int j=1,x;j<=k;j++)
scanf("%d",&x),p[i]+=x*facb[j-1];
sort(p+1,p+m+1);
p[++m]=facb[k]-1;
memset(dp,0,sizeof(dp));
for(int i=1;i<=m;i++){
dp[i]=ways[p[i]];
for(int j=1;j<i;j++)if(cmp(p[j],p[i]))
dp[i]=(dp[i]-1ll*ways[p[i]-p[j]]*dp[j]%P+P)%P;
}
printf("%d\n",dp[m]);
}
return 0;
}