【hdu 6067】Big Integer
题意
给你一个 \((k-1)\times (n+1)\) 的 \(01\) 矩阵 \(g\),求满足下列条件的 \(k(k\le 10)\) 进制整数的数量:
1. 不超过 \(n\) 位且数的最高位非 \(0\)
2. 没有出现 \(0\)
3. 对于 \(0\) 以外的数字 \(i\),对于 \(j∈[0,n]\),若 \(g(i,j)=1\),则允许数字 \(i\) 恰好出现 \(j\) 次;若 \(g(i,j)=0\),则不允许数字 \(i\) 恰好出现 \(j\) 次。
这个问题太简单了,于是有 \(m\) 次修改操作,每次将 \(g(i,j)\) 单点取反。让你求修改前及每次修改操作后的答案之和 \(\mod 786433\)。
\(786433=2^{18}\times 3 + 1\),是个质数。
\(k\le 10,\space n\le 14000,\space m\le 200\)
题解
所以放一个这么明显的 \(\text{NTT}\) 模数是什么意思
前置普及组知识:你有 \(x_1\) 个 \(1\),\(x_2\) 个 \(2\),……,\(x_n\) 个 \(n\),用这 \(x_1+x_2+\cdots+x_n\) 个数构成的不同排列数为 \(\frac{(x_1+x_2+\cdots+x_n)!}{x_1! x_2! \cdots x_n!}\)。
构造指数生成函数 \(f_i(x) = \sum\limits_{j=0}^{n} g(i,j) \frac{x^j}{j!}\),将这 \(k-1\) 个多项式卷积成一个生成函数后,记 \(i\) 次项系数为 \(a_i\),则答案为 \(\sum\limits_{i=1}^n a_i i!\)。
可以用 \(\text{NTT}\) 在 \(O(nk^2\log (nk))\) 的复杂度内预处理出初始答案。
下面考虑修改。注意到我们只关心所有答案的和,故可以在 \(\text{DFT}\) 意义下直接累加答案,最后再将结果 \(\text{IDFT}\) 回来。
对于单点修改操作,可以看成是给某个多项式 \(A\) 叠加上一个只有一项系数不为 \(0\) 的多项式 \(B\)。
因为 \(A\) 正处于点值表示法,所以我们把 \(B\) 也转化成点值表示法(其长度需要扩到与 \(A\) 相等)。这需要 \(O(nk\log(nk))\) 的时间由于只有一项系数不为 \(0\),我们考虑暴力 \(\text{DFT}\)。
观察指数生成函数 \(\text{NTT}\) 的公式:
$$y_n = \sum\limits_{i=0}^{d-1}\frac{x_n}{n!}\times (g\frac{p-1}{d})\mod p$$
那么之前的多项式 \(B\) 在 \(\text{DFT}\) 后的结果是一个等比数列,故直接对原始多项式叠加一个等比数列即可。
然而如果直接暴力叠加的话,修改部分的复杂度是 \(O(nmk^2)\)(不过实测能卡过)。
发现把 \(k-1\) 个 \(\text{DFT}\) 后的多项式放成 \(k-1\) 行,依次对齐每次项,由于点值表示法下,这些多项式卷起来的第 \(i\) 位是这些多项式第 \(i\) 位的乘积,显然只要有一个是 \(0\),这一列就废了。考虑记录 \(\text{DFT}\) 下每一列 \(0\) 的数量以及非 \(0\) 数的乘积,这样每次单点修改时就只需要修改该点所在的一行多项式的信息。
复杂度 \(O(nk^2\log(nk) + nmk)\)。
#include<bits/stdc++.h>
#define ll long long
#define N 140010
#define mod 786433
#define G 10
#define invG 235930
using namespace std;
inline int read(){
int x=0; bool f=1; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=0;
for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+c-'0';
if(f) return x; return 0-x;
}
int Pow(int x, int y){
int ret=1;
while(y){
if(y&1) ret=(ll)ret*x%mod;
x=(ll)x*x%mod;
y>>=1;
}
return ret;
}
struct Poly{
int n,bit,r[N];
void init(int x){
for(n=1,bit=0; n<x; n<<=1,++bit);
for(int i=1; i<n; ++i) r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
//cout<<"n:"<<n<<endl;
}
void dft(int *a, int f){
for(int i=0; i<n; ++i) if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=1; i<n; i<<=1){
int wn = Pow(f==1 ? G : invG, (mod-1)/(i<<1));
for(int j=0; j<n; j+=(i<<1)){
int w=1,x,y;
for(int k=0; k<i; ++k,w=(ll)w*wn%mod)
x=a[j+k], y=(ll)w*a[j+i+k]%mod,
a[j+k]=(x+y)%mod, a[j+i+k]=(x-y+mod)%mod;
}
}
if(f==-1){
int mul=Pow(n,mod-2);
for(int i=0; i<n; ++i) a[i]=(ll)a[i]*mul%mod;
}
}
}NTT;
char c[11][N];
int k,n,m,e[11][N],f[11][N],g[N],ans;
int inv[mod+5],jc[N],jcn[N];
int mul[N],zero_cnt[N];
int main(){
k=read(), n=read(), m=read();
int num=(k-1)*n;
inv[0]=inv[1]=jc[0]=jcn[0]=1;
for(int i=2; i<mod; ++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
for(int i=1; i<=num; ++i) jc[i]=(ll)jc[i-1]*i%mod, jcn[i]=inv[jc[i]];
for(int i=1; i<k; ++i){
scanf("%s",c[i]);
for(int j=0; j<=n; ++j){
e[i][j]=c[i][j]-'0';
f[i][j] = e[i][j] ? jcn[j] : 0;
}
}
NTT.init(num+1);
for(int i=0; i<NTT.n; ++i) mul[i]=1, g[i]=1;
for(int i=1; i<k; ++i){
NTT.dft(f[i],1);
for(int j=0; j<NTT.n; ++j){
g[j]=(ll)g[j]*f[i][j]%mod;
if(!f[i][j]) ++zero_cnt[j];
else mul[j]=(ll)mul[j]*f[i][j]%mod;
//cout<<f[i][j]<<endl;
}
}
//for(int i=0; i<NTT.n; ++i) cout<<g[i]<<endl;
int x,y;
while(m--){
x=read(), y=read();
e[x][y]^=1;
//for(int i=0; i<NTT.n; ++i) cout<<mul[i]<<' '<<zero_cnt[i]<<endl;
for(int i=0; i<NTT.n; ++i){
if(f[x][i]) mul[i]=(ll)mul[i]*inv[f[x][i]]%mod;
else --zero_cnt[i];
}
int val=jcn[y], tol=Pow(G,(mod-1)/NTT.n*y%(mod-1));
if(!e[x][y]) val=mod-val;
for(int i=0; i<NTT.n; ++i){
f[x][i]=(f[x][i]+val)%mod;
//cout<<f[x][i]<<endl;
val=(ll)val*tol%mod;
}
for(int i=0; i<NTT.n; ++i){
if(f[x][i]) mul[i]=(ll)mul[i]*f[x][i]%mod;
else ++zero_cnt[i];
if(!zero_cnt[i]) g[i]=(g[i]+mul[i])%mod;
}
}
NTT.dft(g,-1);
for(int i=1; i<=num; ++i){
ans=(ans+(ll)g[i]*jc[i]%mod)%mod;
//cout<<f[k-1][i]<<' '<<jc[i]<<endl;
}
cout<<ans<<endl;
return 0;
}
/*
3 2 0
101
010
3 2 0
111
010
3 2 0
110
010
3 2 1
101
010
1 1
*/