cf1251 F. Red-White Fence NTT + 组合数学
传送门
比较好的一个理解多项式乘法的题。
首先,周长\(C = (红板长度 + 板子数目)\times 2\)
那么对于每一个查询\(Q\)即周长,可以得到选取的板子数目为\(\frac{Q}{2} - 红板长度\), 白板个数为\(\frac{Q}{2} - 红板长度 - 1\)
那么只需要求出对于每一个红板的情况,最后对于每一个查询,枚举在该红板下取白板数,求和即可。
现考虑对于每一个红板,求白板的情况。
首先,对于只出现一次的白板,有\(tot1\)个,那么取\(i\)个白板的情况为\(2^i\times C_{tot1}^{ i}\), 因为有两种可能,在左边或在右边。
对于出现次数大于等于2的白板,设有\(tot2\)个,那么取\(i\)个白板的情况为\(C_{2tot2}^{i}\),相同的白板就放在两侧,大于等于2的都等价于等于2的情况。
相当于就是\(ans[k] = f_i + g_{k - i}\), \(g_i\) = \(2^i\times C_{tot1}^{ i}\), \(g_i =C_{2tot2}^{i}\)
那么对于每一个红板长度,就要去求出所有的\(ans_i\),这就用多项式乘法就行了,这就相当于是多项式乘法的定义了。
其实有个地方不好理解,就是对于出现次数大于等于2的白板,如果你是成对地取,相当于是\(i\)是偶数时比较好理解,如果是取奇数个情况呢?那不就得乘2吗,因为两侧都可以放,但其实取奇数时是存在重复的,比如两个1的时候,我是\(C_2^1\),这时取重复了,但其实是\(C_1^1\),所以取奇数个时,奇数那部分得先除以2再乘以2,相当于不乘。
#include <bits/stdc++.h>
#define ll long long
#define CASE int Kase = 0; cin >> Kase; for(int kase = 1; kase <= Kase; kase++)
using namespace std;
template<typename T = long long> inline T read() {
T s = 0, f = 1; char ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) {s = (s << 3) + (s << 1) + ch - 48; ch = getchar();}
return s * f;
}
const int N = 3e5 + 5, MAXN = 3e5 + 5, MOD = 1e9 + 7, CM = 998244353, INF = 0x3f3f3f3f;
namespace NTT{
ll ksm(ll a, ll b, ll p){ ll ans = 1; a %= p; while(b){ if(b & 1) ans = ans * a % p; a = a * a % p; b >>= 1;} return ans; }
const int P = 998244353, G = 3, inv_G = 332748118; // p的原根和原根在p下的逆元
const int N = 2e6 + 5; // n * 4
int r[N], A[N], B[N], n, m;
void NTT(int *a, int n, int op){
for(int i = 0; i < n; i++)
if(i < r[i]) swap(a[i], a[r[i]]);
for(int mid = 1; mid < n; mid <<= 1){
int x = ksm(op == 1 ? G: inv_G, (P - 1) / (mid << 1), P);
for(int j = 0; j < n; j += (mid << 1)){
int w = 1;
for(int k = 0; k < mid; k++, w = 1ll * w * x % P){
int t1 = a[j + k], t2 = 1ll * w * a[j + k + mid] % P;
a[j + k] = (t1 + t2) % P;
a[j + k + mid] = (t1 - t2 + P) % P;
}
}
}
if(op == -1) {
int inv = ksm(n, P - 2, P);
for(int i = 0; i <= n; i++) A[i] = 1ll * A[i] * inv % P;
}
}
void mul(int *a, int *b, int *c, int nn, int mm, int &k){
int l = 0;
n = nn, m = mm;
for(m = n + m, n = 1; n <= m; n <<= 1, l++);
for(int i = 0; i < n; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
A[i] = B[i] = 0;
}
for(int i = 0; i <= nn; i++) A[i] = a[i];
for(int i = 0; i <= mm; i++) B[i] = b[i];
NTT(A, n, 1); NTT(B, n, 1);
for(int i = 0; i < n; i++) A[i] = 1ll * A[i] * B[i] % P;
NTT(A, n, -1); k = m;
for(int i = 0; i <= m; i++) c[i] = A[i];
}
};
ll ksm(ll a, ll b, ll p){
ll ans = 1; a %= p;
while(b){
if(b & 1) ans = ans * a % p;
a = a * a % p;
b >>= 1;
}
return ans;
}
namespace Combination{
ll fac[MAXN], invfac[MAXN], mod;
void init(int n, ll MOD){ // 线性求[1, n]的组合数和逆元
fac[0] = 1; mod = MOD;
for(int i = 1; i <= n; i++)
fac[i] = fac[i - 1] * i % mod;
invfac[n] = ksm(fac[n], mod - 2, mod);
for(int i = n; i >= 1; i--)
invfac[i - 1] = invfac[i] * i % mod;
}
ll C(ll n, ll m){
return n >= m ? fac[n] * invfac[n - m] % mod * invfac[m] % mod: 0;
}
}
int a[N], b[N], cnt[N], n, k;
int ans[N][6];
int aa[N], bb[N], cc[N];
void cal(int pos, int red) {
for(int i = 1; i <= n; i++) cnt[a[i]] = 0;
for(int i = 1; i <= n; i++) if(a[i] < red) cnt[a[i]]++;
int one = 0, two = 0;
for(int i = 1; i < red; i++) one += cnt[i] == 1, two += cnt[i] >= 2;
two *= 2;
for(int i = 0; i <= one; i++) aa[i] = 1ll * ksm(2, i, CM) * Combination::C(one, i) % CM;
for(int i = 0; i <= two; i++) bb[i] = Combination::C(two, i);
int three = 0;
NTT::mul(aa, bb, cc, one, two, three);
for(int i = 0; i <= three; i++) ans[i][pos] = cc[i];
}
void solve(int kase){
Combination::init(N - 4, CM);
n = read(), k = read();
for(int i = 1; i <= n; i++) a[i] = read();
for(int i = 1; i <= k; i++) b[i] = read();
for(int i = 1; i <= k; i++) cal(i, b[i]);
int q = read();
for(int i = 1; i <= q; i++) {
int Q = read();
int res = 0;
for(int j = 1; j <= k; j++) {
int num = Q / 2 - b[j] - 1;
if(num < 0) continue;
res += ans[num][j];
res %= CM;
}
printf("%d\n", res);
}
}
const bool DUO = 0;
int main(){
clock_t start, finish; double totaltime; start = clock();
if(DUO) {CASE solve(kase);} else solve(1);
finish = clock();
#ifdef ONLINE_JUDGE
return 0;
#endif
printf("\nTime: %lfms\n", (double)(finish - start) / CLOCKS_PER_SEC * 1000);
return 0;
}
I‘m Stein, welcome to my blog