loj #575. 「LibreOJ NOI Round #2」不等关系
前缀和优化的DP没有什么前途,我们考虑容斥
先忽略所有的\(“>”\)(全部强制满足),把剩下的\(“<"\)是否满足条件看作是\(0/1\)
那么要求的就是\(111..111\)
容斥一下就是
\(111...111=111....11?-111...110\\=111....11?-111...1?0+111...100\)
以此类推
设\(dp[i]\)表示钦定长度为\(i\)的前缀合法的排列数
\[dp[i]=\sum_{j=0}^{i=1}[st[j]=='<'] \ \ dp[j]\times (-1)^{cnt[i-1]-cnt[j]} \times \binom{i}{i-j}
\]
\(cnt[i]\)表示前\(i\)个有几个是\(“<”\)
显然可以轻易用分治NTT优化成\(O(nlog^2n)\)
code:
#include<bits/stdc++.h>
#define N 400050
#define poly vector<int>
#define mod 998244353
using namespace std;
int add(int x, int y) { x += y;
if(x >= mod) x -= mod;
return x;
}
int sub(int x, int y) { x -= y;
if(x < 0) x += mod;
return x;
}
int mul(int x, int y) {
return 1ll * x * y % mod;
}
int qpow(int x, int y) {
int ret = 1;
for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
return ret;
}
const int G = 3;
const int Ginv = qpow(G, mod - 2);
int rev[N << 1];
void ntt(int *a, int n, int o) {
for(int i = 1; i < n; i ++) rev[i] = (rev[i >> 1] >> 1) | ((n >> 1) * (i & 1));
for(int i = 1; i < n; i ++) if(i > rev[i]) swap(a[i], a[rev[i]]);
//for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n");
for(int len = 2; len <= n; len <<= 1) {
int w0 = qpow((o == 1)? G : Ginv, (mod - 1) / len);
for(int j = 0; j < n; j += len) {
int wn = 1;
for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
int X = a[k], Y = mul(wn, a[k + (len >> 1)]);
a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
}
}
}
int ninv = qpow(n, mod - 2);
if(o == -1)
for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);
//for(int i = 0; i < n; i ++) printf("%d ", a[i]); printf("\n\n");
}
#define sz(A) ((int)A.size())
int a[N << 1], b[N << 1];
poly operator * (const poly& A, const poly& B) {
for(int i = 0; i < sz(A); i ++) a[i] = A[i];
for(int i = 0; i < sz(B); i ++) b[i] = B[i];
poly C; C.resize(sz(A) + sz(B) - 1);
int len = 1;
for(; len <= sz(A) + sz(B) - 1 ; ) len <<= 1;
// for(int i = 0; i < sz(A); i ++) printf("%d ", a[i]); printf("\n");
// for(int i = 0; i < sz(B); i ++) printf("%d ", b[i]); printf("\n");
ntt(a, len, 1), ntt(b, len, 1);
// for(int i = 0; i < len; i ++) printf("%d ", a[i]); printf("\n");
// for(int i = 0; i < len; i ++) printf("%d ", b[i]); printf("\n");
// for(int i = 0; i < len; i ++) printf("%lld ", 1ll * a[i] * b[i] % mod); printf("\n");
for(int i = 0; i < len; i ++) a[i] = mul(a[i], b[i]);
ntt(a, len, -1);
//for(int i = 0; i < len; i ++) printf(" %d ", a[i]); printf("\n\n");
for(int i = 0; i < sz(A) + sz(B) - 1; i ++) C[i] = a[i];
for(int i = 0; i <= len; i ++) a[i] = b[i] = 0;
return C;
}
int fac[N], ifac[N];
void init(int n) {
fac[0] = 1;
for(int i = 1; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
ifac[n] = qpow(fac[n], mod - 2);
for(int i = n - 1; i >= 0; i --) ifac[i] = mul(ifac[i + 1], i + 1);
}
char st[N];
int n, f[N], cnt[N];
void cdq(int l, int r) {
if(l == r) {
if(l) {
if(cnt[l - 1] & 1) f[l] = (mod - f[l]) % mod;
}
return ;
}
int mid = (l + r) >> 1;
cdq(l, mid);
poly a, b;
for(int i = l; i <= mid; i ++) {
int o = f[i];
if(cnt[i] & 1) o = (mod - o) % mod;
if(st[i] == '>') o = 0;
a.push_back(o);
}
for(int i = l; i <= r; i ++) b.push_back(ifac[i - l]);
// printf("--- %d %d %d\n", l, mid, r);
// for(int i = 0; i < sz(a); i ++) printf("%d ", a[i]); printf("\n");
// for(int i = 0; i < sz(a); i ++) printf("%d ", b[i]); printf("\n");
poly c = a * b;
// for(int i = 0; i < sz(c); i ++) printf("%d ", c[i]); printf("\n\n");
for(int i = mid + 1; i <= r; i ++) f[i] = add(f[i], c[i - l]);
cdq(mid + 1, r);
}
int main() {
// freopen("a.in","r",stdin);
// freopen("a.out","w",stdout);
scanf("%s", st + 1);
n = strlen(st + 1) + 1; init(n);
for(int i = 1; i < n; i ++) cnt[i] = cnt[i - 1] + (st[i] == '<');
f[0] = 1;
cdq(0, n);
// if(!(cnt[n - 1] & 1)) f[n] = (mod - f[n]) % mod;
f[n]= mul(f[n], fac[n]);
printf("%d", f[n]);
return 0;
}