【模板】树链剖分
题目传送门
代码如下
#include <iostream>
#include <cstdio>
#include <vector>
#define maxn 100005
using namespace std;
typedef long long ll;
struct T{
int data, next;
}e[maxn << 1];
int top[maxn], son[maxn], size[maxn], depth[maxn], data[maxn], fa[maxn];
int head[maxn], cnt;
vector<int> vec;
int p;
struct node{
int l, r;
ll sum;
ll lazy;
}tree[maxn << 2];
void add(int x, int y)
{
++ cnt;
e[cnt].data = y;
e[cnt].next = head[x];
head[x] = cnt;
}
void dfs1(int x)
{
size[x] = 1;
for(int i = head[x]; i != 0; i = e[i].next){
int r = e[i].data;
if(r != fa[x]){
depth[r] = depth[x] + 1;
fa[r] = x;
dfs1(r);
size[x] += size[r];
if(!son[x] || size[r] > size[son[x]])
son[x] = r;
}
}
}
int mp[maxn];
void dfs2(int x, int k)
{
if(x == 0)
return;
top[x] = k;
vec.push_back(x);
mp[x] = vec.size() - 1;
dfs2(son[x], k);
for(int i = head[x]; i != 0; i = e[i].next){
int r = e[i].data;
if(r != fa[x] && r != son[x]){
dfs2(r, r);
}
}
}
void build(int l, int r, int k)
{
tree[k].l = l;
tree[k].r = r;
if(l == r){
tree[k].sum = data[vec[l]];
return;
}
int mid = (l + r) / 2;
build(l, mid, 2*k);
build(mid + 1, r, 2*k+1);
tree[k].sum = tree[2*k].sum + tree[2*k+1].sum;
tree[k].sum %= p;
}
void down(int k)
{
if(tree[k].lazy == 0)
return;
tree[2*k].sum += (tree[2*k].r - tree[2*k].l + 1) * tree[k].lazy;
tree[2*k+1].sum += (tree[2*k+1].r - tree[2*k+1].l + 1) * tree[k].lazy;
tree[2*k].lazy += tree[k].lazy;
tree[2*k + 1].lazy += tree[k].lazy;
tree[k].lazy = 0;
}
void add(int l, int r, int z, int k)
{
if(tree[k].l >= l && tree[k].r <= r){
tree[k].sum += ((tree[k].r - tree[k].l + 1) * z) % p;
tree[k].sum %= p;
tree[k].lazy += z;
return;
}
down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if(l <= mid)
add(l, r, z, 2*k);
if(r > mid)
add(l, r, z, 2*k+1);
tree[k].sum = tree[2*k].sum + tree[2*k + 1].sum;
tree[k].sum %= p;
}
void add1(int x, int y, int z)
{
while(top[x] != top[y]){
if(depth[top[x]] > depth[top[y]]){
add(mp[top[x]], mp[x], z, 1);
x = fa[top[x]];
}
else {
add(mp[top[y]], mp[y], z, 1);
y = fa[top[y]];
}
}
if(depth[x] > depth[y])
add(mp[y], mp[x], z, 1);
else
add(mp[x], mp[y], z, 1);
}
ll query(int l, int r, int k)
{
if(tree[k].l >= l && tree[k].r <= r){
return tree[k].sum;
}
down(k);
int mid = (tree[k].l + tree[k].r) / 2;
ll ans = 0;
if(l <= mid)
ans += query(l, r, 2*k), ans %= p;
if(r > mid)
ans += query(l, r, 2*k+1), ans %= p;;
return ans;
}
ll get1(int x, int y)
{
ll ans = 0;
while(top[x] != top[y]){
if(depth[top[x]] > depth[top[y]]){
ans += query(mp[top[x]], mp[x], 1);
ans %= p;
x = fa[top[x]];
}
else {
ans += query(mp[top[y]], mp[y], 1);
ans %= p;
y = fa[top[y]];
}
}
if(depth[x] > depth[y])
ans += query(mp[y], mp[x], 1), ans %= p;
else
ans += query(mp[x], mp[y], 1), ans %= p;
return ans;
}
inline void add2(int x, int y)
{
add(mp[x], mp[x] + size[x] - 1, y, 1);
}
inline ll get2(int x)
{
return query(mp[x], mp[x] + size[x] - 1, 1) % p;
}
int main()
{
int n, m, r;
scanf("%d%d%d%d", &n, &m, &r, &p);
for(int i = 1; i <= n; i ++)
scanf("%d", &data[i]);
for(int i = 1; i < n; i ++){
int x, y;
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
vec.push_back(0);
dfs1(r);
dfs2(r, r);
build(1, n, 1);
for(int i = 1; i <= m; i ++){
int opt;
scanf("%d", &opt);
if(opt == 1){
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
add1(x, y, z);
}
else if(opt == 2){
int x, y;
scanf("%d%d", &x, &y);
printf("%lld\n", get1(x, y) % p);
}
else if(opt == 3){
int x, y;
scanf("%d%d", &x, &y);
add2(x, y);
}
else {
int x;
scanf("%d", &x);
printf("%lld\n", get2(x) % p);
}
}
return 0;
}