BZOJ_3159_决战

题目链接

分析:

我使用树剖+splay维护这个东西。
对每条重链维护一棵splay,链加和查询正常做,剩下的链反转如下。
由于一定是深度递增的一条链,我们树剖将它分成从左到右log个区间,提取出对应子树,插入到一个新的splay中。
然后打标记进行反转,将子树归还给log个区间。
时间复杂度\(O(nlogn^2)\)

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
#define N 200050
typedef long long ll;
#define ls ch[p][0]
#define rs ch[p][1]
#define db(x) cerr<<#x<<" = "<<x<<endl
#define Db(x) cerr<<#x<<endl
#define get(x) (ch[f[x]][1]==x)
int ch[N][2],f[N],sz1[N],rev[N];
int head[N],to[N<<1],nxt[N<<1],fa[N],son[N],top[N],sz2[N],dep[N],cnt,root;
int idx[N],pid[N],id[N],n,m,TOT,tot;
ll sum[N],tag[N],mn[N],mx[N],num[N];
struct Splay {
	int rt,bg,ed;
	int newnode() {
		int p=++tot; return p;
	}
	void init() {
		rt=newnode(); int p=newnode();
		ch[rt][1]=p; f[p]=rt; pushup(p); pushup(rt);
	}
	void pushup(int p) {
		sum[p]=sum[ls]+sum[rs]+num[p];
		mn[p]=min(mn[ls],min(mn[rs],num[p]));
		mx[p]=max(mx[ls],max(mx[rs],num[p]));
		sz1[p]=sz1[ls]+sz1[rs]+1;
	}
	void giv1(int p) {
		rev[p]^=1; swap(ls,rs);
	}
	void giv2(int p,ll d) {
		tag[p]+=d; sum[p]+=sz1[p]*d; mn[p]+=d; mx[p]+=d; num[p]+=d;
	}
	void pushdown(int p) {
		if(rev[p]) {
			if(ls) giv1(ls);
			if(rs) giv1(rs); 
			rev[p]=0;
		}
		if(tag[p]) {
			if(ls) giv2(ls,tag[p]); 
			if(rs) giv2(rs,tag[p]); 
			tag[p]=0;
		}
	}
	void UPD(int x) {
		if(x!=rt) UPD(f[x]);
		pushdown(x);
	}
	void rotate(int x) {
		int y=f[x],z=f[y],k=get(x);
		ch[y][k]=ch[x][!k]; f[ch[y][k]]=y;
		ch[x][!k]=y; f[y]=x; f[x]=z;
		if(z) ch[z][ch[z][1]==y]=x;
		if(y==rt) rt=x;
		pushup(y); pushup(x);
	}
	void splay(int x,int y) {
		UPD(x);
		for(int d;(d=f[x])!=y;rotate(x)) if(f[d]!=y) rotate(get(x)==get(d)?d:x);
	}
	int find(int x) {
		int p=rt;
		while(1) {
			pushdown(p);
			if(sz1[ls]>=x) p=ls;
			else {
				x-=sz1[ls]+1;
				if(!x) return p;
				p=rs;
			}
		}
	}
	int BUILD(int l,int r,int fa) {
		int mid=(l+r)>>1;
		int p=newnode();
		f[p]=fa;
		if(l<mid) ls=BUILD(l,mid-1,p);
		if(r>mid) rs=BUILD(mid+1,r,p);
		pushup(p);
		return p;
	}
	void build(int x) {
		rt=BUILD(1,x+2,0);	
	}
	void update(int x,int y,int z) {
		x=x-bg+1,y=y-bg+1;
		x=find(x),y=find(y+2);
		splay(x,0); splay(y,x);
		giv2(ch[y][0],z);
		pushup(y); pushup(x);
	}
	ll qsum(int x,int y) {
		x=x-bg+1,y=y-bg+1;
		x=find(x),y=find(y+2);
		splay(x,0); splay(y,x);
		//db(x),db(y);
		return sum[ch[y][0]];
	}
	ll qmin(int x,int y) {
		x=x-bg+1,y=y-bg+1;
		x=find(x),y=find(y+2);
		splay(x,0); splay(y,x);
		return mn[ch[y][0]];
	}
	ll qmax(int x,int y) {
		x=x-bg+1,y=y-bg+1;
		x=find(x),y=find(y+2);
		splay(x,0); splay(y,x);
		return mx[ch[y][0]];
	}
	int split(int x,int y) {
		x=x-bg+1,y=y-bg+1;
		x=find(x),y=find(y+2);
		splay(x,0); splay(y,x);
		int p=ch[y][0];
		f[p]=0,ch[y][0]=0; 
		pushup(y); pushup(x);
		return p;
	}
	void insert(int x,int p) {
		int y,t=x;
		x=find(t+1),y=find(t+2);
		splay(x,0); splay(y,x);
		ch[y][0]=p; f[p]=y; 
		pushup(y); pushup(x);
	}
	void findbug(int p) {
		pushdown(p);
		if(ls) findbug(ls);
		//printf("p=%d num[p]=%lld sz1[p]=%d sum=%lld\n",p,num[p],sz1[p],sum[p]);
		if(rs) findbug(rs);
	}
}G[N],TMP;
inline void add(int u,int v) {
	to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt;
}
void df1(int x,int y) {
	int i; sz2[x]=1; fa[x]=y;
	dep[x]=dep[y]+1;
	for(i=head[x];i;i=nxt[i]) if(to[i]!=y) {
		df1(to[i],x); sz2[x]+=sz2[to[i]];
		if(sz2[to[i]]>sz2[son[x]]) son[x]=to[i];
	}
}
void df2(int x,int t) {
	int i;
	top[x]=t;
	idx[x]=++idx[0]; pid[idx[0]]=x;
	if(son[x]) df2(son[x],t);
	for(i=head[x];i;i=nxt[i]) if(to[i]!=fa[x]&&to[i]!=son[x]) df2(to[i],to[i]);
}
char opt[12];
void INCREASE(int x,int y,int z) {
	while(top[x]!=top[y]) {
		if(dep[top[x]]>dep[top[y]]) swap(x,y);
		G[id[idx[y]]].update(idx[top[y]],idx[y],z);
		y=fa[top[y]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	G[id[idx[y]]].update(idx[y],idx[x],z);
}
ll SUM(int x,int y) {
	ll re=0;
	while(top[x]!=top[y]) {
		if(dep[top[x]]>dep[top[y]]) swap(x,y);
		re+=G[id[idx[y]]].qsum(idx[top[y]],idx[y]);
		y=fa[top[y]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	re+=G[id[idx[y]]].qsum(idx[y],idx[x]);
	return re;
}
ll MIN(int x,int y) {
	ll re=1ll<<60;
	while(top[x]!=top[y]) {
		if(dep[top[x]]>dep[top[y]]) swap(x,y);
		re=min(re,G[id[idx[y]]].qmin(idx[top[y]],idx[y]));
		y=fa[top[y]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	re=min(re,G[id[idx[y]]].qmin(idx[y],idx[x]));
	return re;
}
ll MAX(int x,int y) {
	ll re=0;
	while(top[x]!=top[y]) {
		if(dep[top[x]]>dep[top[y]]) swap(x,y);
		re=max(re,G[id[idx[y]]].qmax(idx[top[y]],idx[y]));
		y=fa[top[y]];
	}
	if(dep[x]<dep[y]) swap(x,y);
	re=max(re,G[id[idx[y]]].qmax(idx[y],idx[x]));
	return re;
}
int lca(int x,int y) {
	while(top[x]!=top[y]) {
		if(dep[top[x]]>dep[top[y]]) swap(x,y);
		y=fa[top[y]];
	}
	return dep[x]<dep[y]?x:y;
}
struct A {
	int l,r,id,p;
}a[N];
void clr(int p) {
	num[p]=tag[p]=rev[p]=mn[p]=mx[p]=sum[p]=ls=rs=f[p]=sz1[p]=0;
}
void INVERSE(int x,int y) {
	if(dep[x]>dep[y]) swap(x,y);
	int la=0;
	while(top[x]!=top[y]) {
		a[++la]=(A){idx[top[y]],idx[y],id[idx[y]],G[id[idx[y]]].split(idx[top[y]],idx[y])};
		y=fa[top[y]];
	}
	a[++la]=(A){idx[x],idx[y],id[idx[x]],G[id[idx[x]]].split(idx[x],idx[y])};
	int i;
	for(i=la;i;i--) {
		TMP.insert(sz1[TMP.rt]-2,a[i].p);
	}

	TMP.splay(1,0); TMP.splay(2,1);
	TMP.giv1(ch[2][0]); TMP.pushup(2); TMP.pushup(1);

	for(i=la;i;i--) {
		int p=TMP.split(1,a[i].r-a[i].l+1);
		G[a[i].id].insert(a[i].l-G[a[i].id].bg,p);
	}

	clr(1),clr(2); ch[1][1]=2; f[2]=1; sz1[1]=2; sz1[2]=1; TMP.rt=1;
}
int main() {
	mn[0]=1ll<<60;
	scanf("%d%d%d",&n,&m,&root);
	int i,x,y,j=0;
	for(i=1;i<n;i++) {
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	df1(root,0),df2(root,root);
	for(i=1;i<=n;i++) {
		if(top[pid[i]]!=top[pid[i-1]]) j++,G[j].bg=i;
		id[i]=j; G[j].ed=i;
	}
	TMP.init(); TMP.bg=1;
	TOT=j;
	for(i=1;i<=TOT;i++) {
		G[i].build(G[i].ed-G[i].bg+1);
	}
	int z;
	for(i=1;i<=m;i++) {
		scanf("%s",opt);
		if(opt[0]=='I') {
			if(opt[2]=='c') {
				scanf("%d%d%d",&x,&y,&z);
				INCREASE(x,y,z);
			}else {
				scanf("%d%d",&x,&y);
				INVERSE(x,y);
			}
		}else if(opt[0]=='S') {
			scanf("%d%d",&x,&y);
			printf("%lld\n",SUM(x,y));
		}else {
			if(opt[1]=='a') {
				scanf("%d%d",&x,&y);
				printf("%lld\n",MAX(x,y));
			}else {
				scanf("%d%d",&x,&y);
				printf("%lld\n",MIN(x,y));
			}
		}
	}
}
/*
5 8 1
1 2
2 3
3 4
4 5
Sum 2 4
Increase 3 5 3
Minor 1 4
Sum 4 5
Invert 1 3
Major 1 2
Increase 1 5 2
Sum 1 5
*/

posted @ 2018-11-19 22:28  fcwww  阅读(221)  评论(0编辑  收藏  举报