题解 史莱姆A

传送门

DP 式子是

\[f_i=\sum\limits_{j=0}^{i-1}f_j\operatorname{mex}_{j+1, i} \]

那么这是个原题,Chtholly 树维护就好

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define fir first
#define sec second
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int a[N];
const ll mod=998244353;

namespace force{
	ll f[N];
	bool vis[N];
	void solve() {
		f[0]=1;
		for (int i=1; i<=n; ++i) {
			memset(vis, 0, sizeof(vis));
			int pos=0;
			for (int j=i; j; --j) {
				vis[a[j]]=1;
				while (vis[pos]) ++pos;
				f[i]=(f[i]+f[j-1]*pos)%mod;
			}
		}
		printf("%lld\n", (f[n]%mod+mod)%mod);
	}
}

namespace task1{
	ll f[N];
	bool vis[N];
	void solve() {
		f[0]=1;
		for (int i=0; i<n; ++i) {
			memset(vis, 0, sizeof(vis));
			int pos=0;
			for (int j=i+1; j<=n; ++j) {
				vis[a[j]]=1;
				while (vis[pos]) ++pos;
				f[j]=(f[j]+f[i]*pos)%mod;
			}
		}
		printf("%lld\n", (f[n]%mod+mod)%mod);
	}
}

namespace task2{
	ll f[N], sum[N];
	int nxt[N], buc[N], now;
	struct node{int l; mutable int r, val, tim;};
	inline bool operator < (node a, node b) {return a.l<b.l;}
	set<node> odt;
	void spread(set<node>::iterator it) {
		ll val=(sum[now-1]-sum[it->tim])*(it->r-it->l+1)%mod;
		f[it->val]=(f[it->val]+val)%mod;
		it->tim=now-1;
	}
	auto split(int l) {
		// cout<<"split: "<<l<<endl;
		auto it=--odt.upper_bound({l, 0, 0, 0});
		if (it->l==l) return it;
		int tl=it->l, tr=it->r, val=it->val, tim=it->tim;
		odt.erase(it);
		odt.insert({tl, l-1, val, tim});
		// cout<<"tlr: "<<tl<<' '<<tr<<' '<<l<<endl;
		return odt.insert({l, tr, val, tim}).fir;
	}
	void solve() {
		for (int i=0; i<=n; ++i) buc[i]=n+1;
		for (int i=n; i; --i) nxt[i]=buc[a[i]], buc[a[i]]=i;
		// cout<<"a  : "; for (int i=1; i<=n; ++i) cout<<a[i]<<' '; cout<<endl;
		// cout<<"nxt: "; for (int i=1; i<=n; ++i) cout<<nxt[i]<<' '; cout<<endl;
		for (int l=0,r,maxn; l<=n; l=r+1) {
			maxn=buc[l]; r=l;
			while (r<n && buc[r+1]<=maxn) ++r;
			odt.insert({l, r, maxn, 0});
			f[maxn]+=r-l+1;
			// cout<<"ins: "<<l<<' '<<r<<endl;
		}
		// cout<<"odt: "; for (auto it:odt) cout<<"("<<it.l<<','<<it.r<<','<<it.val<<','<<it.tim<<") "; cout<<endl;
		for (now=1; now<=n; ++now) {
			// cout<<"now: "<<now<<endl;
			spread(odt.begin());
			f[now]=(f[now-1]+f[now])%mod;
			sum[now]=(sum[now-1]+f[now])%mod;
			int l=a[now], r=-1;
			// cout<<"odt: "; for (auto it:odt) cout<<"("<<it.l<<','<<it.r<<','<<it.val<<','<<it.tim<<") "; cout<<endl;
			for (auto it=split(a[now]); it!=odt.end()&&it->val<=nxt[now]; it=odt.erase(it)) {
				// cout<<"it: "<<it->l<<' '<<it->r<<endl;
				spread(it);
				r=it->r;
			}
			if (l<=r) odt.insert({l, r, nxt[now], now-1});
			// cout<<"odt: "; for (auto it:odt) cout<<"("<<it.l<<','<<it.r<<','<<it.val<<','<<it.tim<<") "; cout<<endl;
		}
		// cout<<"f  : "; for (int i=1; i<=n; ++i) cout<<f[i]<<' '; cout<<endl;
		// cout<<"sum: "; for (int i=1; i<=n; ++i) cout<<sum[i]<<' '; cout<<endl;
		printf("%lld\n", (f[n]%mod+mod)%mod);
	}
}

signed main()
{
	freopen("a.in", "r", stdin);
	freopen("a.out", "w", stdout);

	n=read();
	for (int i=1; i<=n; ++i) a[i]=read();
	// force::solve();
	// task1::solve();
	task2::solve();

	return 0;
}
posted @ 2022-04-12 16:26  Administrator-09  阅读(2)  评论(0编辑  收藏  举报