【模板】线段树 Segment Tree / Segtree

posted on 2021-07-26 20:14:08 | under 学术 | source
posted on 2021-11-19 15:19:41 | under 模板 | source

稍微地封装了一下,还是比较好用的。

这个线段树必跑满,注意一下常数。

复杂度为 \(O(n)-O(\log n)\)

现在感觉就是,模板没用,线段树的题千变万化,今天值域 \(10^9\),明天就上树了。奉劝大家:线段树一定要学好,它只是工具,而以后不会再有一道只考工具的题了。

区间加+区间和(RSQ)

https://www.luogu.com.cn/blog/pks-LOVING/senior-data-structure-qian-tan-xian-duan-shu-segment-tree

点击查看代码
template<class T,int N> struct SegmentTree{
    T tag[N<<2],ans[N<<2];
    SegmentTree(){memset(tag,0,sizeof tag),memset(ans,0,sizeof ans);}
    int lc(int p){return p<<1;}
    int rc(int p){return p<<1|1;}
    void pushup(int p){
        ans[p]=ans[lc(p)]+ans[rc(p)];
    }
    void update(T k,int l,int r,int p){
        tag[p]+=k;
        ans[p]+=k*(r-l+1);
    }
    void pushdown(int l,int r,int p){
        if(tag[p]){
            int mid=(l+r)>>1;
            update(tag[p],l,mid,lc(p));
            update(tag[p],mid+1,r,rc(p));
            tag[p]=0;
        }
    }
    void build(T ipt[],int l=1,int r=N,int p=1){
        if(l==r) return (void)(ans[p]=ipt[l]);
        int mid=(l+r)>>1;
        build(ipt,l,mid,lc(p));
        build(ipt,mid+1,r,rc(p));
        pushup(p);
    }
    void add(T k,int ql,int qr,int l=1,int r=N,int p=1){
        if(ql<=l&&r<=qr) return update(k,l,r,p);
        pushdown(l,r,p);
        int mid=(l+r)>>1;
        if(ql<=mid)   add(k,ql,qr,l,mid,lc(p));
        if(mid+1<=qr) add(k,ql,qr,mid+1,r,rc(p));
        pushup(p);
    }
    T query(int ql,int qr,int l=1,int r=N,int p=1){
        if(ql<=l&&r<=qr) return ans[p];
        pushdown(l,r,p);
        T ans=0;
        int mid=(l+r)>>1;
        if(ql<=mid)   ans+=query(ql,qr,l,mid,lc(p));
        if(mid+1<=qr) ans+=query(ql,qr,mid+1,r,rc(p));
        pushup(p);
        return ans;
    }
};

区间赋值+区间和

其实就是改一下 update/spread 再改一下函数名。

点击查看代码
template<class T,int N> struct SegmentTree{
    T tag[N<<2],ans[N<<2];
    SegmentTree(){memset(tag,-1,sizeof tag),memset(ans,0,sizeof ans);}
    int lc(int x){return x<<1;}
    int rc(int x){return x<<1|1;}
    void pushup(int p){
        ans[p]=ans[lc(p)]+ans[rc(p)];
    }
    void spread(T k,int l,int r,int p){
        tag[p]=k;
        ans[p]=k*(r-l+1);
    }
    void pushdown(int l,int r,int p){
        if(tag[p]==-1) return ;
        int mid=(l+r)>>1;
        spread(tag[p],l,mid,lc(p));
        spread(tag[p],mid+1,r,rc(p));
        tag[p]=-1;
    }
    void build(T ipt[],int l=0,int r=N,int p=1){
        if(l==r) return (void)(ans[l]=ipt[l]);
        int mid=(l+r)>>1;
        build(ipt,l,mid,lc(p));
        build(ipt,mid+1,r,rc(p));
        pushup(p);
    }
    void assign(T k,int ql,int qr,int l=0,int r=N,int p=1){
        if(ql<=l&&r<=qr) return spread(k,l,r,p);
        pushdown(l,r,p);
        int mid=(l+r)>>1;
        if(ql<=mid)   assign(k,ql,qr,l,mid,lc(p));
        if(mid+1<=qr) assign(k,ql,qr,mid+1,r,rc(p));
        pushup(p);
    }
    T query(int ql,int qr,int l=0,int r=N,int p=1){
        if(ql<=l&&r<=qr) return ans[p];
        T ans=0;
        pushdown(l,r,p);
        int mid=(l+r)>>1;
        if(ql<=mid)   ans+=query(ql,qr,l,mid,lc(p));
        if(mid+1<=qr) ans+=query(ql,qr,mid+1,r,rc(p));
        pushup(p);
        return ans;
    }
};

区间最小/大值(RMQ)

点击查看代码
template<class T,int N,char cmp,T inf> struct SegmentTree{
    T fun(T a,T b){return cmp=='<'?(a<b?a:b):(a<b?b:a);}
    T ans[N<<2];
    SegmentTree(){memset(ans,inf,sizeof ans);}
    int lc(int p){return p<<1;}
    int rc(int p){return p<<1|1;}
    void build(T ipt[],int l=1,int r=N,int p=1){
        if(l==r) return (void)(ans[p]=ipt[l]);
        int mid=(l+r)>>1;
        build(ipt,l,mid,lc(p));
        build(ipt,mid+1,r,rc(p));
        ans[p]=fun(ans[lc(p)],ans[rc(p)]);
    }
    T query(int ql,int qr,int l=1,int r=N,int p=1){
        if(ql<=l&&r<=qr) return ans[p];
        T ans=inf;
        int mid=(l+r)>>1;
        if(ql<=mid)   ans=fun(ans,query(ql,qr,l,mid,lc(p)));
        if(mid+1<=qr) ans=fun(ans,query(ql,qr,mid+1,r,rc(p)));
        return ans;
    }
};

由于使用方式有点离谱,说明一下:

SegmentTree<int,1000010,'<',0x7fffffff> a;
SegmentTree<int,1000010,'>',-0x7fffffff> b;
//Segment<类型,大小,'<'或'>',极限数> name;
//极限数(inf)需要满足 fun(inf,x)=x

或者这种:

点击查看代码
template<class T,int N,T (*fun)(T,T)> struct SegmentTree{
    T ans[N*4+10];
    SegmentTree(){memset(ans,0,sizeof ans);}
    int lc(int x){return x<<1;}
    int rc(int x){return x<<1|1;}
    int build(T a[],int l=1,int r=N,int p=1){
        if(l==r) return ans[p]=a[l];
        int mid=(l+r)>>1;
        return ans[p]=max(build(a,l,mid,lc(p)),build(a,mid+1,r,rc(p)));
    }
    int query(int ql,int qr,int l=1,int r=N,int p=1){
        if(qr<l||r<ql) return 0;
        if(ql<=l&&r<=qr) return ans[p];
        int mid=(l+r)>>1;
        return max(query(ql,qr,l,mid,lc(p)),query(ql,qr,mid+1,r,rc(p)));
    }
};

通用线段树

把懒标记绑定,注意优先级

点击查看代码
template<class T,int N> struct SegmentTree{
    T tag[N<<2],ans[N<<2];
    SegmentTree(){memset(tag,0,sizeof tag),memset(ans,0,sizeof ans);}
    int lc(int p){return p<<1;}
    int rc(int p){return p<<1|1;}
    void pushup(int p){ans[p]=ans[lc(p)]+ans[rc(p)];}
    void build(T ipt[],int l=1,int r=N,int p=1){
        if(l==r) return (void)(ans[p]=ipt[l]);
        int mid=(l+r)>>1;
        build(ipt,l,mid,lc(p));
        build(ipt,mid+1,r,rc(p));
        pushup(p);
    }
    void update(T k,int l,int r,int p){
        tag[p]+=k;
        ans[p]+=k*(r-l+1);
    }
    void pushdown(int l,int r,int p){
        if(!tag[p]) return ;
        int mid=(l+r)>>1;
        update(tag[p],l,mid,lc(p));
        update(tag[p],mid+1,r,rc(p));
        tag[p]=0;
    }
    void add(T k,int ql,int qr,int l=1,int r=N,int p=1){
        if(ql<=l&&r<=qr) return update(k,l,r,p);
        pushdown(l,r,p);
        int mid=(l+r)>>1;
        if(ql<=mid)   add(k,ql,qr,l,mid,lc(p));
        if(mid+1<=qr) add(k,ql,qr,mid+1,r,rc(p));
        pushup(p);
    }
    T query(int ql,int qr,int l=1,int r=N,int p=1){
        if(ql<=l&&r<=qr) return ans[p];
        pushdown(l,r,p);
        T ans=0;
        int mid=(l+r)>>1;
        if(ql<=mid)   ans+=query(ql,qr,l,mid,lc(p));
        if(mid+1<=qr) ans+=query(ql,qr,mid+1,r,rc(p));
        pushup(p);
        return ans;
    }
};

另一种重度压行的实现:

点击查看代码
template<int N,class T,class A> struct segtree{
	#define mid ((l+r)>>1)
	T tag[4*N+10];A ans[4*N+10];
	A add(T k,int p){return tag[p]+=k,ans[p]+=k;}
	void psdw(int p){add(tag[p],p<<1),add(tag[p],p<<1|1),tag[p]=T();}
	A build(A a[],int l=1,int r=N,int p=1){
		if(tag[p]=T(),ans[p]=A(),l==r) return ans[p]=a[l];
		return ans[p]=build(a,l,mid,p<<1)+build(a,mid+1,r,p<<1|1);
	}
	A modify(T k,int L,int R,int l=1,int r=N,int p=1){
		if(r<L||R<l||L>R) return A();if(L<=l&&r<=R) return add(k,p);
		return psdw(p),ans[p]=modify(k,L,R,l,mid,p<<1)+modify(k,L,R,mid+1,r,p<<1|1);
	}
	A query(int L,int R,int l=1,int r=N,int p=1){
		if(r<L||R<l||L>R) return A();if(L<=l&&r<=R) return ans[p];
		return psdw(p),query(L,R,l,mid,p<<1)+query(L,R,mid+1,r,p<<1|1);
	}
	#undef mid
};

单点修改,区间查询:(没有懒标记)

点击查看代码
template<int N,class A,int RN=N*2> struct segtree{
    A ans[RN+10];int ch[RN+10][2],cnt,root;
    segtree():cnt(-1){root=0;newnode();}
    void add(A k,int &p){if(!p) p=newnode();ans[p]=k;}
    int newnode(){return ans[++cnt]=A(),ch[cnt][0]=ch[cnt][1]=0,cnt;}
    void maintain(int &p){ans[p]=ans[ch[p][0]]+ans[ch[p][1]];}
    void modify(A k,int L,int R,int &p,int l=1,int r=N){
        if(!p) p=newnode();
        if(L<=l&&r<=R) return add(k,p);
        int mid=(l+r)>>1;
        if(L<=mid) modify(k,L,R,ch[p][0],l,mid);
        if(mid+1<=R) modify(k,L,R,ch[p][1],mid+1,r);
        maintain(p);
    }
    A query(int L,int R,int &p,int l=1,int r=N){
        if(!p) return A();
        if(L<=l&&r<=R) return ans[p];
        int mid=(l+r)>>1;A res=A();
        if(L<=mid) res=res+query(L,R,ch[p][0],l,mid);
        if(mid+1<=R) res=res+query(L,R,ch[p][1],mid+1,r);
        return res;
    }
};

单点修改,全局询问,那还要线段树干什么?

点击查看代码
template<int N,class T> struct segtree{
    T ans[N<<2];
    segtree(){memset(ans,0,sizeof ans);}
    void modify(int x,T k,int p=1,int l=1,int r=N){
        if(l==r) return void(ans[p]=k);
        int mid=(l+r)>>1;
        if(x<=mid) modify(x,k,p<<1,l,mid);
        else modify(x,k,p<<1|1,mid+1,r);
        ans[p]=ans[p<<1]+ans[p<<1|1];
    }
};
//query=t.ans[1]

线段树二分

点击查看代码
typedef long long LL;
template<int N> struct segtree{
	int ans[N<<2];
	void build(int b[],int p=1,int l=1,int r=N){
		if(l==r) return ans[p]=b[l],void();
		int mid=(l+r)>>1;
		build(b,p<<1,l,mid),build(b,p<<1|1,mid+1,r);
		ans[p]=min(ans[p<<1],ans[p<<1|1]);
	}
	int query(int L,int R,LL k,int p=1,int l=1,int r=N){
		if(L<=l&&r<=R&&ans[p]>k) return 0;
		if(L<=l&&r<=R&&l==r) return l;
		int mid=(l+r)>>1,res;
		if(L<=mid&&(res=query(L,R,k,p<<1,l,mid))) return res;
		if(mid<R&&(res=query(L,R,k,p<<1|1,mid+1,r))) return res;
		return 0;
	}
};

动态开点

一次修改 \(+O(\log n)\) 个结点

无删除:

点击查看代码
template<int N,class T,class A> struct segtree{
	T tag[N*20+10];A ans[N*20+10];
	int ch[N*20+10][2],cnt,root;
	segtree():cnt(-1){root=0;newnode();}
	void add(T k,int &p,int l,int r){if(!p) p=newnode();ans[p].add(k,l,r),tag[p]+=k;}
	int newnode(){return tag[++cnt]=T(),ans[cnt]=A(),ch[cnt][0]=ch[cnt][1]=0,cnt;}
	void maintain(int &p){ans[p]=ans[ch[p][0]]+ans[ch[p][1]];}
	void pushdown(int &p,int l,int r){
        if(tag[p].empty()) return ;
        int mid=(l+r)>>1;
        add(tag[p],ch[p][0],l,mid);
        add(tag[p],ch[p][1],mid+1,r);
        tag[p]=T();
    }
	void modify(T k,int L,int R,int &p,int l=1,int r=N){
		if(!p) p=newnode();
		if(L<=l&&r<=R) return add(k,p,l,r);
		int mid=(l+r)>>1;
		pushdown(p,l,r);
		if(L<=mid) modify(k,L,R,ch[p][0],l,mid);
		if(mid+1<=R) modify(k,L,R,ch[p][1],mid+1,r);
		maintain(p);
	}
	A query(int L,int R,int &p,int l=1,int r=N){
		if(!p) return A();
		if(L<=l&&r<=R) return ans[p];
		int mid=(l+r)>>1;A res=A();
		pushdown(p,l,r);
		if(L<=mid) res=res+query(L,R,ch[p][0],l,mid);
		if(mid+1<=R) res=res+query(L,R,ch[p][1],mid+1,r);
		return res;
	}
};

如果空间卡的很死,试试垃圾回收

点击查看代码
template<int N,class T,class A> struct segtree{
	T tag[N*20+10];A ans[N*20+10];
	int ch[N*20+10][2],cnt,root[N+10],nod[N*20+10],tot;
	segtree():cnt(-1),tot(0){memset(root,0,sizeof root);newnode();}
	void add(T k,int &p,int l,int r){
        if(!p) p=newnode();
        ans[p].add(k,l,r),tag[p]+=k;
    }
	int newnode(){
        int k=tot==0?++cnt:nod[tot--];
        ch[k][0]=ch[k][1]=0;
        tag[k]=T(),ans[k]=A();
        return k;
    }
    void delnode(int &p){
        if(!ch[p][0]&&!ch[p][1]&&tag[p]==T()&&ans[p]==A()){
            nod[++tot]=p;
            p=0;
        }
    }
	void pushup(int &p){
        ans[p]=ans[ch[p][0]]+ans[ch[p][1]];
    }
	void pushdown(int &p,int l,int r){
        if(tag[p]==T()) return ;
        int mid=(l+r)>>1;
        add(tag[p],ch[p][0],l,mid);
        add(tag[p],ch[p][1],mid+1,r);
        tag[p]=T();
    }
	void modify(T k,int L,int R,int &p,int l=1,int r=N){
		if(!p) p=newnode();
		if(L<=l&&r<=R) return add(k,p,l,r),delnode(p);
		int mid=(l+r)>>1;
		pushdown(p,l,r);
		if(L<=mid) modify(k,L,R,ch[p][0],l,mid);
		if(mid+1<=R) modify(k,L,R,ch[p][1],mid+1,r);
		pushup(p);
        delnode(p);
	}
	A query(int L,int R,int &p,int l=1,int r=N){
		if(!p) return A();
		if(L<=l&&r<=R) return ans[p];
		int mid=(l+r)>>1;A res=A();
		pushdown(p,l,r);
		if(L<=mid) res=res+query(L,R,ch[p][0],l,mid);
		if(mid+1<=R) res=res+query(L,R,ch[p][1],mid+1,r);
        delnode(p);
		return res;
	}
};

examples

接口:Tag 要重载 += 表示懒标记结合,Ans 重载 + 表示答案结合和 add(k,l,r) 表示答案与懒标记结合。

P3372:区间 add+sum

ans.add(k,l,r)==ans+k*(r-l+1),其他不变

P1253:区间 add+cover+max

把 add 和 cover 绑定在一起:

点击查看代码
struct Tag{
	LL add,ass;
	bool isass;
	Tag(LL add=0,LL ass=-1e18,bool isass=0):add(add),ass(ass),isass(isass){}
	Tag operator+=(Tag b){
		if(b.isass){
			add=0;
			ass=b.ass;
			isass=1;
		}
		add+=b.add;
		return *this;
	}
};
struct Ans{
	LL x;
	Ans(LL x=-1e18):x(x){}
	friend Ans operator+(Ans a,Ans b){
		return Ans(max(a.x,b.x));
	}
	Ans add(Tag k,int,int){
		if(k.isass){
			x=k.ass;
		}
		x+=k.add;
		return *this;
	}
};

P3373:区间 add+mul+sum

同样绑在一起,先乘后加:(之前写的,有点奇怪)

点击查看代码
struct tag {
    LL mul, add;
    tag(LL mul = 1, LL add = 0): mul(mul % P), add(add % P) {}
    friend tag operator+(tag b, tag a) {
        return tag(a.mul % P * b.mul % P, (a.mul % P * b.add % P + a.add % P) % P);
    }
    LL mix(LL sum, int l, int r) {
        return (sum % P * mul % P + add % P * (r - l + 1) % P) % P;
    }
};

可持久化线段树

单点修改,区间查询:

点击查看代码
template<int N,class A,int logN=23> struct exsegtree{
	A ans[N*logN];//logN=logn+2=23
	int tot,lc[N*logN],rc[N*logN];
	int newnode(int p=0){int q=++tot;return ans[q]=ans[p],lc[q]=lc[p],rc[q]=rc[p],q;}
	void pushup(int p){ans[p]=ans[lc[p]]+ans[rc[p]];}
	exsegtree():tot(0){lc[0]=rc[0]=0;}
	int build(A a[],int l=1,int r=N){
		int q=newnode();
		if(l==r) return ans[q]=a[l],q;
		int mid=(l+r)>>1;
		lc[q]=build(a,l,mid),rc[q]=build(a,mid+1,r);
		return pushup(q),q;
	}
	int modify(int p,int x,A k,int l=1,int r=N){
		int q=newnode(p);
		if(l==r) return ans[q]=k,q;
		int mid=(l+r)>>1;
		if(x<=mid) lc[q]=modify(lc[p],x,k,l,mid);
		else rc[q]=modify(rc[p],x,k,mid+1,r);
		return pushup(q),q; 
	}
	A query(int p,int L,int R,int l=1,int r=N){
		if(L<=l&&r<=R) return ans[p];
		int mid=(l+r)>>1;A res=A(0);
		if(L<=mid) res=query(lc[p],L,R,l,mid)+res;
		if(mid<R)  res=res+query(rc[p],L,R,mid+1,r);
		return res;
	}
};

特殊用法:查询异或关系

考虑一个问题:你有一棵 0/1 trie(其实也可以说是线段树),里面有很多数,现在给定 \(y,L,R\) 查询是否存在 \(x\) 使得 \(x\oplus y\in [L,R]\)

查看解法

考虑对 0/1 trie 进行 \(\oplus x\) 操作:对于第 \(i\) 层,如果 \(x\) 的第 \(i\) 位是 \(1\),就交换这一层所有的左右儿子。相当于是你将所有根节点到叶子的路径全都“异或 \(x\)”,换到了正确的位置、

做完这个操作之后,直接查询 \([L,R]\) 的区间和。复杂度是线段树复杂度 \(O(\log V)\)。实现的时候我们可以不去换它,在进入左右儿子的时候偷偷交换一下。

template<int N,int M=1073741823> struct segtree{
	int ans[N+10],tot,ch[N+10][2];
	int newnode(){int p=++tot;return ch[p][0]=ch[p][1]=0,ans[p]=-1e9,p;}
	segtree():tot(-1){newnode();}
	void modify(int x,int k,int &p,int l=0,int r=M){
		if(!p) p=newnode();
		if(l==r) return (void)(ans[p]=max(ans[p],k));
//		if(l==r) return printf("call modify(%d,%d,%d,%d,%d)\n",x,k,p,l,r),(void)(ans[p]=max(ans[p],k));
		int mid=(l+r)>>1;
		if(x<=mid) modify(x,k,ch[p][0],l,mid);
		else modify(x,k,ch[p][1],mid+1,r);
		ans[p]=max(ans[ch[p][0]],ans[ch[p][1]]);
	}
	int query(int L,int R,int x,int p,int l=0,int r=M){
//		if(!p||(L<=l&&r<=R)) return printf("call query(%d,%d,%d,%d,%d,%d), line 69, return %d\n",L,R,x,p,l,r,ans[p]),ans[p];
		if(!p||(L<=l&&r<=R)) return ans[p];
		int mid=(l+r)>>1,res=-1e9;
		if((r-l+1)>>1&x){
			if(L<=mid) res=max(res,query(L,R,x,ch[p][1],l,mid));
			if(mid<R) res=max(res,query(L,R,x,ch[p][0],mid+1,r));
		}else{
			if(L<=mid) res=max(res,query(L,R,x,ch[p][0],l,mid));
			if(mid<R) res=max(res,query(L,R,x,ch[p][1],mid+1,r));
		}
//		printf("call query(%d,%d,%d,%d,%d,%d), return %d\n",L,R,x,p,l,r,res);
		return res;
	}
};

单点修改区间查询

2n 空间非递归

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
int n, m, ans[1000010];
void modify(int x, int k) {
  for (x += n; x; x >>= 1) ans[x] += k;
}
int query(int l, int r) {
  int res = 0;
  for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
    if (l & 1) res += ans[l++];
    if (r & 1) res += ans[--r];
  }
  return res;
}
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);
#endif
  cin >> n >> m;
  for (int i = n; i < n * 2; i++) cin >> ans[i];
  for (int i = n - 1; i >= 1; i--) ans[i] = ans[i << 1] + ans[i << 1 | 1];
  while (m--) {
    int op, x, y;
    cin >> op >> x >> y;
    if (op == 1) modify(x - 1, y);
    else cout << query(x - 1, y) << endl;
  }
  return 0;
}

posted @ 2022-11-15 18:01  caijianhong  阅读(129)  评论(0编辑  收藏  举报