题解 「LibreOJ NOI Round #2」不等关系

传送门

因为完全不会所以就直接说正解了

全是错误的官方题解
OID的题解

考虑对大小的限制是难以处理的
但是若限制只有 < 而没有 > 就很好处理了
这种情况下是将 \(n\) 个数放入若干个递增序列中,使用可重集排列即可
那么考虑用容斥处理 > 的限制,枚举钦定不满足的,剩下的任意
这样就只有 < 和大小任意的限制了
那么一个 \(O(2^n)\) 的做法是直接枚举不合法的位置,然后可重集排列
尝试 DP 优化这个做法
那么需要将容斥系数记到权值中
\(f_i\) 为前 \(i\) 个数(前 \(i-1\) 个符号)所有满足/不满足情况带容斥系数的权值和

\[f_i=\sum\limits_{j=0}^{i-1}[s_j\neq >]f_j(-1)^{cnt_{i-1}-cnt_j}\binom{i}{j} \]

在做的事情是枚举以 \(i\) 为结尾的一段递增序列长度
-1 的那个次数是在考虑容斥系数,钦定了在这段递增序列中的 > 均不满足
那么发现这个式子可以分治 NTT
于是可以做到 \(O(n\log^2 n)\)

这题是某模拟赛的正解的一部分,所以这份代码会有点怪
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define ll long long
//#define int long long

int n;
char str[N];
const ll mod=998244353, rt=3, phi=mod-1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	int sta[N], tem[N], top;
	ll rec[21][1<<20], ans;
	void decode(int len, int s, int* pos, int& tot) {
		tot=0;
		for (int i=0; i<len; ++i) tem[i]=s&(1<<i)?1:0;
		for (int p1=0,p2=1; p1<len; p1=p2) {
			while (p2<len && tem[p2]==tem[p1]) ++p2;
			pos[++tot]=p1;
		}
	}
	ll dfs(int len, int s) {
		if (len==1) return 1;
		if (~rec[len][s]) return rec[len][s];
		int pos[21], tot;
		decode(len, s, pos, tot);
		ll *t=&rec[len][s]; *t=0;
		for (int i=1; i<=tot; ++i) {
			int lim=(1<<pos[i])-1;
			*t=(*t+dfs(len-1, (s&lim)|(s>>(pos[i]+1)<<pos[i]) ))%mod;
		}
		return *t;
	}
	void solve() {
		// cout<<double(sizeof(rec))/1000/1000<<endl; exit(0);
		memset(rec, -1, sizeof(rec));
		for (int i=1; i<=n; ++i)
			if (str[i]=='?') sta[top++]=i;
			else str[i]-='0';
		int lim=1<<top;
		for (int s=0; s<lim; ++s) {
			for (int i=0; i<top; ++i)
				if (s&(1<<i)) str[sta[i]]=1;
				else str[sta[i]]=0;
			int t=0;
			for (int i=1; i<=n; ++i) t|=str[i]<<(i-1);
			ans=(ans+dfs(n, t))%mod;
		}
		cout<<ans<<endl;
	}
}

namespace task{
	char sta[N];
	int cnt[N], rev[N], top, bln, bct;
	ll fac[N], inv[N], f[N], t1[N], t2[N], ans;
	inline ll C(int n, int k) {return fac[n]*inv[k]%mod*inv[n-k]%mod;}
	// ll calc() {
	// 	// cout<<"sta: "; for (int i=1; i<=top; ++i) cout<<sta[i]<<' '; cout<<endl;
	// 	f[0]=1; sta[top+1]=0;
	// 	for (int i=1; i<=top+1; ++i) cnt[i]=cnt[i-1]+(sta[i]=='>');
	// 	for (int i=1; i<=top+1; ++i) f[i]=0;
	// 	for (int i=1; i<=top+1; ++i)
	// 		for (int j=0; j<i; ++j) if (sta[j]!='<')
	// 			f[i]=(f[i]+((cnt[i-1]-cnt[j])&1?-1:1)*f[j]*C(i, j))%mod;
	// 	// cout<<"return: "<<f[top+1]<<endl;
	// 	return f[top+1];
	// }
	// ll calc() {
	// 	f[0]=1; sta[top+1]=0;
	// 	for (int i=1; i<=top+1; ++i) cnt[i]=cnt[i-1]+(sta[i]=='>');
	// 	for (int i=1; i<=top+1; ++i) f[i]=0;
	// 	for (int i=1; i<=top+1; ++i) {
	// 		for (int j=0; j<i; ++j) if (sta[j]!='<')
	// 			f[i]=(f[i]+f[j]*(cnt[j]&1?-1:1)*inv[j]%mod*inv[i-j])%mod;
	// 		f[i]=f[i]*fac[i]*(cnt[i-1]&1?-1:1)%mod;
	// 	}
	// 	return f[top+1];
	// }
	void ntt(ll* a, int len, int op) {
		for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i], a[rev[i]]);
		ll w, wn, t;
		for (int i=1; i<len; i<<=1) {
			wn=qpow(rt, (op*phi/(i<<1)+phi)%phi);
			for (int j=0,step=i<<1; j<len; j+=step) {
				w=1;
				for (int k=j; k<j+i; ++k,w=w*wn%mod) {
					ll t=w*a[k+i]%mod;
					a[k+i]=(a[k]-t)%mod;
					a[k]=(a[k]+t)%mod;
				}
			}
		}
		if (op==-1) {
			ll inv=qpow(len, mod-2);
			for (int i=0; i<len; ++i) a[i]=a[i]*inv%mod;
		}
	}
	void solve(int l, int r, int bct) {
		if (l+1==r) {
			if (l==0) return ;
			if (l==top+1) f[l]=f[l]*fac[l]*(cnt[l-1]&1?-1:1)%mod;
			if (sta[l]!='<') f[l]=f[l]*(cnt[l-1]&1?-1:1)*(cnt[l]&1?-1:1)%mod;
			else f[l]=0;
			return ;
		}
		// cout<<"solve: "<<l<<' '<<r<<endl;
		int mid=(l+r)>>1, len=r-l;
		solve(l, mid, bct-1);
		for (int i=l; i<mid; ++i) t1[i-l]=f[i];
		for (int i=mid; i<r; ++i) t1[i-l]=0;
		for (int i=0; i<len; ++i) t2[i]=inv[i];
		for (int i=0; i<len; ++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bct-1));
		ntt(t1, len, 1); ntt(t2, len, 1);
		for (int i=0; i<len; ++i) t1[i]=t1[i]*t2[i]%mod;
		ntt(t1, len, -1);
		for (int i=mid; i<r; ++i) f[i]=(f[i]+t1[i-l])%mod;
		solve(mid, r, bct-1);
	}
	ll calc() {
		f[0]=1; sta[top+1]=0;
		for (int i=1; i<=top+1; ++i) cnt[i]=cnt[i-1]+(sta[i]=='>');
		for (int i=1; i<=top+1; ++i) f[i]=0;
		for (bln=1,bct=0; bln<=top+1; bln<<=1,++bct) ;
		solve(0, bln, bct);
		// cout<<"f: "; for (int i=0; i<=top+1; ++i) cout<<(f[i]%mod+mod)%mod<<' '; cout<<endl;
		return f[top+1];
	}
	void solve() {
		fac[0]=fac[1]=1; inv[0]=inv[1]=1;
		for (int i=2; i<=n+1; ++i) fac[i]=fac[i-1]*i%mod;
		for (int i=2; i<=n+1; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod;
		for (int i=2; i<=n+1; ++i) inv[i]=inv[i-1]*inv[i]%mod;
		ans=fac[n+1];
		for (int p1=1,p2; p1<=n; p1=p2) {
			while (p1<=n && str[p1]=='?') ++p1;
			if (p1>n) break;
			for (p2=p1+1; p2<=n&&str[p2]!='?'; ++p2) ;
			top=0;
			for (int i=p1; i<p2; ++i)
				if (str[i]=='0') sta[++top]='>';
				else sta[++top]='<';
			ans=ans*calc()%mod*inv[top+1]%mod;
		}
		cout<<(ans%mod+mod)%mod<<endl;
	}
}

signed main()
{
	// freopen("a.in", "r", stdin);
	// freopen("a.out", "w", stdout);

	scanf("%s", str+1);
	n=strlen(str+1);
	for (int i=1; i<=n; ++i)
		if (str[i]=='>') str[i]='0';
		else str[i]='1';
	task::solve();

	return 0;
}
posted @ 2022-04-08 21:32  Administrator-09  阅读(15)  评论(0编辑  收藏  举报