[ZJOI2019][LOJ3046]语言(线段树合并)

题面

https://loj.ac/problem/3046

题解

前置知识:

首先对求的值做一个转化:相当于对\(1 \leq i \leq n\),求出\(S[i]\)表示1~n中可以与i开展贸易的点数。

一个点j(\(\neq i\))与i能够开展贸易的充要条件是\(\exists x \in[1,m]\)使得路径\(path (s_x,t_x)\)通过点i,j。

因此,\(S [i]\)就是所有通过i的\(path(s_x,t_x)\),这些路径的并集中点的个数。也就是这些路径的端点形成的虚树。

性质:k个点\(u_1,u_2,…,u_k\)(按dfs序升序)形成的虚树的大小是\(\sum_{i=1}^kdep_{u_i}-\sum_{i=1}^{k} dep_{lca(u_i,u_{i + 1})}\),其中\(u_{k+1}=u_1\)

​ 可以使用dfs序证明,这里略去。

因此考虑对于原树中的每一个点u,维护\(f[u],g[u],left[u],right[u]\),使得:

  1. 所有通过u的“统一语言”路径,它们的端点按照dfs序排序后,形成\(v_1,v_2,…,v_k\)的序列。

  2. \(f[u]=\sum_{i=1}^{k}dep_{v_k}\)

  3. \(g[u]=\sum_{i=1}^{k-1}dep_{lca(v_k,v_{k+1})}\)

  4. \(left[u]=v_1,right[u]=v_k\)

其中\(left,right\)是用于支持合并以及统计答案。

实现时,可以使用线段树。首先对于每一个\(x \in [1,m]\)

  • \(s_x\)处打\((s_x,1)\)\((t_x,1)\)的标记。
  • \(t_x\)处打\((s_x,1)\)\((t_x,1)\)的标记。
  • \(lca(s_x,t_x)\)处打\((s_x,-1)\)\((t_x,-1)\)的标记。
  • \(lca(s_x,t_x)\)的父亲处打\((s_x,-1)\)\((t_x,-1)\)的标记。

然后,对于原树进行一次dfs,每一个原树上节点u对应的线段树首先是它所有子节点的线段树之并;

其次,按照u节点上打的每一个标记,对u对应的线段树进行更新

那么\(S[u]\)就是\(f[u]-g[u]-dep_{lca(left[u],right[u])}\)啦。

总时间复杂度\(O(n \log n)\)

代码

#include<bits/stdc++.h>

using namespace std;

#define rg register
#define In inline
#define ll long long

const int N = 1e5;
const int TN = 9 * 17 * N;

typedef pair<int,int>pii;

namespace IO{
	In int read(){
		int s = 0,ww = 1;
		char ch = getchar();
		while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
		while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
		return s * ww;
	}
	In void write(int x){
		if(x < 0)putchar('-'),x = -x;
		if(x > 9)write(x / 10);
		putchar('0' + x % 10);
	}
}
using namespace IO;

struct edge{
	int des,next;
}e[2*N+5];

int s[N+5],t[N+5],fa[N+5];
int head[N+5],in[N+5],dfn[N+5],D[N+5],E[2*N+5];
ll dep[2*N+5];
int cnt,En,dn;
int n,m;

In void addedge(int a,int b){
	cnt++;
	e[cnt].des = b;
	e[cnt].next = head[a];
	head[a] = cnt;
}

void dfs1(int u,int f){
	E[++En] = u;
	in[u] = En;
	D[++dn] = u;
	dfn[u] = dn;
	fa[u] = f;
	dep[u] = dep[fa[u]] + 1;
	for(rg int i = head[u];i;i = e[i].next){
		int v = e[i].des;
		if(v == f)continue;
		dfs1(v,u);
		E[++En] = u;	
	}
}

int lg[2*N+5];

struct ST{
	int m[2*N+5][21];
	void prepro(){
		for(rg int i = 2;i <= 2 * N;i++)lg[i] = lg[i>>1] + 1;
		for(rg int i = 1;i <= En;i++)m[i][0] = i;
		for(rg int j = 1;j <= 20;j++)
			for(rg int i = 1;i + (1<<(j-1)) <= En;i++){
				int x = m[i][j-1],y = m[i+(1<<(j-1))][j-1];
				m[i][j] = dep[E[x]] < dep[E[y]] ? x : y;
			}	
	}
	In int query(int l,int r){
		int d = lg[r-l+1];
		int x = m[l][d],y = m[r+1-(1<<d)][d];
		return dep[E[x]] < dep[E[y]] ? x : y;
	}
	In int lca(int u,int v){
		if(in[u] > in[v])swap(u,v);
		return E[query(in[u],in[v])];
	}
}S;

int rt[N+5];

struct SegTree{
	ll f[TN+5],g[TN+5];
	int left[TN+5],right[TN+5],lc[TN+5],rc[TN+5];
	int cnt;
	In void pushup(int u){
		int l = lc[u],r = rc[u];
		f[u] = f[l] + f[r];
		if(!f[l]){
			g[u] = g[r];
			left[u] = left[r];
			right[u] = right[r];
			return;
		}
		if(!f[r]){
			g[u] = g[l];
			left[u] = left[l];
			right[u] = right[l];
			return;
		}
		left[u] = left[l],right[u] = right[r];
		g[u] = g[l] + g[r] + dep[S.lca(D[right[l]],D[left[r]])]; 
	}
	In ll query(int u){
		if(!f[u])return 0;
		return f[u] - g[u] - dep[S.lca(D[right[u]],D[left[u]])];
	}
	void ud(int u,int l,int r,int x,ll d){
		if(l == r){
			f[u] += d * dep[D[x]];
			if(!f[u])left[u] = right[u] = f[u] = g[u] = 0;
			else{
				int n = f[u] / dep[D[x]];
				g[u] = (n - 1) * dep[D[x]];
				left[u] = right[u] = x;
			}
			return;
		}
		int m = (l + r) >> 1;
		if(x <= m){
			if(!lc[u])lc[u] = ++cnt;
			ud(lc[u],l,m,x,d);	
		}
		else{
			if(!rc[u])rc[u] = ++cnt;
			ud(rc[u],m + 1,r,x,d);
		} 
		pushup(u);
	}
	int merge(int u,int v,int l,int r){
		if(!u || !v)return u + v;
		if(l == r){
			f[u] += f[v];
			int n = f[u] / dep[D[l]];
			if(!n)g[u] = left[u] = right[u] = 0;
			else g[u] = 1ll * (n - 1) * dep[D[l]],left[u] = right[u] = l;
			return u;
 		}
		int m = (l + r) >> 1;
		lc[u] = merge(lc[u],lc[v],l,m);
		rc[u] = merge(rc[u],rc[v],m + 1,r);
		pushup(u);
		return u;
	}
}T;

vector<pii>v[N+5];
ll ans[N+5];

void dfs2(int u){
	rt[u] = ++T.cnt;
	for(rg int i = head[u];i;i = e[i].next){
		int v = e[i].des;
		if(v == fa[u])continue;
		dfs2(v);
		rt[u] = T.merge(rt[u],rt[v],1,n);
	}
	for(rg int i = 0;i < v[u].size();i++){
		int id = v[u][i].first,dx = v[u][i].second;
		T.ud(rt[u],1,n,dfn[s[id]],dx);
		T.ud(rt[u],1,n,dfn[t[id]],dx);
	}
	ans[u] = T.query(rt[u]);
}

int main(){
//	freopen("L3046.in","r",stdin);
//	freopen("L3046.out","w",stdout);
	n = read(),m = read();
	for(rg int i = 1;i < n;i++){
		int u = read(),v = read();
		addedge(u,v);
		addedge(v,u);
	}
	dfs1(1,0);
	S.prepro();
	for(rg int i = 1;i <= m;i++){
		s[i] = read(),t[i] = read();
		if(dfn[s[i]] > dfn[t[i]])swap(s[i],t[i]);
		v[s[i]].push_back(make_pair(i,1));
		v[t[i]].push_back(make_pair(i,1));
		int Lca = S.lca(s[i],t[i]);
		v[Lca].push_back(make_pair(i,-1));
		v[fa[Lca]].push_back(make_pair(i,-1));
	}
	dfs2(1);
	ll rt = 0;
	for(rg int i = 1;i <= n;i++)rt += ans[i];
	rt >>= 1;
	cout << rt << endl;
	return 0;
}
posted @ 2020-10-06 19:47  coder66  阅读(165)  评论(0编辑  收藏  举报