[做题笔记] 浅谈势能线段树在特殊区间问题上的应用
区间最值操作
题目描述
维护一个数据结构支持区间取最小值,查询区间最大值,查询区间和。
解法
线段树上每个节点维护 \(mx\) 表示区间最大值,\(cx\) 表示区间严格次大值,对于修改我们这样做:
- 如果 \(mx\leq t\),那么忽略这次取最小值的操作。
- 如果 \(mx>t>cx\),设区间中 \(mx\) 有 \(num\) 个,那么打上标记,把区间和减去 \(num\cdot(mx-t)\)
- 如果 \(t\leq cx\),暴力往下递归。
可以设计势能函数 \(h(x)\) 表示线段树上节点 \(x\) 的代表区间中互不相同的元素个数,考虑无论是打标记还是往下递归都是花费 \(O(1)\) 的时间将势能减少 \(1\),初始势能是 \(n\log n\),所以时间复杂度 \(O(n\log n)\)
总结
对于一些奇怪的区间操作,可以考虑势能线段树。
我们可以先尽量多想一些剪枝,然后用势能函数证明时间复杂度。
关于势能函数的定义,可以考虑关键操作会让什么量减少,尝试把它定义成势能函数。
#include <cstdio>
#include <iostream>
using namespace std;
const int M = 1000005;
#define ll long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int T,n,m,mx[4*M],cx[4*M],num[4*M];ll s[4*M];
void up(int i)
{
num[i]=0;
mx[i]=max(mx[i<<1],mx[i<<1|1]);
cx[i]=max(cx[i<<1],cx[i<<1|1]);
if(mx[i<<1]!=mx[i<<1|1])
cx[i]=max(cx[i],min(mx[i<<1],mx[i<<1|1]));
if(mx[i]==mx[i<<1]) num[i]+=num[i<<1];
if(mx[i]==mx[i<<1|1]) num[i]+=num[i<<1|1];
s[i]=s[i<<1]+s[i<<1|1];
}
void fuck(int i,int c)
{
if(mx[i]<=c) return ;
s[i]-=1ll*(mx[i]-c)*num[i];
mx[i]=c;
}
void down(int i)
{
fuck(i<<1,mx[i]);
fuck(i<<1|1,mx[i]);
}
void build(int i,int l,int r)
{
if(l==r)
{
s[i]=mx[i]=read();
num[i]=1;cx[i]=-1;
return ;
}
int mid=(l+r)>>1;
build(i<<1,l,mid);
build(i<<1|1,mid+1,r);
up(i);
}
void zxy(int i,int l,int r,int c)
{
if(mx[i]<=c) return ;
if(mx[i]>c && c>cx[i])
{
fuck(i,c);
return ;
}
if(l==r)
{
mx[i]=s[i]=min(c,mx[i]);
return ;
}
int mid=(l+r)>>1;down(i);
zxy(i<<1,l,mid,c);
zxy(i<<1|1,mid+1,r,c);
up(i);
}
void upd(int i,int l,int r,int L,int R,int c)
{
if(L>r || l>R) return ;
if(L<=l && r<=R)
{
zxy(i,l,r,c);
return ;
}
int mid=(l+r)>>1;down(i);
upd(i<<1,l,mid,L,R,c);
upd(i<<1|1,mid+1,r,L,R,c);
up(i);
}
int askmax(int i,int l,int r,int L,int R)
{
if(L>r || l>R) return 0;
if(L<=l && r<=R) return mx[i];
int mid=(l+r)>>1;down(i);
return max(askmax(i<<1,l,mid,L,R),
askmax(i<<1|1,mid+1,r,L,R));
}
ll asksum(int i,int l,int r,int L,int R)
{
if(L>r || l>R) return 0;
if(L<=l && r<=R) return s[i];
int mid=(l+r)>>1;down(i);
return asksum(i<<1,l,mid,L,R)+
asksum(i<<1|1,mid+1,r,L,R);
}
void write(ll x)
{
if (x < 0) x = ~x + 1, putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
signed main()
{
T=read();
while(T--)
{
n=read();m=read();
build(1,1,n);
while(m--)
{
int op=read(),l=read(),r=read();
if(op==0) upd(1,1,n,l,r,read());
if(op==1) write(askmax(1,1,n,l,r)),puts("");
if(op==2) write(asksum(1,1,n,l,r)),puts("");
}
}
}
带区间加法的区间除法
题目描述
解法
考虑除法减少得是很快的,而加法只是把区间权值整体抬升,所以我们考虑定义势能函数 \(h(x)=\lg (mx-mi)\),也就是达到状态 \(mx-mi\leq 1\) 需要被除的次数。
显然初始时势能总和是 \(n\log n\log c\),对于整个被除的区间,如果我们向下递归,那么势能一定减少 \(1\),这说明我们花费了 \(O(1)\) 的时间让势能减少 \(1\)
再考虑操作中带来的势能增加,考虑一次加法操作部分影响的区间有 \(\log n\) 个,单个区间增加的势能不超过 \(\log c\),所以总势能增加不超过 \(q\log n\log c\);考虑除法只除到了一个区间的部分,这部分的增量也是类似的 \(q\log n\log c\)
所以总时间复杂度 \(O((n+q)\log n\log c)\),具体实现中我们判断 \(mx-\frac{mx}{d}=mi-\frac{mi}{d}\) 就打减法标记。
#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
const int M = 100005;
#define int long long
const int inf = 1e18;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,q,fl[4*M],s[4*M],mi[4*M],mx[4*M];
void work(int x,int y,int len)
{
fl[x]+=y;s[x]+=y*len;
mi[x]+=y;mx[x]+=y;
}
void up(int i)
{
s[i]=s[i<<1]+s[i<<1|1];
mi[i]=min(mi[i<<1],mi[i<<1|1]);
mx[i]=max(mx[i<<1],mx[i<<1|1]);
}
void down(int i,int l,int r)
{
int mid=(l+r)>>1;
if(!fl[i]) return ;
work(i<<1,fl[i],mid-l+1);
work(i<<1|1,fl[i],r-mid);
fl[i]=0;
}
void add(int i,int l,int r,int L,int R,int x)
{
if(L>r || l>R) return ;
if(L<=l && r<=R)
{
work(i,x,r-l+1);
return ;
}
int mid=(l+r)>>1;down(i,l,r);
add(i<<1,l,mid,L,R,x);
add(i<<1|1,mid+1,r,L,R,x);
up(i);
}
int asksum(int i,int l,int r,int L,int R)
{
if(L>r || l>R) return 0;
if(L<=l && r<=R) return s[i];
int mid=(l+r)>>1;down(i,l,r);
return asksum(i<<1,l,mid,L,R)
+asksum(i<<1|1,mid+1,r,L,R);
}
int askmin(int i,int l,int r,int L,int R)
{
if(L>r || l>R) return inf;
if(L<=l && r<=R) return mi[i];
int mid=(l+r)>>1;down(i,l,r);
return min(askmin(i<<1,l,mid,L,R),
askmin(i<<1|1,mid+1,r,L,R));
}
int wxk(int x,int y)
{
return (int)floor(1.0*x/y);
}
void zxy(int i,int l,int r,int c)
{
if(l==r)
{
mx[i]=mi[i]=s[i]=wxk(mx[i],c);
return ;
}
if(mx[i]-wxk(mx[i],c)==mi[i]-wxk(mi[i],c))
{
work(i,wxk(mx[i],c)-mx[i],r-l+1);
return ;
}
int mid=(l+r)>>1;down(i,l,r);
zxy(i<<1,l,mid,c);
zxy(i<<1|1,mid+1,r,c);
up(i);
}
void div(int i,int l,int r,int L,int R,int c)
{
if(L>r || l>R) return ;
if(L<=l && r<=R) {zxy(i,l,r,c);return ;}
int mid=(l+r)>>1;down(i,l,r);
div(i<<1,l,mid,L,R,c);
div(i<<1|1,mid+1,r,L,R,c);
up(i);
}
signed main()
{
n=read();q=read();
for(int i=1;i<=n;i++)
add(1,1,n,i,i,read());
while(q--)
{
int op=read(),l=read()+1,r=read()+1;
if(op==1) add(1,1,n,l,r,read());
if(op==2) div(1,1,n,l,r,read());
if(op==3) printf("%lld\n",askmin(1,1,n,l,r));
if(op==4) printf("%lld\n",asksum(1,1,n,l,r));
}
}