题解 十

传送门

发现是对子集计数,不好直接做
赛时试图固定右端点算包含这个右端点的无果
结果正解的基础 DP 比这个要暴力亿点
直接枚举一个子集最左、右的两个元素
那么能随意选的元素的限制区间必须包括这两个端点
于是对右端点扫描线
一个能够到当前右端点的 \(i\)\(l_i\) 处的贡献是将 \(l_i\) 位置的值加入总贡献
\([l_i+1, i]\) 的贡献相当于区间乘 2
复杂度 \(O(n\log n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n;
int l[N], r[N];
const ll mod=998244353, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}

namespace force{
	int ans;
	void solve() {
		int lim=1<<n;
		for (int s=1; s<lim; ++s) {
			int lmax=INF, rmax=0;
			for (int i=1; i<=n; ++i) if (s&(1<<(i-1))) lmax=min(lmax, i), rmax=max(rmax, i);
			for (int i=1; i<=n; ++i) if (s&(1<<(i-1))) {
				if (l[i]>lmax || r[i]<rmax) goto jump;
			}
			++ans;
			jump: ;
		}
		cout<<ans<<endl;
	}
}

namespace task1{
	int sta[N], top, ans=0;
	void solve() {
		for (int i=1; i<=n; ++i) {
			top=0;
			for (int j=l[i]; j<i; ++j) sta[++top]=j;
			int lim=1<<top;
			for (int s=0; s<lim; ++s) {
				int lmax=i, rmax=i;
				for (int j=1; j<=top; ++j) if (s&(1<<(j-1))) lmax=min(lmax, sta[j]);
				for (int j=1; j<=top; ++j) if (s&(1<<(j-1))) {
					if (l[sta[j]]>lmax || r[sta[j]]<rmax) goto jump;
				}
				++ans;
				jump: ;
			}
		}
		cout<<ans<<endl;
	}
}

namespace task{
	ll ans;
	vector<pair<int, int>> del[N];
	int tl[N<<2], tr[N<<2]; ll val[N<<2], tag[N<<2], sum[N<<2], cnt[N<<2];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define val(p) val[p]
	#define sum(p) sum[p]
	#define tag(p) tag[p]
	#define cnt(p) cnt[p]
	#define pushup(p) sum(p)=(sum(p<<1)+sum(p<<1|1))%mod
	void spread(int p) {
		if (tag(p)==1) return ;
		val(p<<1)=val(p<<1)*tag(p)%mod; sum(p<<1)=sum(p<<1)*tag(p)%mod; tag(p<<1)=tag(p<<1)*tag(p)%mod;
		val(p<<1|1)=val(p<<1|1)*tag(p)%mod; sum(p<<1|1)=sum(p<<1|1)*tag(p)%mod; tag(p<<1|1)=tag(p<<1|1)*tag(p)%mod;
		tag(p)=1;
	}
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r; tag(p)=1;
		if (l==r) {val(p)=1; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	void upd(int p, int l, int r, ll dat) {
		if (l<=tl(p)&&r>=tr(p)) {val(p)=val(p)*dat%mod; sum(p)=sum(p)*dat%mod; tag(p)=tag(p)*dat%mod; return ;}
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid) upd(p<<1, l, r, dat);
		if (r>mid) upd(p<<1|1, l, r, dat);
		pushup(p);
	}
	void upd(int p, int pos, ll k) {
		if (tl(p)==tr(p)) {cnt(p)+=k; sum(p)=val(p)*(qpow(2, cnt(p))-1)%mod; return ;}
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (pos<=mid) upd(p<<1, pos, k);
		else upd(p<<1|1, pos, k);
		pushup(p);
	}
	ll query(int p, int l, int r) {
		if (l<=tl(p)&&r>=tr(p)) return sum(p);
		spread(p);
		int mid=(tl(p)+tr(p))>>1; ll ans=0;
		if (l<=mid) ans=(ans+query(p<<1, l, r))%mod;
		if (r>mid) ans=(ans+query(p<<1|1, l, r))%mod;
		return ans;
	}
	void solve() {
		build(1, 1, n);
		for (int i=1; i<=n; ++i) del[r[i]].pb({l[i], i});
		for (int i=1; i<=n; ++i) {
			// cout<<"i: "<<i<<' '<<query(1, l[i], i)<<endl;
			ans=(ans+query(1, l[i], i))%mod;
			if (l[i]<i) upd(1, l[i], i-1, 2);
			upd(1, i, 1); //, cout<<"upd: "<<l[i]<<endl;
			for (auto it:del[i]) upd(1, it.sec, -1); //, cout<<"del: "<<it.fir<<endl;
			for (auto it:del[i])
				if (it.fir<it.sec) upd(1, it.fir, it.sec-1, inv2);
		}
		cout<<(ans+n)%mod<<endl;
	}
}

signed main()
{
	freopen("ten.in", "r", stdin);
	freopen("ten.out", "w", stdout);

	n=read();
	for (int i=1; i<=n; ++i) l[i]=read();
	for (int i=1; i<=n; ++i) r[i]=read();
	// if (n<=20) force::solve();
	// else task1::solve();
	task::solve();

	return 0;
}
posted @ 2022-03-17 19:14  Administrator-09  阅读(2)  评论(0编辑  收藏  举报