题解 黑白树

传送门

神仙题,并且我有理由怀疑是新一代 lxl 搞的

  • LCT 不支持子树加!

发现 LCT 能且仅能维护出这个连通块,但无法维护和修改相关的任何信息
那么考虑不用 LCT
直接放题解神仙思路好了
将每个同色连通块中深度最浅的点称作 管辖点
那么每次对同色连通块的修改对应的是这个连通块管辖点 dfs 序中的几个子区间
那么考虑魔改线段树
对每个线段树节点维护出它在哪些异色点的子树中
这个数量级比较大,可以用一个类似标记永久化的东西
只在每个恰好被包含的地方的 set 插入,取的时候从根下来带着一个值走就行了
那么可以对每个线段树节点维护出它到根的链上的异色点个数,记为 cnt
发现能 pushup/spread 的条件是子区间的 cnt 等于当前节点的 cnt
发现修改时能向下递归的条件是若查出管辖点的 cnt,记为 lim
那么向下递归的条件是子区间的 cnt \(\leqslant\) lim
查询同理
那么找管辖点可以利用线段树 set 查出到根节点路径上深度最深的异色点
向下找一个儿子就可以了
剩下的操作可以直接树剖,线段树上要再维护一个无条件下传的标记
完全没必要对每种颜色单独开一棵线段树,全放在一棵线段树中反色操作会方便很多
不考虑其它操作,仅线段树+反色的复杂度是 \(O(n\log^2 n)\),复杂度瓶颈在于线段树+set
其实加上其他操作复杂度也还是 \(O(n\log^2 n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 200010
#define fir first
#define sec second
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
ll val[N];
int head[N], col[N], deg[N], ecnt, endpos;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}

namespace force{
	int back[N], dep[N], sta[N], top;
	void dfs(int u, int fa) {
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dep[v]=dep[u]+1;
			back[v]=u;
			dfs(v, u);
		}
	}
	void get_block(int u, int fa) {
		sta[++top]=u;
		for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa && col[e[i].to]==col[u])
			get_block(e[i].to, u);
	}
	void add(int u, int fa, int t) {
		val[u]+=t;
		for (int i=head[u]; ~i; i=e[i].next)
			if (e[i].to!=fa) add(e[i].to, u, t);
	}
	void solve() {
		ll ans;
		dfs(1, 0);
		for (int i=1,op,u,x,y,z; i<=m; ++i) {
			op=read();
			if (op==1) col[read()]^=1;
			else if (op==2) {
				x=read(); y=read(); top=0;
				get_block(x, 0);
				for (int j=1; j<=top; ++j) val[sta[j]]+=y;
			}
			else if (op==3) {
				x=read(); ans=top=0;
				get_block(x, 0);
				for (int j=1; j<=top; ++j) ans=max(ans, val[sta[j]]);
				printf("%lld\n", ans);
			}
			else if (op==4) {
				x=read(); y=read(); z=read();
				while (1) {
					if (dep[x]<dep[y]) swap(x, y);
					val[x]+=z;
					if (x==y) break;
					x=back[x];
				}
			}
			else {
				x=read(); y=read();
				add(x, back[x], y);
			}
		}
	}
}

namespace task1{
	int bit[N], id[N], rk[N], tot;
	inline void upd(int i, int dat) {for (; i<=n; i+=i&-i) bit[i]+=dat;}
	inline int query(int i) {int ans=0; for (; i; i-=i&-i) ans+=bit[i]; return ans;}
	int tl[N<<2], tr[N<<2]; ll mx[N<<2], tag[N<<2];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define mx(p) mx[p]
	#define tag(p) tag[p]
	#define pushup(p) mx(p)=max(mx(p<<1), mx(p<<1|1))
	void spread(int p) {
		if (!tag(p)) return ;
		mx(p<<1)+=tag(p); tag(p<<1)+=tag(p);
		mx(p<<1|1)+=tag(p); tag(p<<1|1)+=tag(p);
		tag(p)=0;
	}
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r;
		if (l==r) {mx(p)=val[rk[l]]; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	void upd(int p, int l, int r, int dat) {
		if (l<=tl(p)&&r>=tr(p)) {mx(p)+=dat; tag(p)+=dat; return ;}
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid) upd(p<<1, l, r, dat);
		if (r>mid) upd(p<<1|1, l, r, dat);
		pushup(p);
	}
	ll query(int p, int l, int r) {
		if (l<=tl(p)&&r>=tr(p)) return mx(p);
		spread(p);
		int mid=(tl(p)+tr(p))>>1; ll ans=0;
		if (l<=mid) ans=max(ans, query(p<<1, l, r));
		if (r>mid) ans=max(ans, query(p<<1|1, l, r));
		return ans;
	}
	pair<int, int> get_block(int x) {
		pair<int, int> ans;
		if (col[x]) {
			x=id[x];
			int l, r, mid;
			l=x; r=n;
			while (l<=r) {
				mid=(l+r)>>1;
				if (query(mid)-query(x)==mid-x) l=mid+1;
				else r=mid-1;
			}
			ans.sec=l-1;
			l=1, r=x;
			while (l<=r) {
				mid=(l+r)>>1;
				if (query(x)-query(mid-1)==x-mid+1) r=mid-1;
				else l=mid+1;
			}
			ans.fir=r+1;
		}
		else {
			x=id[x];
			int l, r, mid;
			l=x; r=n;
			while (l<=r) {
				mid=(l+r)>>1;
				if (query(mid)-query(x)==0) l=mid+1;
				else r=mid-1;
			}
			ans.sec=l-1;
			l=1, r=x;
			while (l<=r) {
				mid=(l+r)>>1;
				if (query(x)-query(mid-1)==0) r=mid-1;
				else l=mid+1;
			}
			ans.fir=r+1;
		}
		return ans;
	}
	void dfs(int u, int fa) {
		rk[id[u]=++tot]=u;
		for (int i=head[u]; ~i; i=e[i].next)
			if (e[i].to!=fa) dfs(e[i].to, u);
	}
	void solve() {
		ll ans;
		dfs(endpos, 0);
		build(1, 1, n);
		for (int i=1; i<=n; ++i) if (col[i]) upd(id[i], 1);
		// cout<<"id: "; for (int i=1; i<=n; ++i) cout<<id[i]<<' '; cout<<endl;
		for (int i=1,op,u,x,y,z; i<=m; ++i) {
			op=read();
			if (op==1) {
				x=read();
				if (col[x]) upd(id[x], -1);
				else upd(id[x], 1);
				col[x]^=1;
			}
			else if (op==2) {
				x=read(); y=read();
				pair<int, int> range=get_block(x);
				upd(1, range.fir, range.sec, y);
			}
			else if (op==3) {
				x=read(); ans=0;
				pair<int, int> range=get_block(x);
				printf("%lld\n", query(1, range.fir, range.sec));
			}
			else if (op==4) {
				x=read(); y=read(); z=read();
				if (id[x]>id[y]) swap(x, y);
				upd(1, id[x], id[y], z);
			}
			else {
				x=read(); y=read();
				if (x==1) upd(1, 1, n, y);
				else if (id[x]>id[1]) upd(1, id[x], n, y);
				else upd(1, 1, id[x], y);
			}
		}
	}
}

namespace task2{
	int back[N], dep[N], sta[N], top[N], rk[N], stop;
	int id[N], siz[N], msiz[N], mson[N], tot;
	int tl[N<<2], tr[N<<2]; ll mx[N<<2], tag[N<<2];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define mx(p) mx[p]
	#define tag(p) tag[p]
	#define pushup(p) mx(p)=max(mx(p<<1), mx(p<<1|1))
	void spread(int p) {
		if (!tag(p)) return ;
		mx(p<<1)+=tag(p); tag(p<<1)+=tag(p);
		mx(p<<1|1)+=tag(p); tag(p<<1|1)+=tag(p);
		tag(p)=0;
	}
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r;
		if (l==r) {mx(p)=val[rk[l]]; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	void upd(int p, int l, int r, int dat) {
		if (l<=tl(p)&&r>=tr(p)) {mx(p)+=dat; tag(p)+=dat; return ;}
		spread(p);
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid) upd(p<<1, l, r, dat);
		if (r>mid) upd(p<<1|1, l, r, dat);
		pushup(p);
	}
	ll query(int p, int l, int r) {
		if (l<=tl(p)&&r>=tr(p)) return mx(p);
		spread(p);
		int mid=(tl(p)+tr(p))>>1; ll ans=0;
		if (l<=mid) ans=max(ans, query(p<<1, l, r));
		if (r>mid) ans=max(ans, query(p<<1|1, l, r));
		return ans;
	}
	void dfs1(int u, int fa) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dep[v]=dep[u]+1;
			back[v]=u;
			dfs1(v, u);
			siz[u]+=siz[v];
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
	}
	void dfs2(int u, int fa, int t) {
		top[u]=t;
		rk[id[u]=++tot]=u;
		if (!mson[u]) return ;
		dfs2(mson[u], u, t);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa||v==mson[u]) continue;
			dfs2(v, u, v);
		}
	}
	void get_block(int u, int fa) {
		sta[++stop]=u;
		for (int i=head[u]; ~i; i=e[i].next) if (e[i].to!=fa && col[e[i].to]==col[u])
			get_block(e[i].to, u);
	}
	void upd(int u, int v, int dat) {
		while (top[u]!=top[v]) {
			if (dep[top[u]]<dep[top[v]]) swap(u, v);
			upd(1, id[top[u]], id[u], dat);
			u=back[top[u]];
		}
		if (dep[u]>dep[v]) swap(u, v);
		upd(1, id[u], id[v], dat);
	}
	void solve() {
		ll ans;
		dfs1(1, 0); dfs2(1, 0, 1); build(1, 1, n);
		for (int i=1,op,u,x,y,z; i<=m; ++i) {
			op=read();
			if (op==1) col[read()]^=1;
			else if (op==2) {
				x=read(); y=read(); stop=0;
				get_block(x, 0);
				for (int j=1; j<=stop; ++j) upd(1, id[sta[j]], id[sta[j]], y);
			}
			else if (op==3) {
				x=read(); ans=stop=0;
				get_block(x, 0);
				for (int j=1; j<=stop; ++j) ans=max(ans, query(1, id[sta[j]], id[sta[j]]));
				printf("%lld\n", ans);
			}
			else if (op==4) {
				x=read(); y=read(); z=read();
				upd(x, y, z);
			}
			else {
				x=read(); y=read();
				upd(1, id[x], id[x]+siz[x]-1, y);
			}
		}
	}
}

namespace task{
	bool vis[N];
	int id[N], siz[N], msiz[N], mson[N], tot;
	int back[N], dep[N], sta[N], top[N], rk[N], tval[N], stop;
	struct seg{
		bool inuse[N<<2];
		set<pair<int, int>> st[N<<2];
		int tl[N<<2], tr[N<<2], cnt[N<<2];
		ll mx[N<<2], stag[N<<2], btag[N<<2], tag[N<<2];
		#define tl(p) tl[p]
		#define tr(p) tr[p]
		#define mx(p) mx[p]
		#define st(p) st[p]
		#define cnt(p) cnt[p]
		#define tag(p) tag[p]
		#define stag(p) stag[p]
		#define btag(p) btag[p]
		#define inuse(p) inuse[p]
		#undef pushup
		void pushup(int p) {
			mx(p)=0;
			if (cnt(p<<1)==cnt(p)) mx(p)=max(mx(p), mx(p<<1));
			if (cnt(p<<1|1)==cnt(p)) mx(p)=max(mx(p), mx(p<<1|1));
		}
		void spread(int p) {
			// cout<<"spread: "<<p<<endl;
			if (stag(p)) {
				cnt(p<<1)+=stag(p); stag(p<<1)+=stag(p);
				cnt(p<<1|1)+=stag(p); stag(p<<1|1)+=stag(p);
				stag(p)=0;
			}
			if (btag(p)) {
				// cout<<cnt(p<<1|1)<<' '<<cnt(p)<<endl;
				if (cnt(p<<1)==cnt(p)) mx(p<<1)+=btag(p), btag(p<<1)+=btag(p);
				if (cnt(p<<1|1)==cnt(p)) mx(p<<1|1)+=btag(p), btag(p<<1|1)+=btag(p);
				btag(p)=0;
			}
			if (tag(p)) {
				if (inuse(p<<1)) mx(p<<1)+=tag(p), tag(p<<1)+=tag(p);
				if (inuse(p<<1|1)) mx(p<<1|1)+=tag(p), tag(p<<1|1)+=tag(p);
				tag(p)=0;
			}
		}
		void build(int p, int l, int r) {
			tl(p)=l; tr(p)=r; inuse(p)=1;
			if (l==r) {mx(p)=tval[rk[l]]; inuse(p)=vis[rk[l]]; return ;}
			int mid=(l+r)>>1;
			build(p<<1, l, mid);
			build(p<<1|1, mid+1, r);
			pushup(p);
		}
		void upd(int p, int l, int r, int dat) {
			if (l<=tl(p)&&r>=tr(p)) {mx(p)+=dat; tag(p)+=dat; return ;}
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (l<=mid) upd(p<<1, l, r, dat);
			if (r>mid) upd(p<<1|1, l, r, dat);
			pushup(p);
		}
		void bupd(int p, int l, int r, int dat, int lim) {
			if (l<=tl(p)&&r>=tr(p)) {mx(p)+=dat; btag(p)+=dat; return ;}
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (l<=mid&&cnt(p<<1)<=lim) bupd(p<<1, l, r, dat, lim);
			if (r>mid&&cnt(p<<1|1)<=lim) bupd(p<<1|1, l, r, dat, lim);
			pushup(p);
		}
		void cover(int p, int l, int r, pair<int, int> dat) {
			if (l<=tl(p)&&r>=tr(p)) {st[p].insert(dat); ++cnt(p); ++stag(p); return ;}
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (l<=mid) cover(p<<1, l, r, dat);
			if (r>mid) cover(p<<1|1, l, r, dat);
			pushup(p);
		}
		void uncover(int p, int l, int r, pair<int, int> dat) {
			if (l<=tl(p)&&r>=tr(p)) {st[p].erase(dat); --cnt(p); --stag(p); return ;}
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (l<=mid) uncover(p<<1, l, r, dat);
			if (r>mid) uncover(p<<1|1, l, r, dat);
			pushup(p);
		}
		int qtop(int p, int pos, pair<int, int> dat) {
			if (st[p].size() && st[p].rbegin()->fir>dat.fir) dat=*st[p].rbegin();
			if (tl(p)==tr(p)) return dat.sec;
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (pos<=mid) return qtop(p<<1, pos, dat);
			else return qtop(p<<1|1, pos, dat);
		}
		int qcnt(int p, int pos) {
			if (tl(p)==tr(p)) return cnt(p);
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (pos<=mid) return qcnt(p<<1, pos);
			else return qcnt(p<<1|1, pos);
		}
		ll query(int p, int l, int r, int lim) {
			if (l<=tl(p)&&r>=tr(p)) return mx(p);
			spread(p);
			int mid=(tl(p)+tr(p))>>1; ll ans=0;
			if (l<=mid&&cnt(p<<1)<=lim) ans=max(ans, query(p<<1, l, r, lim));
			if (r>mid&&cnt(p<<1|1)<=lim) ans=max(ans, query(p<<1|1, l, r, lim));
			return ans;
		}
		ll reset(int p, int pos) {
			if (tl(p)==tr(p)) {int ans=mx(p); mx(p)=0; inuse(p)=0; return ans;}
			spread(p);
			int mid=(tl(p)+tr(p))>>1, ans;
			if (pos<=mid) ans=reset(p<<1, pos);
			else ans=reset(p<<1|1, pos);
			pushup(p);
			return ans;
		}
		void setup(int p, int pos, int val) {
			if (tl(p)==tr(p)) {mx(p)=val; inuse(p)=1; return ;}
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (pos<=mid) setup(p<<1, pos, val);
			else setup(p<<1|1, pos, val);
			pushup(p);
		}
		void show(int p) {
			cout<<setw(2)<<p<<" ["<<tl(p)<<','<<tr(p)<<"]: "<<"cnt("<<cnt(p)<<"), mx("<<mx(p)<<")"<<endl;
			if (tl(p)==tr(p)) return ;
			spread(p);
			show(p<<1); show(p<<1|1);
		}
	}seg[2];
	void dfs1(int u, int fa) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dep[v]=dep[u]+1;
			back[v]=u;
			dfs1(v, u);
			siz[u]+=siz[v];
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
	}
	void dfs2(int u, int fa, int t) {
		top[u]=t;
		rk[id[u]=++tot]=u;
		if (!mson[u]) return ;
		dfs2(mson[u], u, t);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa||v==mson[u]) continue;
			dfs2(v, u, v);
		}
	}
	void upd(int u, int v, int dat) {
		while (top[u]!=top[v]) {
			if (dep[top[u]]<dep[top[v]]) swap(u, v);
			seg[0].upd(1, id[top[u]], id[u], dat);
			seg[1].upd(1, id[top[u]], id[u], dat);
			u=back[top[u]];
		}
		if (dep[u]>dep[v]) swap(u, v);
		seg[0].upd(1, id[u], id[v], dat);
		seg[1].upd(1, id[u], id[v], dat);
	}
	int anc(int u, int v) {
		while (top[u]!=top[v]) {
			if (back[top[v]]==u) return top[v];
			// cout<<"v: "<<v<<' '<<top[v]<<endl;
			v=back[top[v]];
		}
		return rk[id[u]+1];
	}
	void solve() {
		ll ans;
		dep[1]=1; dfs1(1, 0); dfs2(1, 0, 1);
		for (int i=1; i<=n; ++i) {tval[i]=!col[i]?val[i]:0; vis[i]=!col[i]?1:0;} seg[0].build(1, 1, n);
		for (int i=1; i<=n; ++i) {tval[i]=col[i]?val[i]:0; vis[i]=col[i]?1:0;} seg[1].build(1, 1, n);
		for (int i=1; i<=n; ++i) seg[col[i]^1].cover(1, id[i], id[i]+siz[i]-1, {dep[i], i});
		// for (int i=1; i<=n; ++i) cout<<seg[col[i]].qtop(1, id[i], {0, 1})<<' '; cout<<endl;
		// for (int i=1; i<=n; ++i) cout<<qcnt(1, id[i])<<endl;
		// cout<<"id: "; for (int i=1; i<=n; ++i) cout<<id[i]<<' '; cout<<endl;
		// seg[0].show(1);
		// seg[0].bupd(1, id[3], id[3]+siz[3]-1, 2, 3); //, cout<<id[3]<<' '<<id[3]+siz[3]-1<<endl;
		// cout<<seg[0].query(1, id[2], id[2]+siz[2]-1)<<endl;
		for (int i=1,op,u,x,y,z,top; i<=m; ++i) {
			// cout<<"i: "<<i<<endl;
			op=read();
			if (op==1) {
				x=read();
				y=seg[col[x]].reset(1, id[x]);
				seg[col[x]^1].uncover(1, id[x], id[x]+siz[x]-1, {dep[x], x});
				col[x]^=1;
				seg[col[x]^1].cover(1, id[x], id[x]+siz[x]-1, {dep[x], x});
				seg[col[x]].setup(1, id[x], y);
			}
			else if (op==2) {
				x=read(); y=read();
				// cout<<"qtop: "<<seg[col[x]].qtop(1, id[x], {0, 0})<<endl;
				top=anc(seg[col[x]].qtop(1, id[x], {0, 0}), x);
				// cout<<"top: "<<top<<endl;
				seg[col[x]].bupd(1, id[top], id[top]+siz[top]-1, y, seg[col[x]].qcnt(1, id[top]));
			}
			else if (op==3) {
				x=read();
				// cout<<"qtop: "<<seg[col[x]].qtop(1, id[x], {0, 0})<<endl;
				top=anc(seg[col[x]].qtop(1, id[x], {0, 0}), x);
				// cout<<"top: "<<top<<endl;
				// cout<<"cnt: "<<seg[col[x]].qcnt(1, id[top])<<endl;
				printf("%lld\n", seg[col[x]].query(1, id[top], id[top]+siz[top]-1, seg[col[x]].qcnt(1, id[top])));
			}
			else if (op==4) {
				x=read(); y=read(); z=read();
				upd(x, y, z);
			}
			else {
				x=read(); y=read();
				seg[0].upd(1, id[x], id[x]+siz[x]-1, y);
				seg[1].upd(1, id[x], id[x]+siz[x]-1, y);
			}
		}
	}
}

signed main()
{
	freopen("astill.in", "r", stdin);
	freopen("astill.out", "w", stdout);

	n=read(); m=read();
	memset(head, -1, sizeof(head));
	bool is_chain=1;
	for (int i=1,u,v; i<n; ++i) {
		u=read(); v=read();
		add(u, v); add(v, u);
		++deg[u]; ++deg[v];
	}
	for (int i=1; i<=n; ++i) col[i]=read();
	for (int i=1; i<=n; ++i) val[i]=read();
	for (int i=1; i<=n; ++i)
		if (deg[i]<=1) endpos=i;
		else if (deg[i]>2) is_chain=0;
	// if (n<=1000&&m<=1000) force::solve();
	// else if (is_chain) task1::solve();
	// else task2::solve();
	task::solve();

	return 0;
}
posted @ 2022-03-27 18:12  Administrator-09  阅读(4)  评论(0编辑  收藏  举报