题解 校门外歪脖树上的鸽子

传送门

毒瘤数据结构

正解思路很特别
对于一个闭区间修改 \([l, r]\),将其写成开区间 \((l-1, r+1)\)
于是类似zkw线段树,我们发现在原树上向上跳链(到lca的孙子辈)的过程中应该修改的节点恰好是访问到的节点的兄弟
于是分成左链和右链分别树剖,每个节点维护其兄弟的信息
然后因为写成了开区间,需要加两个虚点0和n+1
发现这两个虚点恰好使得包含边界的修改可以打在相应的节点上了(相当于提高了边界的优先级/为边界提供了兄弟)
于是可以树剖修改了
复杂度 \(O(nlog^2n)\)

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 400010
#define ll long long
#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, m;
int tl[N], tr[N], ls[N], rs[N], rot;
bool vis[N];

namespace force{
	ll dat[N];
	void build(int p) {
		if (!ls[p]) {tl[p]=tr[p]=p; return ;}
		build(ls[p]); build(rs[p]);
		tl[p]=tl[ls[p]]; tr[p]=tr[rs[p]];
		// cout<<"p: "<<p<<' '<<tl[p]<<' '<<tr[p]<<endl;
	}
	void upd(int p, int l, int r, int val) {
		if (l<=tl[p]&&r>=tr[p]) {dat[p]+=1ll*val*(tr[p]-tl[p]+1); return ;}
		if (l<=tr[ls[p]]) upd(ls[p], l, r, val);
		if (r>=tl[rs[p]]) upd(rs[p], l, r, val);
	}
	ll query(int p, int l, int r) {
		if (l<=tl[p]&&r>=tr[p]) return dat[p];
		ll ans=0;
		if (l<=tr[ls[p]]) ans+=query(ls[p], l, r);
		if (r>=tl[rs[p]]) ans+=query(rs[p], l, r);
		return ans;
	}
	void solve() {
		build(rot);
		for (int i=1,l,r,d; i<=m; ++i) {
			if (read()&1) {
				l=read(); r=read(); d=read();
				upd(rot, l, r, d);
			}
			else {
				l=read(); r=read();
				printf("%lld\n", query(rot, l, r));
			}
		}
		exit(0);
	}
}

namespace task{
	int dep[N], top[N], id[N], rk[N], siz[N], siz2[N], fa[N], mson[N], tot, num;
	ll extra;
	struct seg{
		bool whitch; // 1->left
		int tl[N<<2], tr[N<<2];
		ll tag[N<<2], k[N<<2], dat[N<<2];
		seg(bool t):whitch(t){}
		#define tl(p) tl[p]
		#define tr(p) tr[p]
		#define tag(p) tag[p]
		#define k(p) k[p]
		#define dat(p) dat[p]
		#define pushup(p) dat(p)=dat(p<<1)+dat(p<<1|1)
		void spread(int p) {
			if (!tag(p)) return ;
			dat(p<<1)+=tag(p)*k(p<<1); tag(p<<1)+=tag(p);
			dat(p<<1|1)+=tag(p)*k(p<<1|1); tag(p<<1|1)+=tag(p);
			tag(p)=0;
		}
		void build(int p, int l, int r) {
			tl(p)=l; tr(p)=r;
			if (l==r) {k(p)=whitch?(ls[fa[rk[l]]]==rk[l]?siz2[rs[fa[rk[l]]]]:0):(rs[fa[rk[l]]]==rk[l]?siz2[ls[fa[rk[l]]]]:0); return ;}
			int mid=(l+r)>>1;
			build(p<<1, l, mid);
			build(p<<1|1, mid+1, r);
			k(p)=k(p<<1)+k(p<<1|1);
		}
		void upd(int p, int l, int r, ll val) {
			if (l<=tl(p)&&r>=tr(p)) {dat(p)+=val*k(p); tag(p)+=val; return ;}
			spread(p);
			int mid=(tl(p)+tr(p))>>1;
			if (l<=mid) upd(p<<1, l, r, val);
			if (r>mid) upd(p<<1|1, l, r, val);
			pushup(p);
		}
		ll query(int p, int l, int r) {
			if (l<=tl(p)&&r>=tr(p)) return dat(p);
			spread(p);
			int mid=(tl(p)+tr(p))>>1; ll ans=0;
			if (l<=mid) ans+=query(p<<1, l, r);
			if (r>mid) ans+=query(p<<1|1, l, r);
			return ans;
		}
	}left(1), right(0);
	void dfs1(int u) {
		// cout<<"dfs1: "<<u<<endl;
		siz[u]=u<=n+1; siz2[u]=u&&u<=n;
		if (~ls[u]) {
			dep[ls[u]]=dep[u]+1; fa[ls[u]]=u;
			dfs1(ls[u]);
			siz[u]+=siz[ls[u]]; siz2[u]+=siz2[ls[u]];
		}
		if (~rs[u]) {
			dep[rs[u]]=dep[u]+1; fa[rs[u]]=u;
			dfs1(rs[u]);
			siz[u]+=siz[rs[u]]; siz2[u]+=siz2[rs[u]];
		}
		mson[u]=siz[ls[u]]>siz[rs[u]]?ls[u]:rs[u];
	}
	void dfs2(int u, int t) {
		top[u]=t;
		id[u]=++tot; rk[tot]=u;
		if (ls[u]==-1 && rs[u]==-1) return ;
		if (mson[u]==ls[u]) dfs2(ls[u], t), dfs2(rs[u], rs[u]);
		else dfs2(rs[u], t), dfs2(ls[u], ls[u]);
	}
	int lca(int a, int b) {
		while (top[a]!=top[b]) {
			if (dep[top[a]]<dep[top[b]]) swap(a, b);
			a=fa[top[a]];
		}
		return dep[a]<dep[b]?a:b;
	}
	void ladd(int s, int t, ll val) {
		// cout<<"ladd: "<<s<<' '<<t<<' '<<val<<endl;
		while (top[s]!=top[t]) {
			left.upd(1, id[top[s]], id[s], val);
			// cout<<"show: "<<s<<' '<<top[s]<<endl;
			// cout<<"add: "<<id[top[s]]<<' '<<id[s]<<endl;
			s=fa[top[s]];
		}
		if (id[t]+1<=id[s]) left.upd(1, id[t]+1, id[s], val); //, cout<<"add2: "<<id[t]+1<<' '<<id[s]<<endl;
	}
	void radd(int s, int t, ll val) {
		// cout<<"radd: "<<s<<' '<<t<<' '<<val<<endl;
		while (top[s]!=top[t]) {
			right.upd(1, id[top[s]], id[s], val);
			// cout<<"add: "<<id[top[s]]<<' '<<id[s]<<endl;
			s=fa[top[s]];
		}
		if (id[t]+1<=id[s]) right.upd(1, id[t]+1, id[s], val); //, cout<<"add2: "<<id[t]+1<<' '<<id[s]<<endl;
	}
	ll lqsum(int s, int t) {
		ll ans=0;
		while (top[s]!=top[t]) {
			ans+=left.query(1, id[top[s]], id[s]);
			// cout<<"ans: "<<ans<<endl;
			s=fa[top[s]];
		}
		if (id[t]+1<=id[s]) ans+=left.query(1, id[t]+1, id[s]);
		return ans;
	}
	ll rqsum(int s, int t) {
		// cout<<"rqsum: "<<s<<' '<<t<<endl;
		ll ans=0;
		while (top[s]!=top[t]) {
			ans+=right.query(1, id[top[s]], id[s]);
			s=fa[top[s]];
		}
		if (id[t]+1<=id[s]) ans+=right.query(1, id[t]+1, id[s]);
		return ans;
	}
	void upd(int l, int r, ll val) {
		if (l<1&&r>n) {extra+=val; return ;}
		// cout<<"upd: "<<l<<' '<<r<<' '<<val<<endl;
		int t=lca(l, r);
		// cout<<"lca: "<<t<<endl;
		ladd(l, ls[t], val); radd(r, rs[t], val);
	}
	ll query(int l, int r) {
		if (l<1&&r>n) return extra*n;
		// cout<<"query: "<<l<<' '<<r<<endl;
		int t=lca(l, r);
		// cout<<"sum: "<<lqsum(l, ls[t])<<' '<<rqsum(r, rs[t])<<endl;
		return lqsum(l, ls[t])+rqsum(r, rs[t]);
	}
	void solve() {
		int num=n*2;
		++num; ls[num]=0; rs[num]=rot; rot=num;
		++num; ls[num]=rot; rs[num]=n+1; rot=num;
		dep[rot]=1; dfs1(rot); dfs2(rot, rot);
		// cout<<"id: "; for (int i=0; i<=num; ++i) cout<<id[i]<<' '; cout<<endl;
		left.build(1, 1, num); right.build(1, 1, num);
		for (int i=1,l,r,d; i<=m; ++i) {
			if (read()&1) {
				l=read(); r=read(); d=read();
				upd(l-1, r+1, d);
			}
			else {
				l=read(); r=read();
				printf("%lld\n", query(l-1, r+1));
			}
		}
		// cout<<"qval: "<<right.query(1, 4, 4)<<endl;
		// cout<<"qval: "<<right.query(1, 8, 8)<<endl;
		exit(0);
	}
}

signed main()
{
	freopen("pigeons.in", "r", stdin);
	freopen("pigeons.out", "w", stdout);

	n=read(); m=read();
	memset(ls, -1, sizeof(ls));
	memset(rs, -1, sizeof(rs));
	for (int i=1,t1,t2; i<n; ++i) {
		t1=read(); t2=read();
		if (t1>n) ++t1; if (t2>n) ++t2;
		vis[ls[n+i+1]=t1]=1, vis[rs[n+i+1]=t2]=1;
	}
	for (int i=1; i<=2*n; ++i) if (i!=n+1 && !vis[i]) rot=i;
	// force::solve();
	task::solve();

	return 0;
}
posted @ 2021-11-08 20:51  Administrator-09  阅读(1)  评论(0编辑  收藏  举报