题解 Nasty Donchik 一道数据结构题
题目大意
给定一个长度为\(n\)的序列\(a_1,a_2\dots,a_n\)。保证\(\forall i:1\leq a_i\leq n\)。请你求出,序列里有多少三元组\((i,j,k)\),满足\(a[i,j]\)里的所有数,都在\(a[j+1,k]\)里出现过;且\(a[j+1,k]\)里所有数,都在\(a[i,j]\)里出现过。
\(n\leq 2\times 10^5\)。
本题题解
枚举\(k\)。对每个\(j\),维护使三元组\((i,j,k)\)合法的最小的和最大的\(i\),分别记为\(\text{mini}[j],\text{maxi}[j]\)。那么,当前\(k\)的三元组数量就是:\(\sum_{j=1}^{k-1}(\text{maxi}[j]-\text{mini}[j]+1)\)。考虑分别计算\(\text{maxi}\)的和和\(\text{mini}\)的和。
记每个位置\(t\)上的数上一次和下一次出现的位置分别为\(\text{pre}[t]\)和\(\text{nxt}[t]\),特别地,如果前面/后面没有相同的数,则\(\text{pre}[t]=0\)或\(\text{nxt}[t]=n+1\)。那么,我们发现,三元组\((i,j,k)\)合法的充分必要条件是:\(\max_{t=i}^{j}(\text{nxt}[t])\leq k\),且\(\min_{t=j+1}^{k}(\text{pre}[t])\geq i\)。
由此可知,\(\text{maxi}[j]\)就是满足\(\min_{t=j+1}^{k}(\text{pre}[t])\geq i\)的最大的\(i\),\(\text{mini}[j]\)就是满足\(\max_{t=i}^{j}(\text{nxt[}t])\leq k\)的最小的\(i\)。
\(\text{maxi}\)比较好维护,他就等于\(\min_{t=j+1}^{k}(\text{pre}[j])\)。当从\(k-1\)变到\(k\)时,我们让所有\(j\in[1,k-1]\)的\(\text{maxi}[j]\)对\(\text{pre}[k]\)取\(\min\)即可。
考虑\(\text{mini}\)。我们称\(\text{nxt}[t]>k\)的位置为不合法的,其他位置为合法的。那么对于每个\(j\),\(\text{mini}[j]\)就相当于\(j\)前面、最靠近\(j\)的那个不合法的位置\(+1\)。特别地,如果\(j\)本身就不合法,我们认为\(\text{mini}[j]=j+1\)。从\(k-1\)变到\(k\),会使得所有\(\text{nxt}[t]=k\)的位置,从不合法变成合法。相当于把两段\(\text{mini}\)的区间“合并”起来(令后一段区间的值等于前一段区间的值)。而\(\text{nxt}[t]=k\)的位置最多只有一个:就是\(\text{pre}[k]\)。所以每次对一段区间执行区间覆盖(或者区间取\(\min\))即可(事实上因为\(\text{maxi}\)要支持的是区间取\(\min\),所以都用区间取\(\min\)反而更好写)。
还有一个要注意的点是,我们要始终保证,\(\text{mini}[j]\leq\text{maxi}[j]+1\),所以对\(\text{maxi}\)取\(\min\)的时候,要对\(\text{mini}\)做一样的操作。
总结来说,需要支持区间对一个数取\(\min\),区间求和,可以用吉老师线段树实现。另外,我们还要对一个位置求它前面、最靠近它的不合法的位置,同时要支持单点修改(把某个位置从不合法变为合法),这个可以用线段上二分实现。
时间复杂度\(O(n\log n)\)。
参考代码:
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
if(S==T){
T=(S=buf)+fread(buf,1,MAXN,stdin);
if(S==T)return EOF;
}
return *S++;
}
}
#ifdef ONLINE_JUDGE
#define getchar Fread::getchar
#endif
inline int read(){
int f=1,x=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
inline ll readll(){
ll f=1,x=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
/* ------ by:duyi ------ */ // dysyn1314
const int MAXN=2e5;
int n;
/*
struct Baoli{
int a[MAXN+5],val[MAXN+5],val2[MAXN+5];
int get_nxt0(int p){
for(int i=p;i<=n+1;++i)if(a[i]==0)return i;
throw;
}
int get_pre0(int p){
for(int i=p;i>=0;--i)if(a[i]==0)return i;
throw;
}
void set1(int p){
a[p]=1;
}
void init(){
for(int i=1;i<=n;++i)val[i]=val2[i]=i;
}
void modify_min_mxi(int l,int r,int x){
for(int i=l;i<=r;++i)val[i]=min(val[i],x);
}
void modify_min_mni(int l,int r,int x){
for(int i=l;i<=r;++i)val2[i]=min(val2[i],x);
}
int get_sum_mxi(){
int res=0;
for(int i=1;i<=n;++i)res+=val[i]*a[i];
return res;
}
int get_sum_mni(){
int res=0;
for(int i=1;i<=n;++i)res+=val2[i]*a[i];
return res;
}
}T;
*/
class SegmentTree{
private:
int sz[MAXN*4+5],mx[2][MAXN*4+5],se[2][MAXN*4+5],ct[2][MAXN*4+5];
ll sum[2][MAXN*4+5];
void _pu(int p,int *mx,int *se,int *ct,ll *sum){
sum[p]=sum[p<<1]+sum[p<<1|1];
if(mx[p<<1]>mx[p<<1|1]){
mx[p]=mx[p<<1];
se[p]=max(se[p<<1],mx[p<<1|1]);
ct[p]=ct[p<<1];
}
else if(mx[p<<1]<mx[p<<1|1]){
mx[p]=mx[p<<1|1];
se[p]=max(mx[p<<1],se[p<<1|1]);
ct[p]=ct[p<<1|1];
}
else{
mx[p]=mx[p<<1];
se[p]=max(se[p<<1],se[p<<1|1]);
ct[p]=ct[p<<1]+ct[p<<1|1];
}
}
void push_up(int p){
sz[p]=sz[p<<1]+sz[p<<1|1];
_pu(p,mx[0],se[0],ct[0],sum[0]);
_pu(p,mx[1],se[1],ct[1],sum[1]);
}
void _pd(int p,int *mx,int *ct,ll *sum){
if(mx[p]<mx[p<<1]){
sum[p<<1]-=(ll)ct[p<<1]*(mx[p<<1]-mx[p]);
mx[p<<1]=mx[p];
}
if(mx[p]<mx[p<<1|1]){
sum[p<<1|1]-=(ll)ct[p<<1|1]*(mx[p<<1|1]-mx[p]);
mx[p<<1|1]=mx[p];
}
}
void push_down(int p){
_pd(p,mx[0],ct[0],sum[0]);
_pd(p,mx[1],ct[1],sum[1]);
}
void build(int p,int l,int r){
if(l==r){
mx[0][p]=mx[1][p]=l;
se[0][p]=se[1][p]=-1;
return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
push_up(p);
}
void modify1(int p,int l,int r,int pos){
if(l==r){
sz[p]=1;
ct[0][p]=ct[1][p]=1;
sum[0][p]=mx[0][p];
sum[1][p]=mx[1][p];
return;
}
push_down(p);
int mid=(l+r)>>1;
if(pos<=mid)modify1(p<<1,l,mid,pos);
else modify1(p<<1|1,mid+1,r,pos);
push_up(p);
}
int __first0(int p,int l,int r){
if(l==r){assert(sz[p]==0);return l;}
push_down(p);
int mid=(l+r)>>1;
if(sz[p<<1]<mid-l+1)return __first0(p<<1,l,mid);
else return __first0(p<<1|1,mid+1,r);
}
int _nxt0(int p,int l,int r,int ql,int qr){
if(ql<=l && qr>=r){
if(sz[p]==r-l+1)return n+1;
else return __first0(p,l,r);
}
push_down(p);
int mid=(l+r)>>1,res=n+1;
if(ql<=mid&&sz[p<<1]<mid-l+1)res=_nxt0(p<<1,l,mid,ql,qr);
if(res!=n+1)return res;
if(qr>mid&&sz[p<<1|1]<r-mid)return _nxt0(p<<1|1,mid+1,r,ql,qr);
else return n+1;
}
int __last0(int p,int l,int r){
if(l==r){assert(sz[p]==0);return l;}
push_down(p);
int mid=(l+r)>>1;
if(sz[p<<1|1]<r-mid)return __last0(p<<1|1,mid+1,r);
else return __last0(p<<1,l,mid);
}
int _pre0(int p,int l,int r,int ql,int qr){
if(ql<=l && qr>=r){
if(sz[p]==r-l+1)return 0;
else return __last0(p,l,r);
}
push_down(p);
int mid=(l+r)>>1,res=0;
if(qr>mid&&sz[p<<1|1]<r-mid)res=_pre0(p<<1|1,mid+1,r,ql,qr);
if(res)return res;
if(ql<=mid&&sz[p<<1]<mid-l+1)return _pre0(p<<1,l,mid,ql,qr);
else return 0;
}
void modify2(int p,int l,int r,int ql,int qr,int x,int t){
//区间对x取min
if(x>=mx[t][p])return;
if(ql<=l && qr>=r && se[t][p]<x){
sum[t][p]-=(ll)ct[t][p]*(mx[t][p]-x);
mx[t][p]=x;
return;
}
push_down(p);
int mid=(l+r)>>1;
if(ql<=mid)modify2(p<<1,l,mid,ql,qr,x,t);
if(qr>mid)modify2(p<<1|1,mid+1,r,ql,qr,x,t);
push_up(p);
}
public:
//mxi tree0
//mni tree1
void set1(int p){modify1(1,1,n,p);}
int get_nxt0(int p){
if(p>n)return n+1;
if(p<1)return 0;
return _nxt0(1,1,n,p,n);
}
int get_pre0(int p){
if(p>n)return n+1;
if(p<1)return 0;
return _pre0(1,1,n,1,p);
}
void modify_min_mxi(int l,int r,int x){
if(l>r)return;
modify2(1,1,n,l,r,x,0);
}
void modify_min_mni(int l,int r,int x){
if(l>r)return;
modify2(1,1,n,l,r,x,1);
}
ll get_sum_mxi(){return sum[0][1];}
ll get_sum_mni(){return sum[1][1];}
void init(){build(1,1,n);}
}T;
int a[MAXN+5],nxt[MAXN+5],pre[MAXN+5],pos[MAXN+5];
int main(){
n=read();
for(int i=1;i<=n;++i){a[i]=read();pre[i]=pos[a[i]];pos[a[i]]=i;}
for(int i=1;i<=n;++i)pos[i]=n+1;
for(int i=n;i>=1;--i){nxt[i]=pos[a[i]];pos[a[i]]=i;}
T.init();
ll ans=0;
for(int k=1;k<=n;++k){
if(pre[k]){
int x=T.get_nxt0(pre[k]+1)-1;
//cout<<"* "<<x<<" "<<T.get_pre0(pre[k]-1)<<endl;
T.modify_min_mni(pre[k],x,T.get_pre0(pre[k]-1));
T.set1(pre[k]);
}
T.modify_min_mxi(1,k-1,pre[k]);
T.modify_min_mni(1,k-1,pre[k]);
ans+=T.get_sum_mxi()-T.get_sum_mni();
}
cout<<ans<<endl;
return 0;
}