Luogu_4242 树上的毒瘤

题意
给你一棵n个点的树,每个点都有一个颜色。定义一条路径<u,v>的权值为u到v路径上依次经过的点(包括u,v)的颜色序列的颜色段数。有两种操作:
1.将<u,v>路径上的点的颜色全改为y。
2.给定一个点的集合,对于集合内每个点求其到集合内所有点的路径权值的和。
\(1≤n,q≤100000,c_i,y\leq 10^9 ,\sum m\leq 200000,m\leq n。\)

题解
好一个小水题(苦笑)。
还是一样看到总询问点数与n同阶,直接上虚树。
我们先不考虑修改,那么每次询问我们只要给虚树上每条边赋一个权值,为这条边两端点的路径权值-1(左闭右开,减掉fa的那个颜色段,这样计算时直接两边权相加就行了)这样虚树上任意两点之间的路径权值即为经过的边权和+1,然后二次换根就可以算出每个点的答案。
但是怎么样算在原树上两点的路径权值呢?
我们可以用树剖+线段树。线段树每个点记录它所对应的区间的颜色段数,左端的颜色和右端的颜色就行了。
然后带修改的话,用树剖+线段树也很好直接维护。

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=1e5;
int n,Q,tot,Time,m;
int c[maxn+8],pre[maxn*2+8],now[maxn+8],son[maxn*2+8];
int dep[maxn+8],siz[maxn+8],heavy[maxn+8];
int id[maxn+8],dfn[maxn+8],top[maxn+8],fa[maxn+8];
int a[maxn+8],st[maxn+8];

int read()
{
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
    return x*f;
}

void add(int u,int v)
{
    pre[++tot]=now[u];
    now[u]=tot;
    son[tot]=v;
}

struct node
{
    int l,r,sum,lazy;
};

node operator +(node x,node y)
{
    if (!x.sum) return y;
    if (!y.sum) return x;
    return (node){x.l,y.r,x.sum+y.sum-(x.r==y.l),0};
}

struct Segment_Tree
{
    node tree[maxn*4+8];
    void build(int p,int l,int r)
    {
		if (l==r)
			{
				tree[p]=(node){c[id[l]],c[id[l]],1,0};
				return;
			}
		int mid=(l+r)>>1;
		build(p<<1,l,mid),build(p<<1|1,mid+1,r);
		tree[p]=tree[p<<1]+tree[p<<1|1];
    }
    void mark(int p,int v){tree[p]=(node){v,v,1,v};}
    void pushdown(int p)
    {
		if (tree[p].lazy)
			{
				mark(p<<1,tree[p].lazy);
				mark(p<<1|1,tree[p].lazy);
				tree[p].lazy=0;
			}
    }
    void change(int p,int l,int r,int L,int R,int v)
    {
		if (L<=l&&r<=R)
			{
				mark(p,v);
				return;
			}
		int mid=(l+r)>>1;
		pushdown(p);
		if (L<=mid) change(p<<1,l,mid,L,R,v);
		if (R>mid) change(p<<1|1,mid+1,r,L,R,v);
		tree[p]=tree[p<<1]+tree[p<<1|1];
    }
    node query(int p,int l,int r,int L,int R)
    {
		if (L<=l&&r<=R) return tree[p];
		int mid=(l+r)>>1;
		pushdown(p);
		node tmp=(node){0,0,0,0};
		if (L<=mid) tmp=tmp+query(p<<1,l,mid,L,R);
		if (R>mid) tmp=tmp+query(p<<1|1,mid+1,r,L,R);
		return tmp;
    }
}Sg;

void prepare(int x)
{
    dep[x]=dep[fa[x]]+1;
    siz[x]=1;
    for (int p=now[x];p;p=pre[p])
		{
			int child=son[p];
			if (child==fa[x]) continue;
			fa[child]=x;
			prepare(child);
			siz[x]+=siz[child];
			if (siz[heavy[x]]<siz[child]) heavy[x]=child;
		}
}
					      
void build(int x)
{
    id[dfn[x]=++Time]=x;
    top[x]=heavy[fa[x]]==x?top[fa[x]]:x;
    if (heavy[x]) build(heavy[x]);
    for (int p=now[x];p;p=pre[p])
		{
			int child=son[p];
			if (dfn[child]) continue;
			build(child);
		}
}

int Get_Lca(int u,int v,int t)
{
    //Big_flag=0;
    //if (u+v==12&&abs(u-v)==10) Big_flag=1,puts("Lemon Party");
    node tmp=(node){0,0,0,0};
    while(top[u]!=top[v])
		{
			if (dep[top[u]]<dep[top[v]]) swap(u,v);
			if (t>=0) Sg.change(1,1,n,dfn[top[u]],dfn[u],t);
			if (t==-2) tmp=Sg.query(1,1,n,dfn[top[u]],dfn[u])+tmp;
			u=fa[top[u]];
		}
    if (dep[u]<dep[v]) swap(u,v);
    if (t==-1) return v;
    if (t>=0)Sg.change(1,1,n,dfn[v],dfn[u],t);
    if (t==-2) tmp=Sg.query(1,1,n,dfn[v],dfn[u])+tmp;
    return tmp.sum;
}

struct Virtual_Tree
{
    int tot,tail;
    int now[maxn+8],pre[maxn*2+8],son[maxn*2+8],val[maxn*2+8];
    int color[maxn+8],st[maxn+8],siz[maxn+8];
    ll f[maxn+8];
    void clear()
    {
		for (int i=1;i<=m;i++) color[a[i]]=0;
		while(tail) now[st[tail--]]=0;
		tot=0;
    }
    void add(int u,int v,int w)
    {
		if (!now[u]) st[++tail]=u;
		pre[++tot]=now[u];
		now[u]=tot;
		son[tot]=v;
		val[tot]=w;
    }  
    void insert(int u,int v)
    {
		int w=Get_Lca(u,v,-2)-1;
		add(u,v,w),add(v,u,w);
    }
    void dfs1(int x,int fa)
    {
		siz[x]=color[x];
		f[x]=0;
		for (int p=now[x];p;p=pre[p])
			{
				int child=son[p];
				if (child==fa) continue;
				//printf("Dfs:%d %d %d\n",child,x,val[p]);
				dfs1(child,x);
				siz[x]+=siz[child];
				f[x]+=f[child]+1ll*val[p]*siz[child];
			}
    }
    void dfs2(int x,int fa)
    {
		for (int p=now[x];p;p=pre[p])
			{
				int child=son[p];
				if (child==fa) continue;
				f[child]=f[x]+1ll*val[p]*(m-2*siz[child]);
				dfs2(child,x);
			}
    }
}VT;

bool cmp(int x,int y){return dfn[x]<dfn[y];}
void solve()
{
    //puts("Enter");
    m=read();
    for (int i=1;i<=m;i++) VT.color[a[i]=read()]=1,id[i]=a[i];
    sort(a+1,a+m+1,cmp);
    int tail=1;
    st[tail]=n;
    for (int i=1;i<=m;i++)
		{
			int Lca=Get_Lca(st[tail],a[i],-1),pre=0;
			while(dep[Lca]<dep[st[tail]])
				{
					if (pre) VT.insert(pre,st[tail]);
					pre=st[tail--];
				}
			if (pre) VT.insert(pre,Lca);
			if (Lca!=st[tail]) st[++tail]=Lca;
			st[++tail]=a[i];
		}
    tail--;
    while(tail) VT.insert(st[tail],st[tail+1]),tail--;
    VT.dfs1(n,0);
    VT.dfs2(n,0);
    for (int i=1;i<=m;i++) printf("%lld ",VT.f[id[i]]+m);puts("");
    VT.clear();
}

int main()
{
    n=read(),Q=read();
    for (int i=1;i<=n;i++) c[i]=read();
    for (int i=1;i<n;i++)
		{
			int u=read(),v=read();
			add(u,v),add(v,u);
		}
    n++;add(n,1),add(1,n);c[n]=-2;
    prepare(n);
    build(n);
    Sg.build(1,1,n);
    while(Q--)
		{
			int mode=read();
			if (mode==1)
				{
					int u=read(),v=read(),y=read();
					Get_Lca(u,v,y);
				}
			else
				solve();
		}
    return 0;
}

	    
posted @ 2019-01-03 11:34  Alseo_Roplyer  阅读(164)  评论(0编辑  收藏  举报