【XSY3813】漏网之鱼(线段树)
在 “XSY未完成” 题单里咕了好久的题。
题面
题解
为了方便表述,不妨设 \(mex(l,r)\) 表示 \(mex(a[l],a[l+1],\cdots,a[r])\)。
首先,对于权值大于 \(n+1\) 的 \(a[i]\),我们直接把它设为 \(n+1\) 就行了,因为 \(mex\) 肯定不会大于 \(n\)。
然后 \(O(n)\) 预处理出 \(nxt_i\) 表示 \(a_i\) 下一次出现的位置(不存在则为 \(n+1\))。
先考虑只有一次询问且询问的是 \([1,n]\) 的情况。
我们考虑移动左端点,每次算出左端点固定的情况下的所有区间的答案。
开一棵线段树维护区间 \(mex\) 和,其中第 \(i\) 个叶子节点表示 \(mex(tim,i)\)。(其中 \(tim\) 为当前的左端点)
首先当左端点在 \(1\) 的时候,每个区间的 \(mex\) 是可以预处理的,我们用这个来建树。
考虑左端点从 \(i\) 移动到 \(i+1\):此时对于右端点在 \([i+1,nxt[i]-1]\) 中的区间,都没有了 \(a_i\) 的贡献,所以这些区间的 \(mex\) 要对 \(a_i\) 取 \(\min\)。
但是我们如何同时维护区间取 \(\min\) 和区间求和呢?
发现左端点固定、右端点递增时,区间的 \(mex\) 值是单调不降的。所以我们对一个区间和 \(a_i\) 取 \(\min\) 实际上是对这个区间的右边一部分区间赋值为 \(a_i\)。
那我们可以记录一个最大值,然后在线段树上二分查找第一个大于 \(a_i\) 的位置,区间赋值就行了。
每次移动左端点时统计一下答案即可。
考虑如何扩展到多组区间询问的情况。
我们可以参照主席树的思想:当左端点移动到 \(tim\) 时,线段树中的节点存储的不再是当前的值,而是历史信息的和。
具体地,假设当前左端点移动到 \(tim\),线段树中的某一节点代表的区间为 \([x,y]\),那么原本这个节点存储的值是 \(\sum\limits_{r=x}^{y}mex(tim,r)\),现在变成了存储 \(\sum\limits_{l=1}^{tim}\sum\limits_{r=x}^{y}mex(l,r)\),也就是把之前 \(tim\) 移动到每一个位置的答案都统计了起来。
那么我们可以把一组询问 \([l,r]\) 拆成两个:当 \(tim=r\) 时加上区间 \([l,r]\) 的答案(即加上当左端点在 \(1\sim r\)、右端点在 \(l\sim r\) 时的答案);当 \(tim=l-1\) 时减去区间 \([l,r]\) 的答案(即加上当左端点在 \(1\sim l-1\)、右端点在 \(l\sim r\) 时的答案)。
那么现在的问题就是如何维护历史信息的和了,也就是如何维护这棵线段树。
发现对于某个叶子节点来说,如果一直都没有修改操作,那么它的贡献是关于 \(tim\) 的一次函数形式,因为当左端点 \(tim\) 右移的时候,它的增加量是一定的。
因为一次函数满足可加性,所以可以设一个 \(sumk\) 和 \(sumb\) 表示区间内所有叶子节点的一次函数的 \(k\) 和 \(b\) 的和。
那么如果有修改呢?
那么此时的斜率就变了,那么就会变成一个分段函数:
但我们不可能每个点都维护一个分段函数啊……
于是需要将拆完后的询问离线,并把他们按左端点排序,然后我们每次只需要维护最新的一条直线即可。
不妨设当前左端点从 \(tim\) 移动到 \(tim+1\),然后要将一个区间的直线的斜率全部区间赋值为 \(k'=a_{tim}\)。
那么对于这个区间内的某个叶子节点所代表的的直线,假设它原来的解析式为 \(f(x)=k_jx+b_j\),又由于斜率变化后还是要经过点 \((tim,f(tim))\) 的,所以可以得到新的直线方程为:
就是说由原来红的线变成了蓝的线:
那么对于整个区间来说:
但是这样做是不行的,因为你发现这个更新有先后顺序,所以不方便打懒标记。
那怎么办呢?
我们考虑当区间内的 \(k\) 都相同的时候才更新,自己推推发现这样可以改成区间加,变成一个只和 \(size\) 有关而和 \(tim\) 无关的式子。(详见代码)
代码如下:
#include<bits/stdc++.h>
#define N 1000010
#define ll long long
using namespace std;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
struct Query
{
int pos,l,r,id,opt;
Query(){};
Query(int a,int l1,int r1,int b,int c){pos=a,l=l1,r=r1,id=b,opt=c;}
}q[N<<1];
inline bool operator < (Query a,Query b)
{
return a.pos<b.pos;
}
int type,n,Q,numq,mex,a[N];
int last[N],nxt[N];
int id[N<<2],minn[N<<2],maxn[N<<2],lazyk[N<<2];
ll sumk[N<<2],sumb[N<<2],lazyb[N<<2];
bool vis[N],tag[N<<2];
ll Ans[N];
inline void up(int k)
{
minn[k]=min(minn[k<<1],minn[k<<1|1]);
maxn[k]=max(maxn[k<<1],maxn[k<<1|1]);
sumk[k]=sumk[k<<1]+sumk[k<<1|1];
sumb[k]=sumb[k<<1]+sumb[k<<1|1];
}
inline void downn(int k,int l,int r,int nowk,ll nowb)
{
lazyk[k]+=nowk;
lazyb[k]+=nowb;
sumk[k]+=1ll*(r-l+1)*nowk;
sumb[k]+=1ll*(r-l+1)*nowb;
minn[k]+=nowk,maxn[k]+=nowk;
tag[k]=1;
}
inline void down(int k,int l,int r,int mid)
{
if(tag[k])
{
downn(k<<1,l,mid,lazyk[k],lazyb[k]);
downn(k<<1|1,mid+1,r,lazyk[k],lazyb[k]);
lazyk[k]=lazyb[k]=tag[k]=0;
}
}
inline void build(int k,int l,int r)
{
if(l==r)
{
id[l]=k;
vis[a[l]]=1;
while(vis[mex]) mex++;
minn[k]=maxn[k]=sumk[k]=mex;
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
up(k);
}
inline void update(int k,int l,int r,int ql,int qr,int x,int tim)
{
if(maxn[k]<=x) return;
if(ql<=l&&r<=qr&&minn[k]==maxn[k])
{
downn(k,l,r,x-maxn[k],1ll*(maxn[k]-x)*tim);
return;
}
int mid=(l+r)>>1;
down(k,l,r,mid);
if(ql<=mid) update(k<<1,l,mid,ql,qr,x,tim);
if(qr>mid) update(k<<1|1,mid+1,r,ql,qr,x,tim);
up(k);
}
inline ll query(int k,int l,int r,int ql,int qr,int x)
{
if(ql<=l&&r<=qr) return 1ll*sumk[k]*x+sumb[k];
int mid=(l+r)>>1;
down(k,l,r,mid);
ll ans=0;
if(ql<=mid) ans+=query(k<<1,l,mid,ql,qr,x);
if(qr>mid) ans+=query(k<<1|1,mid+1,r,ql,qr,x);
return ans;
}
int main()
{
type=read(),n=read();
for(int i=1;i<=n;i++)
{
a[i]=read();
if(a[i]>n+1) a[i]=n+1;
}
for(int i=0;i<=n+1;i++) last[i]=n+1;
for(int i=n;i>=1;i--) nxt[i]=last[a[i]],last[a[i]]=i;
Q=read();
for(int i=1;i<=Q;i++)
{
int l=read(),r=read();
q[++numq]=Query(r,l,r,i,1);
q[++numq]=Query(l-1,l,r,i,-1);
}
sort(q+1,q+numq+1);
build(1,1,n);
int tmp=1;
while(tmp<=numq&&(!q[tmp].pos)) tmp++;
for(int tim=1;tim<=n;tim++)
{
while(tmp<=numq&&q[tmp].pos==tim)
{
Ans[q[tmp].id]+=1ll*q[tmp].opt*query(1,1,n,q[tmp].l,q[tmp].r,tim);
tmp++;
}
if(tim+1<=nxt[tim]-1) update(1,1,n,tim+1,nxt[tim]-1,a[tim],tim);
update(1,1,n,tim,tim,0,tim);
}
for(int i=1;i<=Q;i++)
printf("%lld\n",Ans[i]);
return 0;
}
/*
0
5
1 0 1 0 2
5
2 2
2 3
1 2
2 4
1 5
*/
/*
0
10
2 4 1 10794 0 0 5 11706 2 21850
1
3 9
*/