hdu 6116 路径计数 百度之星初赛B补--计数问题+NTT
题意和这个例题很像。交错排列问题。直接用三次NTT优化
#include <bits/stdc++.h> const long long MOD = 998244353; const double ex = 1e-10; typedef long long LL; #define inf 0x3f3f3f3f using namespace std; const int N = 16064; const int p = 998244353; const int G = 3; const int NUM = 25; LL x1[N],x2[N],wn[NUM]; LL F[N],Finv[N],inv[N]; inline LL quick_mod(LL a,LL b,LL m){ LL ans = 1; a %= m; while (b){ if (b % 2 == 1)ans = ans * a % m; b/=2; a = a * a % m; } return ans; } inline void GetWn(){ for (int i = 0 ; i <NUM ; i++){ int t = 1 << i; wn[i] = quick_mod(G,(p-1)/t,p); } } inline void Rader(LL a[],int len){ int j = len >> 1; for (int i = 1 ; i< len- 1; i++){ if (i <j) swap(a[i],a[j]); int k = len >> 1; while( j >= k){ j-=k;k>>=1; } if ( j < k ) j+=k; } } inline void NTT(LL a[],int len,int on){ Rader(a,len); int id = 0; for (int h = 2; h <=len ; h <<=1){ id ++; for (int j = 0 ; j< len ; j+=h){ LL w = 1; for (int k = j ; k < j + h/2 ; k++){ LL u = a[k] % p; LL t = w * ( ( a[k+h/2] ) % p ) % p; // 注意a的下标 a[k] = (u+t) %p; a[k + h / 2] = (((u-t) % p) + p) % p; w = w * wn[id] % p; } } } if (on == -1){ for (int i = 1; i<len/2 ; i++){ swap(a[i],a[len-i]); } LL Inv = quick_mod(len,p-2,p); for (int i = 0; i<len; i++){ a[i] = a[i] % p * Inv % p; } } } inline void conv(LL a[],LL b[],int n){ NTT(a,n,1); NTT(b,n,1); for (int i = 0 ; i < n; i++) a[i] = a[i] * b[i] % p; NTT(a,n,-1); } inline void init(){ inv[1] = 1; for (int i = 2; i<N; i++){ inv[i] = (MOD-MOD/i) *1ll *inv[MOD % i] % MOD; } F[0] = Finv[0] = 1; for (int i = 1 ;i<N; i++){ F[i] = F[i-1] * i % MOD; Finv[i] = Finv[i-1] *1ll*inv[i] % MOD; } return; } LL A[5][16064]; inline void getA(int d,int id){ for (int i = 1 ; i <= d ; i++){ A[id][i] = F[d-1]*Finv[i-1]%MOD * Finv[d-i] % MOD * Finv[i] % MOD; } } int main() { init(); GetWn(); int aa[5]; while (cin >> aa[1] >> aa[2] >> aa[3] >> aa[4]){ memset(A,0,sizeof(A)); int len = 0; for (int i = 1; i<=4; i++){ getA(aa[i],i); len += aa[i]; } int l = 1; while (l < 2 * (len + 1)) l<<=1; // l 为扩展长度 conv(A[1],A[2],l); conv(A[1],A[3],l); conv(A[1],A[4],l); long long ans = 0; long long f=1; for (int i = 1; i<=len ; i++){ if ((len-i) % 2 ) f = -1LL; else f = 1LL; ans = (ans + f * F[i] * A[1][i] % MOD + MOD )%MOD; } cout << (ans % MOD ) << endl; } }