题解 十
发现是对子集计数,不好直接做
赛时试图固定右端点算包含这个右端点的无果
结果正解的基础 DP 比这个要暴力亿点
直接枚举一个子集最左、右的两个元素
那么能随意选的元素的限制区间必须包括这两个端点
于是对右端点扫描线
一个能够到当前右端点的 \(i\) 在 \(l_i\) 处的贡献是将 \(l_i\) 位置的值加入总贡献
在 \([l_i+1, i]\) 的贡献相当于区间乘 2
复杂度 \(O(n\log n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define pb push_back
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n;
int l[N], r[N];
const ll mod=998244353, inv2=(mod+1)>>1;
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
namespace force{
int ans;
void solve() {
int lim=1<<n;
for (int s=1; s<lim; ++s) {
int lmax=INF, rmax=0;
for (int i=1; i<=n; ++i) if (s&(1<<(i-1))) lmax=min(lmax, i), rmax=max(rmax, i);
for (int i=1; i<=n; ++i) if (s&(1<<(i-1))) {
if (l[i]>lmax || r[i]<rmax) goto jump;
}
++ans;
jump: ;
}
cout<<ans<<endl;
}
}
namespace task1{
int sta[N], top, ans=0;
void solve() {
for (int i=1; i<=n; ++i) {
top=0;
for (int j=l[i]; j<i; ++j) sta[++top]=j;
int lim=1<<top;
for (int s=0; s<lim; ++s) {
int lmax=i, rmax=i;
for (int j=1; j<=top; ++j) if (s&(1<<(j-1))) lmax=min(lmax, sta[j]);
for (int j=1; j<=top; ++j) if (s&(1<<(j-1))) {
if (l[sta[j]]>lmax || r[sta[j]]<rmax) goto jump;
}
++ans;
jump: ;
}
}
cout<<ans<<endl;
}
}
namespace task{
ll ans;
vector<pair<int, int>> del[N];
int tl[N<<2], tr[N<<2]; ll val[N<<2], tag[N<<2], sum[N<<2], cnt[N<<2];
#define tl(p) tl[p]
#define tr(p) tr[p]
#define val(p) val[p]
#define sum(p) sum[p]
#define tag(p) tag[p]
#define cnt(p) cnt[p]
#define pushup(p) sum(p)=(sum(p<<1)+sum(p<<1|1))%mod
void spread(int p) {
if (tag(p)==1) return ;
val(p<<1)=val(p<<1)*tag(p)%mod; sum(p<<1)=sum(p<<1)*tag(p)%mod; tag(p<<1)=tag(p<<1)*tag(p)%mod;
val(p<<1|1)=val(p<<1|1)*tag(p)%mod; sum(p<<1|1)=sum(p<<1|1)*tag(p)%mod; tag(p<<1|1)=tag(p<<1|1)*tag(p)%mod;
tag(p)=1;
}
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r; tag(p)=1;
if (l==r) {val(p)=1; return ;}
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
void upd(int p, int l, int r, ll dat) {
if (l<=tl(p)&&r>=tr(p)) {val(p)=val(p)*dat%mod; sum(p)=sum(p)*dat%mod; tag(p)=tag(p)*dat%mod; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) upd(p<<1, l, r, dat);
if (r>mid) upd(p<<1|1, l, r, dat);
pushup(p);
}
void upd(int p, int pos, ll k) {
if (tl(p)==tr(p)) {cnt(p)+=k; sum(p)=val(p)*(qpow(2, cnt(p))-1)%mod; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) upd(p<<1, pos, k);
else upd(p<<1|1, pos, k);
pushup(p);
}
ll query(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) return sum(p);
spread(p);
int mid=(tl(p)+tr(p))>>1; ll ans=0;
if (l<=mid) ans=(ans+query(p<<1, l, r))%mod;
if (r>mid) ans=(ans+query(p<<1|1, l, r))%mod;
return ans;
}
void solve() {
build(1, 1, n);
for (int i=1; i<=n; ++i) del[r[i]].pb({l[i], i});
for (int i=1; i<=n; ++i) {
// cout<<"i: "<<i<<' '<<query(1, l[i], i)<<endl;
ans=(ans+query(1, l[i], i))%mod;
if (l[i]<i) upd(1, l[i], i-1, 2);
upd(1, i, 1); //, cout<<"upd: "<<l[i]<<endl;
for (auto it:del[i]) upd(1, it.sec, -1); //, cout<<"del: "<<it.fir<<endl;
for (auto it:del[i])
if (it.fir<it.sec) upd(1, it.fir, it.sec-1, inv2);
}
cout<<(ans+n)%mod<<endl;
}
}
signed main()
{
freopen("ten.in", "r", stdin);
freopen("ten.out", "w", stdout);
n=read();
for (int i=1; i<=n; ++i) l[i]=read();
for (int i=1; i<=n; ++i) r[i]=read();
// if (n<=20) force::solve();
// else task1::solve();
task::solve();
return 0;
}