[做题笔记] 浅谈势能线段树在特殊区间问题上的应用

区间最值操作

题目描述

点此看题

维护一个数据结构支持区间取最小值,查询区间最大值,查询区间和。

解法

线段树上每个节点维护 \(mx\) 表示区间最大值,\(cx\) 表示区间严格次大值,对于修改我们这样做:

  • 如果 \(mx\leq t\),那么忽略这次取最小值的操作。
  • 如果 \(mx>t>cx\),设区间中 \(mx\)\(num\) 个,那么打上标记,把区间和减去 \(num\cdot(mx-t)\)
  • 如果 \(t\leq cx\),暴力往下递归。

可以设计势能函数 \(h(x)\) 表示线段树上节点 \(x\) 的代表区间中互不相同的元素个数,考虑无论是打标记还是往下递归都是花费 \(O(1)\) 的时间将势能减少 \(1\),初始势能是 \(n\log n\),所以时间复杂度 \(O(n\log n)\)

总结

对于一些奇怪的区间操作,可以考虑势能线段树。

我们可以先尽量多想一些剪枝,然后用势能函数证明时间复杂度。

关于势能函数的定义,可以考虑关键操作会让什么量减少,尝试把它定义成势能函数。

#include <cstdio>
#include <iostream>
using namespace std;
const int M = 1000005;
#define ll long long
int read()
{
    int x=0,f=1;char c;
    while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
    while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
    return x*f;
}
int T,n,m,mx[4*M],cx[4*M],num[4*M];ll s[4*M];
void up(int i)
{
    num[i]=0;
    mx[i]=max(mx[i<<1],mx[i<<1|1]);
    cx[i]=max(cx[i<<1],cx[i<<1|1]);
    if(mx[i<<1]!=mx[i<<1|1])
        cx[i]=max(cx[i],min(mx[i<<1],mx[i<<1|1]));
    if(mx[i]==mx[i<<1]) num[i]+=num[i<<1];
    if(mx[i]==mx[i<<1|1]) num[i]+=num[i<<1|1];
    s[i]=s[i<<1]+s[i<<1|1];
}
void fuck(int i,int c)
{
    if(mx[i]<=c) return ;
    s[i]-=1ll*(mx[i]-c)*num[i];
    mx[i]=c;
}
void down(int i)
{
    fuck(i<<1,mx[i]);
    fuck(i<<1|1,mx[i]);
}
void build(int i,int l,int r)
{
    if(l==r)
    {
        s[i]=mx[i]=read();
        num[i]=1;cx[i]=-1;
        return ;
    }
    int mid=(l+r)>>1;
    build(i<<1,l,mid);
    build(i<<1|1,mid+1,r);
    up(i);
}
void zxy(int i,int l,int r,int c)
{
    if(mx[i]<=c) return ;
    if(mx[i]>c && c>cx[i])
    {
        fuck(i,c);
        return ;
    }
    if(l==r)
    {
        mx[i]=s[i]=min(c,mx[i]);
        return ;
    }
    int mid=(l+r)>>1;down(i);
    zxy(i<<1,l,mid,c);
    zxy(i<<1|1,mid+1,r,c);
    up(i);
}
void upd(int i,int l,int r,int L,int R,int c)
{
    if(L>r || l>R) return ;
    if(L<=l && r<=R)
    {
        zxy(i,l,r,c);
        return ;
    }
    int mid=(l+r)>>1;down(i);
    upd(i<<1,l,mid,L,R,c);
    upd(i<<1|1,mid+1,r,L,R,c);
    up(i);
}
int askmax(int i,int l,int r,int L,int R)
{
    if(L>r || l>R) return 0;
    if(L<=l && r<=R) return mx[i];
    int mid=(l+r)>>1;down(i);
    return max(askmax(i<<1,l,mid,L,R), 
    askmax(i<<1|1,mid+1,r,L,R));
}
ll asksum(int i,int l,int r,int L,int R)
{
    if(L>r || l>R) return 0;
    if(L<=l && r<=R) return s[i];
    int mid=(l+r)>>1;down(i);
    return asksum(i<<1,l,mid,L,R)+
    asksum(i<<1|1,mid+1,r,L,R);
}
void write(ll x)
{
    if (x < 0) x = ~x + 1, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
signed main()
{
    T=read();
    while(T--)
    {
        n=read();m=read();
        build(1,1,n);
        while(m--)
        {
            int op=read(),l=read(),r=read();
            if(op==0) upd(1,1,n,l,r,read());
            if(op==1) write(askmax(1,1,n,l,r)),puts("");
            if(op==2) write(asksum(1,1,n,l,r)),puts("");
        }
    }
}

带区间加法的区间除法

题目描述

点此看题

解法

考虑除法减少得是很快的,而加法只是把区间权值整体抬升,所以我们考虑定义势能函数 \(h(x)=\lg (mx-mi)\),也就是达到状态 \(mx-mi\leq 1\) 需要被除的次数。

显然初始时势能总和是 \(n\log n\log c\),对于整个被除的区间,如果我们向下递归,那么势能一定减少 \(1\),这说明我们花费了 \(O(1)\) 的时间让势能减少 \(1\)

再考虑操作中带来的势能增加,考虑一次加法操作部分影响的区间有 \(\log n\) 个,单个区间增加的势能不超过 \(\log c\),所以总势能增加不超过 \(q\log n\log c\);考虑除法只除到了一个区间的部分,这部分的增量也是类似的 \(q\log n\log c\)

所以总时间复杂度 \(O((n+q)\log n\log c)\),具体实现中我们判断 \(mx-\frac{mx}{d}=mi-\frac{mi}{d}\) 就打减法标记。

#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
const int M = 100005;
#define int long long
const int inf = 1e18;
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,q,fl[4*M],s[4*M],mi[4*M],mx[4*M];
void work(int x,int y,int len)
{
	fl[x]+=y;s[x]+=y*len;
	mi[x]+=y;mx[x]+=y;
}
void up(int i)
{
	s[i]=s[i<<1]+s[i<<1|1];
	mi[i]=min(mi[i<<1],mi[i<<1|1]);
	mx[i]=max(mx[i<<1],mx[i<<1|1]);
}
void down(int i,int l,int r)
{
	int mid=(l+r)>>1;
	if(!fl[i]) return ;
	work(i<<1,fl[i],mid-l+1);
	work(i<<1|1,fl[i],r-mid);
	fl[i]=0;
}
void add(int i,int l,int r,int L,int R,int x)
{
	if(L>r || l>R) return ;
	if(L<=l && r<=R)
	{
		work(i,x,r-l+1);
		return ;
	}
	int mid=(l+r)>>1;down(i,l,r);
	add(i<<1,l,mid,L,R,x);
	add(i<<1|1,mid+1,r,L,R,x);
	up(i);
}
int asksum(int i,int l,int r,int L,int R)
{
	if(L>r || l>R) return 0;
	if(L<=l && r<=R) return s[i];
	int mid=(l+r)>>1;down(i,l,r);
	return asksum(i<<1,l,mid,L,R)
	+asksum(i<<1|1,mid+1,r,L,R);
}
int askmin(int i,int l,int r,int L,int R)
{
	if(L>r || l>R) return inf;
	if(L<=l && r<=R) return mi[i];
	int mid=(l+r)>>1;down(i,l,r);
	return min(askmin(i<<1,l,mid,L,R),
	askmin(i<<1|1,mid+1,r,L,R));
}
int wxk(int x,int y)
{
	return (int)floor(1.0*x/y);
}
void zxy(int i,int l,int r,int c)
{
	if(l==r)
	{
		mx[i]=mi[i]=s[i]=wxk(mx[i],c);
		return ;
	}
	if(mx[i]-wxk(mx[i],c)==mi[i]-wxk(mi[i],c))
	{
		work(i,wxk(mx[i],c)-mx[i],r-l+1);
		return ;
	}
	int mid=(l+r)>>1;down(i,l,r);
	zxy(i<<1,l,mid,c);
	zxy(i<<1|1,mid+1,r,c);
	up(i);
}
void div(int i,int l,int r,int L,int R,int c)
{
	if(L>r || l>R) return ;
	if(L<=l && r<=R) {zxy(i,l,r,c);return ;}
	int mid=(l+r)>>1;down(i,l,r);
	div(i<<1,l,mid,L,R,c);
	div(i<<1|1,mid+1,r,L,R,c);
	up(i);
}
signed main()
{
	n=read();q=read();
	for(int i=1;i<=n;i++)
		add(1,1,n,i,i,read());
	while(q--)
	{
		int op=read(),l=read()+1,r=read()+1;
		if(op==1) add(1,1,n,l,r,read());
		if(op==2) div(1,1,n,l,r,read());
		if(op==3) printf("%lld\n",askmin(1,1,n,l,r));
		if(op==4) printf("%lld\n",asksum(1,1,n,l,r)); 
	}
}
posted @ 2021-08-24 21:40  C202044zxy  阅读(177)  评论(0编辑  收藏  举报