题解 序列划分

传送门

首先有一个 \(O(n^2)\) 的 DP

\[f_{i+1}=\sum\limits_{j=1}\operatorname{mex}(j, i)f_j \]

那么考虑怎么优化这个 DP
容易想到用线段树取维护这个 mex
那么我们每次要加入一个点 \(i\),原来满足 \(\operatorname{mex}=a_i\) 的区间 \([l, r]\) 可能会分裂成若干个新区间
那么怎么找到这些新区间呢?
考虑 \(\operatorname{mex}(r, i)=x\),则 \([\max(lst_x+1), r]\) 的 mex 应被赋值为 \(x\),然后令 \(r=lst_x\),重复此过程
那么区间总数是 \(O(n)\) 级别的,所以暴力做这个事情的复杂度均摊下来是 \(O(n\log n)\)
但是我被卡常了

那么沈老师做法:
将转移写成这样:

\[f_i\times\operatorname{mex}(i, j-1)\to f_j \]

然后

\[\operatorname{mex}(i, j-1)=\sum\limits_k[\operatorname{mex}(i, j-1)>k] \]

\(b_i\) 为最小的位置满足 \(\operatorname{mex}>i\)
那么

\[b_i=\max\limits_{k\leqslant i}pos_k \]

其中 \(pos_i\)\(i\) 第一次出现的位置
那么一次转移相当于给每个 \([b_j, n]\) 加上 \(dp_i\)
那么差分一下变成给每个 \(b_j\) 加上 \(dp_i\)
用珂朵莉树维护所有的 \(b_i\)
在每个节点维护 \(tim\) 为这个节点的值最后一次实际下放到答案序列的时间
那么每次下放相当于

\[dp_{b_i}\gets 节点长度\times(dp_{now}-dp_{tim}) \]

于是就可以维护了
复杂度也是 \(O(n\log n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 1000010
#define fir first
#define sec second
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
set<int> pos[N];
ll sum[N], dp[N];
const ll mod=998244353;
int a[N], buc[N], nxt[N], now;
struct node{int l, r; mutable int val, tim;};
inline bool operator < (node a, node b) {return a.l<b.l;}
set<node> odt;

inline void spread(node it) {dp[it.val+1]=(dp[it.val+1]+1ll*(it.r-it.l+1)*(sum[now]-sum[it.tim]))%mod;}

auto split(int x) {
	if (x>n) return odt.end();
	auto it=--odt.upper_bound({x, 0, 0, 0});
	if (it->l==x) return it;
	spread(*it);
	int l=it->l, r=it->r, val=it->val, tim=it->tim;
	odt.erase(it);
	odt.insert({l, x-1, val, now});
	return odt.insert({x, r, val, now}).fir;
}

void modify(int l, int val) {
	auto it=split(l); int r=l-1;
	while (it!=odt.end() && it->val<=val) spread(*it), r=it->r, it=odt.erase(it);
	if (l<=r) odt.insert({l, r, val, now});
}

signed main()
{
	freopen("divide.in", "r", stdin);
	freopen("divide.out", "w", stdout);

	n=read();
	for (int i=1; i<=n; ++i) pos[a[i]=min(read(), n)].insert(i);
	for (int i=0; i<=n; ++i) pos[i].insert(n+2);
	pos[n+1].insert(INF);
	int lst=0, val=*pos[0].begin();
	for (int i=0; i<=n+1; ++i) if (*pos[i].begin()>val) {
		odt.insert({lst, i-1, val, 0});
		lst=i; val=*pos[i].begin();
	}
	// cout<<"odt: "; for (auto it:odt) cout<<"("<<it.l<<','<<it.r<<','<<it.val<<','<<it.tim<<") "; cout<<endl;
	dp[1]=sum[1]=1;
	for (now=1; now<=n; ++now) {
		// cout<<"now: "<<now<<endl;
		// cout<<"odt: "; for (auto it:odt) cout<<"("<<it.l<<','<<it.r<<','<<it.val<<','<<it.tim<<") "; cout<<endl;
		// for (auto it:odt) dp[it.val+1]=(dp[it.val+1]+dp[now]*(it.r-it.l+1))%mod;
		spread(*odt.begin()); odt.begin()->tim=now;
		modify(a[now], *pos[a[now]].erase(pos[a[now]].begin()));
		if (now==1) --dp[now];
		dp[now+1]=(dp[now+1]+dp[now])%mod;
		sum[now+1]=(sum[now]+dp[now+1])%mod;
	}
	// cout<<"dp: "; for (int i=1; i<=n+1; ++i) cout<<dp[i]<<' '; cout<<endl;
	cout<<(dp[n+1]%mod+mod)%mod<<endl;

	return 0;
}
posted @ 2022-03-30 08:24  Administrator-09  阅读(2)  评论(0编辑  收藏  举报