[树链剖分]BZOJ3589动态树
题目描述
别忘了这是一棵动态树, 每时每刻都是动态的. 小明要求你在这棵树上维护两种事件
事件0:
这棵树长出了一些果子, 即某个子树中的每个节点都会长出K个果子.
事件1:
小明希望你求出几条树枝上的果子数. 一条树枝其实就是一个从某个节点到根的路径的一段. 每次小明会选定一些树枝, 让你求出在这些树枝上的节点的果子数的和. 注意, 树枝之间可能会重合, 这时重合的部分的节点的果子只要算一次.
输入
第一行一个整数n(1<=n<=200,000), 即节点数.
接下来n-1行, 每行两个数字u, v. 表示果子u和果子v之间有一条直接的边. 节点从1开始编号.
在接下来一个整数nQ(1<=nQ<=200,000), 表示事件.
最后nQ行, 每行开头要么是0, 要么是1.
如果是0, 表示这个事件是事件0. 这行接下来的2个整数u, delta表示以u为根的子树中的每个节点长出了delta个果子.
如果是1, 表示这个事件是事件1. 这行接下来一个整数K(1<=K<=5), 表示这次询问涉及K个树枝. 接下来K对整数u_k, v_k, 每个树枝从节点u_k到节点v_k. 由于果子数可能非常多, 请输出这个数模2^31的结果. 输出
对于每个事件1, 输出询问的果子数.
样例输入
5
1 2
2 3
2 4
1 5
3
0 1 1
0 2 3
1 2 3 1 1 4
样例输出
13
解析
因为是动态的,所以放弃倍增LCA,考虑树链剖分(为什么做树的问题时我脑子里只有这两个东西?)。
对于事件0,直接在dfs序上区间修改x及其子树,即从x到x+siz[x]-1全部加k。
对于事件1,先对于每条链,求出它在dfs序上对应的区间,然后将这些区间合并起来,最后查询即可
合并区间的时候,先按照左端点由小到大给区间排序,然后从小到大遍历,若当前区间的右端点在下一个区间左端点的右边就可以把两个区间合并起来。
还有就是模2^31的问题,找Master Yi问了一下,好像是可以不管炸int的,只需要在最后&一下2^31-1即可。
代码如下
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=200005;
const int inf=2147483647;
struct node{int l,r;}q[maxn];
int en,mk[maxn<<2],sum[maxn<<2];
int n,Q,id,info[maxn],nx[maxn<<1],v[maxn<<1];
int fa[maxn],dep[maxn],dfn[maxn],siz[maxn],son[maxn],top[maxn];
bool cmp(node a,node b){return a.l==b.l?a.r<b.r:a.l<b.l;}
void add(int u1,int v1){nx[++id]=info[u1];info[u1]=id;v[id]=v1;}
void dfs1(int x,int f)
{
dep[x]=dep[fa[x]=f]+(siz[x]=1);
for(int i=info[x];i;i=nx[i])if(v[i]!=f)
{
dfs1(v[i],x);siz[x]+=siz[v[i]];
if(siz[v[i]]>siz[son[x]])son[x]=v[i];
}
}
void dfs2(int x,int f)
{
dfn[x]=++id;top[x]=f;if(son[x])dfs2(son[x],f);
for(int i=info[x];i;i=nx[i])if(v[i]!=fa[x]&&v[i]!=son[x])dfs2(v[i],v[i]);
}
void pushdown(int id,int l,int r)
{
int mid=(l+r)/2;
mk[id*2]+=mk[id];mk[id*2+1]+=mk[id];
sum[id*2]+=(mid-l+1)*mk[id];sum[id*2+1]+=(r-mid)*mk[id];mk[id]=0;
}
void fix(int id,int l,int r,int l1,int r1,int k)
{
if(r1<l||r<l1)return;
if(l1<=l&&r<=r1){sum[id]+=k*(r-l+1);mk[id]+=k;return;}
pushdown(id,l,r);int mid=(l+r)/2;
fix(id*2,l,mid,l1,r1,k);fix(id*2+1,mid+1,r,l1,r1,k);sum[id]=sum[id*2]+sum[id*2+1];
}
int que(int id,int l,int r,int l1,int r1)
{
if(r1<l||r<l1)return 0;
if(l1<=l&&r<=r1)return sum[id];
pushdown(id,l,r);int mid=(l+r)/2;
return que(id*2,l,mid,l1,r1)+que(id*2+1,mid+1,r,l1,r1);
}
int main()
{
scanf("%d",&n);
for(int i=1,u1,v1;i<n;i++)scanf("%d%d",&u1,&v1),add(u1,v1),add(v1,u1);
id=0;dfs1(1,0);dfs2(1,1);id=0;
scanf("%d",&Q);
for(int i=1,ord,x,k;i<=Q;i++)
{
scanf("%d",&ord);
if(ord==0)scanf("%d%d",&x,&k),fix(1,1,n,dfn[x],dfn[x]+siz[x]-1,k);
if(ord==1)
{
scanf("%d",&k);en=0;
for(int i=1,a,b;i<=k;i++)
{
scanf("%d%d",&a,&b);
if(dep[b]>dep[a])swap(a,b);
while(top[a]!=top[b])q[++en]=(node){dfn[top[a]],dfn[a]},a=fa[top[a]];
q[++en]=(node){dfn[b],dfn[a]};
}
sort(q+1,q+1+en,cmp);
int ans=0;
for(int i=1,l,r;i<=en;i++)
{
l=q[i].l,r=q[i].r;
while(q[i+1].l<=r&&i+1<=en)r=max(q[i+1].r,r),i++;
ans+=que(1,1,n,l,r);
}
printf("%d\n",ans&inf);
}
}
}