题解 「LibreOJ NOI Round #2」不等关系
因为完全不会所以就直接说正解了
考虑对大小的限制是难以处理的
但是若限制只有 <
而没有 >
就很好处理了
这种情况下是将 \(n\) 个数放入若干个递增序列中,使用可重集排列即可
那么考虑用容斥处理 >
的限制,枚举钦定不满足的,剩下的任意
这样就只有 <
和大小任意的限制了
那么一个 \(O(2^n)\) 的做法是直接枚举不合法的位置,然后可重集排列
尝试 DP 优化这个做法
那么需要将容斥系数记到权值中
令 \(f_i\) 为前 \(i\) 个数(前 \(i-1\) 个符号)所有满足/不满足情况带容斥系数的权值和
有
\[f_i=\sum\limits_{j=0}^{i-1}[s_j\neq >]f_j(-1)^{cnt_{i-1}-cnt_j}\binom{i}{j}
\]
在做的事情是枚举以 \(i\) 为结尾的一段递增序列长度
-1 的那个次数是在考虑容斥系数,钦定了在这段递增序列中的 >
均不满足
那么发现这个式子可以分治 NTT
于是可以做到 \(O(n\log^2 n)\)
这题是某模拟赛的正解的一部分,所以这份代码会有点怪
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
//#define int long long
int n;
char str[N];
const ll mod=998244353, rt=3, phi=mod-1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int sta[N], tem[N], top;
ll rec[21][1<<20], ans;
void decode(int len, int s, int* pos, int& tot) {
tot=0;
for (int i=0; i<len; ++i) tem[i]=s&(1<<i)?1:0;
for (int p1=0,p2=1; p1<len; p1=p2) {
while (p2<len && tem[p2]==tem[p1]) ++p2;
pos[++tot]=p1;
}
}
ll dfs(int len, int s) {
if (len==1) return 1;
if (~rec[len][s]) return rec[len][s];
int pos[21], tot;
decode(len, s, pos, tot);
ll *t=&rec[len][s]; *t=0;
for (int i=1; i<=tot; ++i) {
int lim=(1<<pos[i])-1;
*t=(*t+dfs(len-1, (s&lim)|(s>>(pos[i]+1)<<pos[i]) ))%mod;
}
return *t;
}
void solve() {
// cout<<double(sizeof(rec))/1000/1000<<endl; exit(0);
memset(rec, -1, sizeof(rec));
for (int i=1; i<=n; ++i)
if (str[i]=='?') sta[top++]=i;
else str[i]-='0';
int lim=1<<top;
for (int s=0; s<lim; ++s) {
for (int i=0; i<top; ++i)
if (s&(1<<i)) str[sta[i]]=1;
else str[sta[i]]=0;
int t=0;
for (int i=1; i<=n; ++i) t|=str[i]<<(i-1);
ans=(ans+dfs(n, t))%mod;
}
cout<<ans<<endl;
}
}
namespace task{
char sta[N];
int cnt[N], rev[N], top, bln, bct;
ll fac[N], inv[N], f[N], t1[N], t2[N], ans;
inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
// ll calc() {
// // cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
// f[0]=1; sta[top+1]=0;
// for (int i=1; i<=top+1; ++i) cnt[i]=cnt[i-1]+(sta[i]=='>');
// for (int i=1; i<=top+1; ++i) f[i]=0;
// for (int i=1; i<=top+1; ++i)
// for (int j=0; j<i; ++j) if (sta[j]!='<')
// f[i]=(f[i]+((cnt[i-1]-cnt[j])&1?-1:1)*f[j]*C(i, j))%mod;
// // cout<<"return: "<<f[top+1]<<endl;
// return f[top+1];
// }
// ll calc() {
// f[0]=1; sta[top+1]=0;
// for (int i=1; i<=top+1; ++i) cnt[i]=cnt[i-1]+(sta[i]=='>');
// for (int i=1; i<=top+1; ++i) f[i]=0;
// for (int i=1; i<=top+1; ++i) {
// for (int j=0; j<i; ++j) if (sta[j]!='<')
// f[i]=(f[i]+f[j]*(cnt[j]&1?-1:1)*inv[j]%mod*inv[i-j])%mod;
// f[i]=f[i]*fac[i]*(cnt[i-1]&1?-1:1)%mod;
// }
// return f[top+1];
// }
void ntt(ll* a, int len, int op) {
for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
ll w, wn, t;
for (int i=1; i<len; i<<=1) {
wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
for (int j=0,step=i<<1; j<len; j+=step) {
w=1;
for (int k=j; k<j+i; ++k,w=w*wn%mod) {
ll t=w*a[k+i]%mod;
a[k+i]=(a[k]-t)%mod;
a[k]=(a[k]+t)%mod;
}
}
}
if (op==-1) {
ll inv=qpow(len, mod-2);
for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
}
}
void solve(int l, int r, int bct) {
if (l+1==r) {
if (l==0) return ;
if (l==top+1) f[l]=f[l]*fac[l]*(cnt[l-1]&1?-1:1)%mod;
if (sta[l]!='<') f[l]=f[l]*(cnt[l-1]&1?-1:1)*(cnt[l]&1?-1:1)%mod;
else f[l]=0;
return ;
}
// cout<<"solve: "<<l<<' '<<r<<endl;
int mid=(l+r)>>1, len=r-l;
solve(l, mid, bct-1);
for (int i=l; i<mid; ++i) t1[i-l]=f[i];
for (int i=mid; i<r; ++i) t1[i-l]=0;
for (int i=0; i<len; ++i) t2[i]=inv[i];
for (int i=0; i<len; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct-1));
ntt(t1, len, 1); ntt(t2, len, 1);
for (int i=0; i<len; ++i) t1[i]=t1[i]*t2[i]%mod;
ntt(t1, len, -1);
for (int i=mid; i<r; ++i) f[i]=(f[i]+t1[i-l])%mod;
solve(mid, r, bct-1);
}
ll calc() {
f[0]=1; sta[top+1]=0;
for (int i=1; i<=top+1; ++i) cnt[i]=cnt[i-1]+(sta[i]=='>');
for (int i=1; i<=top+1; ++i) f[i]=0;
for (bln=1,bct=0; bln<=top+1; bln<<=1,++bct) ;
solve(0, bln, bct);
// cout<<"f: "; for (int i=0; i<=top+1; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
return f[top+1];
}
void solve() {
fac[0]=fac[1]=1; inv[0]=inv[1]=1;
for (int i=2; i<=n+1; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=n+1; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
for (int i=2; i<=n+1; ++i) inv[i]=inv[i-1]*inv[i]%mod;
ans=fac[n+1];
for (int p1=1,p2; p1<=n; p1=p2) {
while (p1<=n && str[p1]=='?') ++p1;
if (p1>n) break;
for (p2=p1+1; p2<=n&&str[p2]!='?'; ++p2) ;
top=0;
for (int i=p1; i<p2; ++i)
if (str[i]=='0') sta[++top]='>';
else sta[++top]='<';
ans=ans*calc()%mod*inv[top+1]%mod;
}
cout<<(ans%mod+mod)%mod<<endl;
}
}
signed main()
{
// freopen("a.in", "r", stdin);
// freopen("a.out", "w", stdout);
scanf("%s", str+1);
n=strlen(str+1);
for (int i=1; i<=n; ++i)
if (str[i]=='>') str[i]='0';
else str[i]='1';
task::solve();
return 0;
}