【模板】树链剖分

题目描述

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

输入输出格式

输入格式:

 

第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。

接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。

接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)

接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:

操作1: 1 x y z

操作2: 2 x y

操作3: 3 x z

操作4: 4 x

 

输出格式:

 

输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)

 

输入输出样例

输入样例#1:
5 5 2 24
7 3 7 8 0 
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
输出样例#1:
2
21

说明

时空限制:1s,128M

数据规模:

对于30%的数据:N<=10,M<=10

对于70%的数据:N<=1000,M<=1000

对于100%的数据:N<=100000,M<=100000

(其实,纯随机生成的树LCA+暴力是能过的,可是,你觉得可能是纯随机的么233)

样例说明:

树的结构如下:

各个操作如下:

故输出应依次为2、21(重要的事情说三遍:记得取模)

一些

有关树链剖分

 首先,要会建树,建议用链表(我一开始各种T,而后抄的shenben学长的),指针,stl等。(不爆空间就好)

一些概念:

树链剖分实际上就是把树按一种特定规则拆解成一条又一条的链。

这些链经过所有的点,称为重链,链上的边,为重边,其余(一般很少)的边为轻边。

size:以该节点为根节点构建的树上的节点数;

重儿子:某一节点所有儿子中size值最大的一个,由重边与该节点相连,不可能位于链首;

轻儿子:不是宗子(重儿子)的儿子,成为新的诸侯王|族长(新链的链首节点);

如下图:

同一枝上一个色的属于同一条重链。

步骤:

读入数据:

读入点(p)权(k);

读入边并存放在链表中;

恩,我不会。

Dfs1:

求出点(p)的父节点(f),深度(p),size(sz),重儿子(ws);

Dfs2:

求出点(p)的线段坐标(p),链首节点(t),链尾节点(w);

build:

建造线段树;

处理处置:

1&2:lca;

3&4:从 根节点的线段坐标(p[].p) 到 根节点的线段坐标+根节点的size-1(p[].p+p[].sz-1)就是了。(可以想象)

return 0;

有关变量的一些说明:
n节点,m操作,root根节点,mod模数;

a,b,c,d,ans辅助变量;

node{//点对应
k权值,f父节点,d深度;
sz:size ws重儿子; p点对于线段上的位置坐标; t点所在链链首节点; }p[
1*] s[]线段暂存重链节点; l线段长; nl线段读入线段树时使用; tree{l,r,s,flag;}t[4*]//基本线段树(然而我一直开的2*,CE Qrz) h[]记录点对应链表上的启示位置 hs表示链表存到第几个了; nate{//应该是链表,抄的shenben学长的东西(原为node) s连接到的节点; n下一个(原为next) }p[2*]

代码实现:

  1 #include<cstdio>
  2 #include<iostream>
  3 using namespace std;
  4 int n,m,l,hs,nl,lfs,root,mod;
  5 int a,b,c,d,ans;
  6 int s[100010],h[100010];
  7 struct node
  8 {
  9     int k,f,d,sz;
 10     int ws;
 11     int p,t,w;
 12 }p[100010];
 13 struct edge{int s,n;}e[200010];
 14 struct tree{int l,r,s,flag;}t[400010];
 15 void heritage(int k)
 16 {
 17     int ls=k*2,rs=k*2+1;
 18     t[ls].flag=(t[ls].flag+t[k].flag)%mod;
 19     t[ls].s+=(t[ls].r-t[ls].l+1)*t[k].flag%mod;
 20     t[rs].flag=(t[rs].flag+t[k].flag)%mod;
 21     t[rs].s+=(t[rs].r-t[rs].l+1)*t[k].flag%mod;
 22     t[k].flag=0;
 23 }
 24 void build(int k,int l,int r)
 25 {
 26     int ls=k*2,rs=k*2+1;
 27     t[k].l=l;t[k].r=r;
 28     if(l==r)
 29     {
 30         t[k].s=s[++nl];
 31         return;
 32     }
 33     int mid=(l+r)/2;
 34     build(ls,l,mid);
 35     build(rs,mid+1,r);
 36     t[k].s=t[ls].s+t[rs].s;
 37 }
 38 void change(int k,int l,int r,int v)
 39 {
 40     int ls=k*2,rs=k*2+1;
 41     if(t[k].l==l&&t[k].r==r)
 42     {
 43         t[k].flag=(t[k].flag+v)%mod;
 44         t[k].s+=(t[k].r-t[k].l+1)*v%mod;
 45         return;
 46     }
 47     if(t[k].flag) heritage(k);
 48     int mid=(t[k].l+t[k].r)/2;
 49     if( l <= mid )
 50     change(ls,l,min(r,mid),v);
 51     if( r > mid )
 52     change(rs,max(l,mid+1),r,v);
 53     t[k].s=(t[ls].s+t[rs].s)%mod;
 54 }
 55 int query(int k,int l,int r)
 56 {
 57     int ls=k*2,rs=k*2+1;
 58     if(t[k].l==l&&t[k].r==r) return t[k].s;
 59     if(t[k].flag) heritage(k);
 60     int mid=(t[k].l+t[k].r)/2,ans=0;
 61     if( l <= mid )
 62     ans+=query(ls,l,min(r,mid))%mod;
 63     if( r > mid )
 64     ans+=query(rs,max(l,mid+1),r)%mod;
 65     return ans%mod;
 66 }
 67 bool traceability(int x,int y)
 68 {
 69     if(x==p[y].w) return 0;
 70     while(p[x].d>=p[y].d)
 71     {
 72         if(p[x].t==p[y].t) return 1;
 73         x=p[p[x].t].f;
 74     }
 75 }
 76 inline void add(int x,int y){e[++hs]=(edge){y,h[x]};h[x]=hs;}
 77 void dfs1(int k,int f,int d)
 78 {
 79     p[k].f=f;
 80     p[k].d=d;
 81     p[k].sz=1;
 82     for(int i=h[k];i;i=e[i].n)
 83     {
 84         if(e[i].s!=f)
 85         {
 86             dfs1(e[i].s,k,d+1);
 87             p[k].sz+=p[e[i].s].sz;
 88             if(p[e[i].s].sz>p[p[k].ws].sz)
 89             p[k].ws=e[i].s;
 90         }
 91     }
 92 }
 93 void dfs2(int k)
 94 {
 95     s[++l]=p[k].k;
 96     p[k].p=l;
 97     if(p[k].ws)
 98     {
 99         p[p[k].ws].t=p[k].t;
100         dfs2(p[k].ws);
101         p[k].w=p[k].ws;
102     }
103     for(int i=h[k];i;i=e[i].n)
104     if(e[i].s!=p[k].ws&&e[i].s!=p[k].f)
105     {
106         p[e[i].s].w=p[e[i].s].t=e[i].s;
107         dfs2(e[i].s);
108     }
109 }
110 int main()
111 {
112     scanf("%d%d%d%d",&n,&m,&root,&mod);
113     for(int i=1;i<=n;i++) scanf("%d",&p[i].k);
114     for(int i=1;i<n;i++)
115     {
116         scanf("%d%d",&a,&b);
117         add(a,b);add(b,a);
118     }
119     dfs1(root,root,1);
120     p[root].w=p[root].t=root;
121     dfs2(root);
122     build(1,1,l);
123     while(m--)
124     {
125         scanf("%d",&a);
126         if(a==1)
127         {
128             scanf("%d%d%d",&b,&c,&d);
129             d%=mod;
130             for(;p[b].t!=p[c].t;b=p[p[b].t].f)
131             {
132                 if(p[p[b].t].d<p[p[c].t].d) swap(b,c);
133                 change(1,p[p[b].t].p,p[b].p,d);
134             }
135             if(p[b].d>p[c].d) swap(b,c);
136             change(1,p[b].p,p[c].p,d);
137         }
138         if(a==2)
139         {
140             scanf("%d%d",&b,&c);
141             ans=0;
142             for(;p[b].t!=p[c].t;b=p[p[b].t].f)
143             {
144                 if(p[p[b].t].d<p[p[c].t].d) swap(b,c);
145                 ans+=query(1,p[p[b].t].p,p[b].p);
146                 ans%=mod;
147             }
148             if(p[b].d>p[c].d) swap(b,c);
149             ans+=query(1,p[b].p,p[c].p);
150             ans%=mod;
151             printf("%d\n",ans);
152         }
153         if(a==3)
154         {
155             scanf("%d%d",&b,&c);
156             c%=mod;
157             change(1,p[b].p,p[b].p+p[b].sz-1,c);
158         }
159         if(a==4)
160         {
161             scanf("%d",&b);
162             ans=0;
163             ans=query(1,p[b].p,p[b].p+p[b].sz-1)%mod;
164             printf("%d\n",ans);
165         }
166     }
167     return 0;
168 }

我记得我上次代码过百行是在201...

  1 #include<cstdio>
  2 #include<iostream>
  3 #define ls k*2
  4 #define rs k*2+1
  5 using namespace std;
  6 int n,m,l,hs,nl,lfs,root,mod;
  7 int a,b,c,d,ans;
  8 int s[100010],h[100010];
  9 struct node{int k,f,d,sz,ws,p,t;}p[100010];
 10 struct nate{int s,n;}e[200010];
 11 struct tree{int l,r,s,f;}t[400010];
 12 void heritage(int k){
 13     t[ls].f=(t[ls].f+t[k].f)%mod;
 14     t[ls].s+=(t[ls].r-t[ls].l+1)*t[k].f%mod;
 15     t[rs].f=(t[rs].f+t[k].f)%mod;
 16     t[rs].s+=(t[rs].r-t[rs].l+1)*t[k].f%mod;
 17     t[k].f=0;
 18 }
 19 void build(int k,int l,int r){
 20     t[k].l=l;t[k].r=r;
 21     if(l==r){t[k].s=s[++nl];return;}
 22     int mid=(l+r)/2;
 23     build(ls,l,mid);
 24     build(rs,mid+1,r);
 25     t[k].s=t[ls].s+t[rs].s;
 26 }
 27 void change(int k,int l,int r,int v){
 28     if(t[k].l==l&&t[k].r==r){
 29         t[k].f=(t[k].f+v)%mod;
 30         t[k].s+=(t[k].r-t[k].l+1)*v%mod;
 31         return;
 32     }
 33     if(t[k].f) heritage(k);
 34     int mid=(t[k].l+t[k].r)/2;
 35     if(l<=mid) change(ls,l,min(r,mid),v);
 36     if(r>mid) change(rs,max(l,mid+1),r,v);
 37     t[k].s=(t[ls].s+t[rs].s)%mod;
 38 }
 39 int query(int k,int l,int r){
 40     if(t[k].l==l&&t[k].r==r) return t[k].s;
 41     if(t[k].f) heritage(k);
 42     int mid=(t[k].l+t[k].r)/2,ans=0;
 43     if(l<=mid) ans+=query(ls,l,min(r,mid))%mod;
 44     if(r>mid) ans+=query(rs,max(l,mid+1),r)%mod;
 45     return ans%mod;
 46 }
 47 inline void add(int x,int y){e[++hs]=(nate){y,h[x]};h[x]=hs;}
 48 void dfs1(int k,int f,int d){
 49     p[k].f=f;p[k].d=d;p[k].sz=1;
 50     for(int i=h[k];i;i=e[i].n)
 51     if(e[i].s!=f){
 52         dfs1(e[i].s,k,d+1);
 53         p[k].sz+=p[e[i].s].sz;
 54         if(p[e[i].s].sz>p[p[k].ws].sz) p[k].ws=e[i].s;
 55     }
 56 }
 57 void dfs2(int k){
 58     s[++l]=p[k].k;p[k].p=l;
 59     if(p[k].ws){
 60         p[p[k].ws].t=p[k].t;
 61         dfs2(p[k].ws);
 62     }
 63     for(int i=h[k];i;i=e[i].n)
 64     if(e[i].s!=p[k].ws&&e[i].s!=p[k].f){
 65         p[e[i].s].t=e[i].s;
 66         dfs2(e[i].s);
 67     }
 68 }
 69 int main(){
 70     scanf("%d%d%d%d",&n,&m,&root,&mod);
 71     for(int i=1;i<=n;i++) scanf("%d",&p[i].k);
 72     for(int i=1;i<n;i++) scanf("%d%d",&a,&b),add(a,b),add(b,a);
 73     dfs1(root,root,1);
 74     dfs2(root);
 75     build(1,1,l);
 76     while(m--){
 77         scanf("%d",&a);
 78         if(a==1){
 79             scanf("%d%d%d",&b,&c,&d);d%=mod;
 80             for(;p[b].t!=p[c].t;b=p[p[b].t].f){
 81                 if(p[p[b].t].d<p[p[c].t].d) swap(b,c);
 82                 change(1,p[p[b].t].p,p[b].p,d);
 83             }
 84             if(p[b].d>p[c].d) swap(b,c);
 85             change(1,p[b].p,p[c].p,d);
 86         }
 87         if(a==2){
 88             scanf("%d%d",&b,&c);ans=0;
 89             for(;p[b].t!=p[c].t;b=p[p[b].t].f){
 90                 if(p[p[b].t].d<p[p[c].t].d) swap(b,c);
 91                 ans+=query(1,p[p[b].t].p,p[b].p),ans%=mod;
 92             }
 93             if(p[b].d>p[c].d) swap(b,c);
 94             ans+=query(1,p[b].p,p[c].p),ans%=mod;
 95             printf("%d\n",ans);
 96         }
 97         if(a==3){
 98             scanf("%d%d",&b,&c),c%=mod;
 99             change(1,p[b].p,p[b].p+p[b].sz-1,c);
100         }
101         if(a==4){
102             scanf("%d",&b),ans=0;
103             ans=query(1,p[b].p,p[b].p+p[b].sz-1)%mod;
104             printf("%d\n",ans);
105         }
106     }
107     return 0;
108 }
无注释版(我不是邪教徒)(反复更新)

评测结果:

搞了一天半,身心俱疲啊~

题目来源:洛谷

posted @ 2017-02-08 21:07  J_william  阅读(287)  评论(0编辑  收藏  举报