FZOJ 4267 树上统计

对每一条边单独考虑其贡献。

我们现在想算有多少连续区间跨过了这条边。正难则反,我们考虑有多少连续区间不跨过这条边,最后用总的减去这些。所以我们只需要计算当前子树内和子树外的连续区间的贡献就行。

考虑 dsu on tree。子树内用并查集维护并计算极长子段的贡献,子树外用set维护其补集并计算贡献即可。

代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<set>
#define int long long
#define S_IT set <int> ::iterator

using namespace std;

const int N=100009;
set <int> S;
int head[N],cnt,bin[N],son[N],siz[N],L[N],R[N],Index,rev[N],n,ans,ans_In,ans_Out;
struct Edge
{
	int nxt,to;
}g[N*2];
struct Union
{
	int fa[N];
	
	int find(int x)
	{
		if(fa[x]==x)
			return x;
		return fa[x]=find(fa[x]);
	}
	
	void merge(int x,int y)
	{
		int X=find(x),Y=find(y);
		if(X!=Y)
		{
			fa[Y]=X;
			bin[X]+=bin[Y];
		}
	}
}B;

void add(int from,int to)
{
	g[++cnt].nxt=head[from];
	g[cnt].to=to;
	head[from]=cnt;
}

void init()
{
	scanf("%lld",&n);
	for (int x,y,i=1;i<n;i++)
		scanf("%lld %lld",&x,&y),
		add(x,y),add(y,x);
}

void dfs(int x,int fa)
{
	siz[x]=1,L[x]=++Index,rev[Index]=x;
//	printf("%lld %lld\n",x,fa);
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa)
			continue;
		dfs(v,x);
		siz[x]+=siz[v];
		if(siz[v]>siz[son[x]])
			son[x]=v;
	}
	R[x]=Index;
}

int calc(int x) { return x*(x+1)/2; }

void Insert(int x)
{
	bin[x]=1;
	if(x>1&&bin[x-1])
		ans_In-=calc(bin[B.find(x-1)]),
		B.merge(x,x-1);
	if(x<n&&bin[x+1])
		ans_In-=calc(bin[B.find(x+1)]),
		B.merge(x,x+1);
	ans_In+=calc(bin[B.find(x)]);
	S_IT it=S.insert(x).first;
	S_IT it1=it,it2=it;
	it1--,it2++;
	ans_Out=ans_Out-calc(*it2-*it1-1)+calc(*it-*it1-1)+calc(*it2-*it-1);
}

void DFS(int x,int fa)
{
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa||v==son[x])
			continue;
		DFS(v,x);
		ans_In=0,ans_Out=calc(n);
		for (int j=L[v];j<=R[v];j++)
		{
			int J=rev[j];
			bin[J]=0,S.erase(J),B.fa[J]=J;
		}
	}
	if(son[x])
		DFS(son[x],x);
	Insert(x);
	for (int i=head[x];i;i=g[i].nxt)
	{
		int v=g[i].to;
		if(v==fa||v==son[x])
			continue;
		for (int j=L[v];j<=R[v];j++)
			Insert(rev[j]);
	}
	ans+=calc(n)-ans_In-ans_Out;
}

void work()
{
	dfs(1,-1);
	S.insert(0),S.insert(n+1);
	for (int i=1;i<=n;i++)
		B.fa[i]=i;
	ans_Out=calc(n);
	DFS(1,-1);
	printf("%lld\n",ans);
}

signed main()
{
	init();
	work();
	return 0;
}
posted @ 2020-07-12 14:56  With_penguin  阅读(208)  评论(0编辑  收藏  举报