题解 [CF1129D] Isolation
首先有一个 \(O(n\sqrt n\log n)\) 的暴力 DP,枚举以 \(r\) 为右端点的区间的左端点转移
发现需要做一些区间 \(\pm 1\) 以及区间查询 \(\sum y_i\mid (x_i, y_i), x_i\leqslant k\),其中 \(k\) 是定值
容易想到分块后每次查询块内二分
然而这并没有用到只有 \(\pm 1\) 的操作以及查询的 \(k\) 是定值的性质
考虑对每个块按权值开桶,指针维护 \(x_i\leqslant k\) 的位置
因为只有 \(\pm 1\) 的操作所以指针的移动是单次 \(O(1)\) 的
对散块修改时可以归并重构
特别注意是对每种权值开桶,也即多个权值相等的元素占同一个位置,否则复杂度是假的
最终复杂度 \(O(n\sqrt n)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 100010
#define fir first
#define sec second
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, k;
ll f[N], sum[325];
const ll mod=998244353;
int siz[325], tag[325], pos[325], usiz[325];
int a[N], bel[N], tem[N], lst[N], sqr, top1, top2;
pair<int, int> buc[325][325], dat[325][325], sta1[325], sta2[325];
void build(int id) {
int p1=1, p2=1, lst=-1; siz[id]=usiz[id]=0;
while (p1<=top1||p2<=top2) {
int minn=INF;
if (p1<=top1) minn=min(minn, sta1[p1].fir);
if (p2<=top2) minn=min(minn, sta2[p2].fir);
if (minn!=lst) dat[id][++usiz[id]]={minn, 0}, lst=minn;
if (p1<=top1 && minn==sta1[p1].fir) buc[id][++siz[id]]=sta1[p1], dat[id][usiz[id]].sec=(dat[id][usiz[id]].sec+f[sta1[p1++].sec-1])%mod;
if (p2<=top2 && minn==sta2[p2].fir) buc[id][++siz[id]]=sta2[p2], dat[id][usiz[id]].sec=(dat[id][usiz[id]].sec+f[sta2[p2++].sec-1])%mod;
}
pos[id]=sum[id]=tag[id]=0;
// while (pos[id]<siz[id] && buc[id][pos[id]+1].fir<=k) sum[id]=(sum[id]+f[buc[id][++pos[id]].sec-1])%mod;
while (pos[id]<usiz[id] && dat[id][pos[id]+1].fir<=k) sum[id]=(sum[id]+dat[id][++pos[id]].sec)%mod;
}
void add(int l, int r, int val) {
if (l>r) return ;
int sid=bel[l], eid=bel[r]; top1=top2=0;
if (sid==eid) {
for (int i=1; i<=siz[sid]; ++i)
if (l<=buc[sid][i].sec&&buc[sid][i].sec<=r) sta1[++top1]={buc[sid][i].fir+tag[sid]+val, buc[sid][i].sec};
else sta2[++top2]={buc[sid][i].fir+tag[sid], buc[sid][i].sec};
build(sid);
return ;
}
top1=top2=0;
for (int i=1; i<=siz[sid]; ++i)
if (l<=buc[sid][i].sec&&buc[sid][i].sec<=r) sta1[++top1]={buc[sid][i].fir+tag[sid]+val, buc[sid][i].sec};
else sta2[++top2]={buc[sid][i].fir+tag[sid], buc[sid][i].sec};
build(sid);
for (int i=sid+1; i<eid; ++i) {
tag[i]+=val;
// while (pos[i]<siz[i] && buc[i][pos[i]+1].fir+tag[i]<=k) sum[i]=(sum[i]+f[buc[i][++pos[i]].sec-1])%mod;
// while (pos[i] && buc[i][pos[i]].fir+tag[i]>k) sum[i]=(sum[i]-f[buc[i][pos[i]--].sec-1])%mod;
while (pos[i]<usiz[i] && dat[i][pos[i]+1].fir+tag[i]<=k) sum[i]=(sum[i]+dat[i][++pos[i]].sec)%mod;
while (pos[i] && dat[i][pos[i]].fir+tag[i]>k) sum[i]=(sum[i]-dat[i][pos[i]--].sec)%mod;
}
top1=top2=0;
for (int i=1; i<=siz[eid]; ++i)
if (l<=buc[eid][i].sec&&buc[eid][i].sec<=r) sta1[++top1]={buc[eid][i].fir+tag[eid]+val, buc[eid][i].sec};
else sta2[++top2]={buc[eid][i].fir+tag[eid], buc[eid][i].sec};
build(eid);
}
ll query(int r) {ll ans=0; for (int i=1; i<=bel[r]; ++i) ans=(ans+sum[i])%mod; return ans;}
void push(int id, pair<int, int> dat) {
top1=top2=0;
for (int i=1; i<=siz[id]; ++i) sta1[++top1]={buc[id][i].fir+tag[id], buc[id][i].sec};
sta2[++top2]=dat;
build(id);
}
signed main()
{
n=read(); k=read(); sqr=sqrt(n);
for (int i=1; i<=n; ++i) a[i]=read();
f[0]=1;
for (int i=1; i<=n; ++i) bel[i]=(i-1)/sqr+1;
for (int i=1; i<=n; ++i) lst[i]=tem[a[i]], tem[a[i]]=i;
for (int i=1; i<=n; ++i) {
add(lst[i]+1, i-1, 1);
add(lst[lst[i]]+1, lst[i], -1);
push(bel[i], {1, i});
f[i]=query(i);
}
// cout<<"f: "; for (int i=1; i<=n; ++i) cout<<f[i]<<' '; cout<<endl;
printf("%lld\n", (f[n]%mod+mod)%mod);
return 0;
}