树链剖分

//计算过程中随时取模,否则一般会wrong 7个点 
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std; 
const int maxn=100010;

struct edge{ int t; edge *nxt; edge(int to, edge *next){ t=to, nxt=next; } };
edge *h[maxn];
void add(int u, int v){ h[u]=new edge(v, h[u]); }
int n, m, r, mod, v[maxn], f[maxn], d[maxn], size[maxn], son[maxn], rk[maxn], top[maxn], id[maxn], cnt;

void dfs1(int x)									//计算f, d, size, son 
{
	size[x]=1, d[x]=d[f[x]]+1;						//计算深度d 
	for(edge *p=h[x]; p; p=p->nxt)
	{
		if(p->t==f[x])	continue;					
		f[p->t]=x, dfs1(p->t), size[x]+=size[p->t];	//计算size,指定父亲f 
		if(size[son[x]]<size[p->t])	son[x]=p->t;	//计算重儿子son 
	}
}

void dfs2(int x, int tp)							//重新排序结点,计算rk, top, id 
{
	top[x]=tp, id[x]=++cnt, rk[cnt]=x;
	if(son[x])	dfs2(son[x], tp);					//先递归到重儿子
	for(edge *p=h[x]; p; p=p->nxt)
		if(p->t!=f[x] && p->t!=son[x])				//再处理轻儿子 
			dfs2(p->t, p->t);
}

inline long long read()                             //快读可以定义为内联函数,效率更高 
{
    long long s=0, w=1;
    char ch=getchar();
    while(ch<'0'  || ch>'9' ){ if(ch=='-') w=-1; ch=getchar(); }
    while(ch>='0' && ch<='9'){ s=s*10+ch-'0';    ch=getchar(); }
    return s*w;
}

struct SegNode
{
    int l, r;                                       //[l, r]为当前结点表示的区间,bj为懒人标记 
    long long val, bj;                              //根据题目要求,一定注意数据范围! 
    SegNode *lc, *rc;                               //指向左、右子树的指针
    SegNode(int left, int right){ l=left, r=right, bj=0, lc=rc=NULL; } 
}*root;												//root为线段树的根结点 

void buildT(SegNode * &x, int L, int R)             //x必须定义为引用,否则递归时无法修改指针的值
{
    x=new SegNode(L, R);                            //新建结点 
    if(L==R)                                        //已到达叶子结点 
    {
        x->val=v[rk[L]]%mod;
        return ;
    }
    int mid=(L+R)>>1;
    buildT(x->lc, L, mid);                          //递归构建左、右子树 
    buildT(x->rc, mid+1, R); 
    x->val=x->lc->val+x->rc->val;                   //回溯时合并左右子树的值 
}

void pushDown(SegNode *x)                           //当前结点的懒人标记下传 
{
    if(x->bj && x->l!=x->r)                         //有标记且x不是叶子结点,可以下传,从叶子结点下传可能会RE 
    {
        long long k=x->bj;
    	(x->lc->bj+=k)%=mod;
        (x->lc->val+=k*(x->lc->r-x->lc->l+1))%=mod;	//这里非常容易错,要乘k 
        (x->rc->bj+=k)%=mod;
        (x->rc->val+=k*(x->rc->r-x->rc->l+1))%=mod; //这里非常容易错,要乘k
        x->bj=0;                                    //清除x结点标记 
    }
}

void radd(int L, int R, int k, SegNode *x)          //区间加,[L, R] + k,x为根 
{
    if(L<=x->l && x->r<=R)                          //当前结点表示的区间被[L, R]覆盖,可直接被修改、设置懒人标记,停止递归 
    {
        (x->val+=k*(x->r-x->l+1)%mod)%=mod;         //修改区间val至正确值 
        (x->bj+=k)%=mod;                            //设懒人标记 
        return;
    }
    int m=(x->l+x->r)>>1;
    pushDown(x);                                    //要递归修改儿子,先把标记下传 
    if(L<=m)    radd(L, R, k, x->lc);               //递归修改左孩子 
    if(R>=m+1)  radd(L, R, k, x->rc);               //递归修改右孩子 
    x->val=(x->lc->val+x->rc->val)%mod;             //回溯时合并,更新父结点的值 
}
 
long long rquery(int L, int R, SegNode *x)          //区间和[L, R]查询,x为根 
{
    if(L<=x->l && x->r<=R)                          //当前结点表示的区间被[L, R]覆盖,可直接返回值 
        return x->val;
    int m=(x->l+x->r)>>1;
    pushDown(x);                                    //要向孩子查询,先更新孩子的值 
    long long ans=0;
    if(L<=m)    ans=(ans+rquery(L, R, x->lc))%mod;  //加左子树返回的值 
    if(R>=m+1)  ans=(ans+rquery(L, R, x->rc))%mod;  //加右子树返回的值 
    return ans;
}

int sum(int x, int y)								//求x--y最短路路径和 
{
	int ret=0;
	while(top[x]!=top[y])							//先统计位置较低树链 
	{
		if(d[top[x]]<d[top[y]])	swap(x, y);
		ret=(ret+rquery(id[top[x]], id[x], root))%mod;
		x=f[top[x]];
	}
	if(id[x]>id[y])	swap(x, y);
	return (ret+rquery(id[x], id[y], root))%mod; 	//加上最后剩余的树链 
}

void update(int x, int y, int c)					//跟新x--y最短路的值 
{
	while(top[x]!=top[y])							//先更新位置较低树链
	{
		if(d[top[x]]<d[top[y]])	swap(x, y);
		radd(id[top[x]], id[x], c, root);
		x=f[top[x]];
	}
	if(id[x]>id[y])	swap(x, y);
	radd(id[x], id[y], c, root);					//最后更新剩余的部分树链 
}

int main()
{
    scanf("%d%d%d%d", &n, &m, &r, &mod);
    for(int i=1; i<=n; i++)			scanf("%d", &v[i]);
    for(int i=1, x, y; i<n; i++)	scanf("%d%d", &x, &y), add(x, y), add(y, x);
	dfs1(r), dfs2(r, r), buildT(root, 1, n);		//第一次dfs计算轻重链,第二次dfs对结点重新排序,然后使用新序列rk构建线段树 
	for(int i=1, op, x, y, k; i<=m; i++)
    {
        scanf("%d", &op);
        if(op==1)	scanf("%d%d%d", &x, &y, &k), update(x, y, k);
        if(op==2)	scanf("%d%d", &x, &y), printf("%d\n", sum(x,y));
        if(op==3)	scanf("%d%d", &x, &y), radd(id[x], id[x]+size[x]-1, y, root);
        if(op==4)	scanf("%d", &x), printf("%lld\n", rquery(id[x], id[x]+size[x]-1, root));
    }
    return 0;
}
posted @ 2019-04-12 17:16  LFYZOI题解  阅读(191)  评论(0编辑  收藏  举报