Luogu_4242 树上的毒瘤
题意
给你一棵n个点的树,每个点都有一个颜色。定义一条路径<u,v>的权值为u到v路径上依次经过的点(包括u,v)的颜色序列的颜色段数。有两种操作:
1.将<u,v>路径上的点的颜色全改为y。
2.给定一个点的集合,对于集合内每个点求其到集合内所有点的路径权值的和。
\(1≤n,q≤100000,c_i,y\leq 10^9 ,\sum m\leq 200000,m\leq n。\)
题解
好一个小水题(苦笑)。
还是一样看到总询问点数与n同阶,直接上虚树。
我们先不考虑修改,那么每次询问我们只要给虚树上每条边赋一个权值,为这条边两端点的路径权值-1(左闭右开,减掉fa的那个颜色段,这样计算时直接两边权相加就行了)这样虚树上任意两点之间的路径权值即为经过的边权和+1,然后二次换根就可以算出每个点的答案。
但是怎么样算在原树上两点的路径权值呢?
我们可以用树剖+线段树。线段树每个点记录它所对应的区间的颜色段数,左端的颜色和右端的颜色就行了。
然后带修改的话,用树剖+线段树也很好直接维护。
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=1e5;
int n,Q,tot,Time,m;
int c[maxn+8],pre[maxn*2+8],now[maxn+8],son[maxn*2+8];
int dep[maxn+8],siz[maxn+8],heavy[maxn+8];
int id[maxn+8],dfn[maxn+8],top[maxn+8],fa[maxn+8];
int a[maxn+8],st[maxn+8];
int read()
{
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*f;
}
void add(int u,int v)
{
pre[++tot]=now[u];
now[u]=tot;
son[tot]=v;
}
struct node
{
int l,r,sum,lazy;
};
node operator +(node x,node y)
{
if (!x.sum) return y;
if (!y.sum) return x;
return (node){x.l,y.r,x.sum+y.sum-(x.r==y.l),0};
}
struct Segment_Tree
{
node tree[maxn*4+8];
void build(int p,int l,int r)
{
if (l==r)
{
tree[p]=(node){c[id[l]],c[id[l]],1,0};
return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid),build(p<<1|1,mid+1,r);
tree[p]=tree[p<<1]+tree[p<<1|1];
}
void mark(int p,int v){tree[p]=(node){v,v,1,v};}
void pushdown(int p)
{
if (tree[p].lazy)
{
mark(p<<1,tree[p].lazy);
mark(p<<1|1,tree[p].lazy);
tree[p].lazy=0;
}
}
void change(int p,int l,int r,int L,int R,int v)
{
if (L<=l&&r<=R)
{
mark(p,v);
return;
}
int mid=(l+r)>>1;
pushdown(p);
if (L<=mid) change(p<<1,l,mid,L,R,v);
if (R>mid) change(p<<1|1,mid+1,r,L,R,v);
tree[p]=tree[p<<1]+tree[p<<1|1];
}
node query(int p,int l,int r,int L,int R)
{
if (L<=l&&r<=R) return tree[p];
int mid=(l+r)>>1;
pushdown(p);
node tmp=(node){0,0,0,0};
if (L<=mid) tmp=tmp+query(p<<1,l,mid,L,R);
if (R>mid) tmp=tmp+query(p<<1|1,mid+1,r,L,R);
return tmp;
}
}Sg;
void prepare(int x)
{
dep[x]=dep[fa[x]]+1;
siz[x]=1;
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa[x]) continue;
fa[child]=x;
prepare(child);
siz[x]+=siz[child];
if (siz[heavy[x]]<siz[child]) heavy[x]=child;
}
}
void build(int x)
{
id[dfn[x]=++Time]=x;
top[x]=heavy[fa[x]]==x?top[fa[x]]:x;
if (heavy[x]) build(heavy[x]);
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (dfn[child]) continue;
build(child);
}
}
int Get_Lca(int u,int v,int t)
{
//Big_flag=0;
//if (u+v==12&&abs(u-v)==10) Big_flag=1,puts("Lemon Party");
node tmp=(node){0,0,0,0};
while(top[u]!=top[v])
{
if (dep[top[u]]<dep[top[v]]) swap(u,v);
if (t>=0) Sg.change(1,1,n,dfn[top[u]],dfn[u],t);
if (t==-2) tmp=Sg.query(1,1,n,dfn[top[u]],dfn[u])+tmp;
u=fa[top[u]];
}
if (dep[u]<dep[v]) swap(u,v);
if (t==-1) return v;
if (t>=0)Sg.change(1,1,n,dfn[v],dfn[u],t);
if (t==-2) tmp=Sg.query(1,1,n,dfn[v],dfn[u])+tmp;
return tmp.sum;
}
struct Virtual_Tree
{
int tot,tail;
int now[maxn+8],pre[maxn*2+8],son[maxn*2+8],val[maxn*2+8];
int color[maxn+8],st[maxn+8],siz[maxn+8];
ll f[maxn+8];
void clear()
{
for (int i=1;i<=m;i++) color[a[i]]=0;
while(tail) now[st[tail--]]=0;
tot=0;
}
void add(int u,int v,int w)
{
if (!now[u]) st[++tail]=u;
pre[++tot]=now[u];
now[u]=tot;
son[tot]=v;
val[tot]=w;
}
void insert(int u,int v)
{
int w=Get_Lca(u,v,-2)-1;
add(u,v,w),add(v,u,w);
}
void dfs1(int x,int fa)
{
siz[x]=color[x];
f[x]=0;
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
//printf("Dfs:%d %d %d\n",child,x,val[p]);
dfs1(child,x);
siz[x]+=siz[child];
f[x]+=f[child]+1ll*val[p]*siz[child];
}
}
void dfs2(int x,int fa)
{
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
f[child]=f[x]+1ll*val[p]*(m-2*siz[child]);
dfs2(child,x);
}
}
}VT;
bool cmp(int x,int y){return dfn[x]<dfn[y];}
void solve()
{
//puts("Enter");
m=read();
for (int i=1;i<=m;i++) VT.color[a[i]=read()]=1,id[i]=a[i];
sort(a+1,a+m+1,cmp);
int tail=1;
st[tail]=n;
for (int i=1;i<=m;i++)
{
int Lca=Get_Lca(st[tail],a[i],-1),pre=0;
while(dep[Lca]<dep[st[tail]])
{
if (pre) VT.insert(pre,st[tail]);
pre=st[tail--];
}
if (pre) VT.insert(pre,Lca);
if (Lca!=st[tail]) st[++tail]=Lca;
st[++tail]=a[i];
}
tail--;
while(tail) VT.insert(st[tail],st[tail+1]),tail--;
VT.dfs1(n,0);
VT.dfs2(n,0);
for (int i=1;i<=m;i++) printf("%lld ",VT.f[id[i]]+m);puts("");
VT.clear();
}
int main()
{
n=read(),Q=read();
for (int i=1;i<=n;i++) c[i]=read();
for (int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
n++;add(n,1),add(1,n);c[n]=-2;
prepare(n);
build(n);
Sg.build(1,1,n);
while(Q--)
{
int mode=read();
if (mode==1)
{
int u=read(),v=read(),y=read();
Get_Lca(u,v,y);
}
else
solve();
}
return 0;
}