「LibreOJ NOI Round #2」不等关系
「LibreOJ NOI Round #2」不等关系
解题思路
令 \(F(k)\) 为恰好有 \(k\) 个大于号不满足的答案,\(G(k)\) 表示钦点了 \(k\) 个大于号不满足,剩下随便填的方案数。
枚举有多少个大于号被钦点了,\(F(0)=\sum_{i=0}^n G(i)(-1)^i\) 。
对于一个只有小于号限制的序列的方案数就是每一个小于号链接的联通块里分配的数字顺序固定,块与块之间随便排,令 \(sz[i]\) 表示第 \(i\) 个联通块的大小,方案数也就是 $ \dfrac{n!}{\prod_{i=1}^msz[i]!}$
令 \(dp[i]\) 表示前 \(i\) 个数,\(s_{i}\) 不强制改为小于号的 \(\prod_{i=1}^m (-1)^m \dfrac{1}{sz[i]!}\) 乘上容斥系数之和之和,令 \(p[i]\) 表示前 \(i\) 个字符的大于号数量。\(i\) 是一个联通块的右边界,考虑枚举左边界 \(j\) 计算贡献得到转移。
\[dp[i]=
\begin{cases}
s_i =0 & dp[i]=(-1)^{p[i-1]}\sum_{j=0}^{i-1}dp[j]\times (-1)^{p[j]}\dfrac{1}{(i-j)!}
\\ otherwise & dp[i] = 0
\end{cases}
\]
答案的式子就是 \(Ans = n!\times dp[n+1]\) ,
显然可以分治FFT优化转移,复杂度 \(O(n\log^2n)\) 。
code
/*program by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 1 << 22, mod = 998244353, G = 3;
char s[N];
int js[N], p[N], a[N], b[N], inv[N], dp[N], n;
inline void up(int &x, int y){
x = x + y >= mod ? x + y - mod : x + y;
}
namespace poly{
int rev[N], len, lg;
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline void timesinit(int lenth){
for(len = 1, lg = 0; len <= lenth; len <<= 1, lg++);
for(int i = 0; i < len; i++)
rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (lg - 1));
}
inline void dft(int *a, int sgn){
for(int i = 0; i < len; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int k = 2; k <= len; k <<= 1){
int w = Pow(G, (mod - 1) / k);
if(sgn == -1) w = Pow(w, mod - 2);
for(int i = 0; i < len; i += k){
int now = 1;
for(int j = i; j < i + (k >> 1); j++){
int x = a[j], y = 1ll * a[j+(k>>1)] * now % mod;
a[j] = x + y >= mod ? x + y - mod : x + y;
a[j+(k>>1)] = x - y < 0 ? x - y + mod : x - y;
now = 1ll * now * w % mod;
}
}
}
if(sgn == -1){
int Inv = Pow(len, mod - 2);
for(int i = 0; i < len; i++) a[i] = 1ll * a[i] * Inv % mod;
}
}
}
using poly::Pow;
inline void solve(int l, int r){
if(l == r){
if(l) dp[l] = p[l-1] & 1 ? mod - dp[l] : dp[l];
if(l != n + 1) dp[l] = p[l] & 1 ? mod - dp[l] : dp[l];
if(l && s[l] == '<') dp[l] = 0;
return;
}
int mid = (l + r) >> 1;
solve(l, mid);
poly::timesinit(2 * (r - l + 1));
for(int i = l; i <= mid; i++) a[i-l] = dp[i];
for(int i = 1; i <= r - l + 1; i++) b[i] = inv[i];
poly::dft(a, 1), poly::dft(b, 1);
for(int i = 0; i < poly::len; i++)
a[i] = 1ll * a[i] * b[i] % mod;
poly::dft(a, -1);
for(int i = mid + 1; i <= r; i++) up(dp[i], a[i-l]);
for(int i = 0; i < poly::len; i++) a[i] = b[i] = 0;
solve(mid + 1, r);
}
int main(){
scanf("%s", s + 1), n = strlen(s + 1);
js[0] = inv[0] = 1;
for(int i = 1; i <= n + 1; i++){
js[i] = 1ll * js[i-1] * i % mod;
inv[i] = Pow(js[i], mod - 2);
}
for(int i = 1; i <= n; i++)
p[i] = p[i-1] + (s[i] == '>');
dp[0] = 1;
solve(0, n + 1);
cout << 1ll * js[n+1] * dp[n+1] % mod;
return 0;
}