2015多校第9场 HDU 5405 Sometimes Naive 树链剖分
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5405
题意: 给你一棵n个节点的树,有点权。
要求支持两种操作:
操作1:更改某个节点的权值。
操作2:给定u,v, 求 Σw[i][j] i , j 为任意两点且i到j的路径与u到v的路径相交。
解法:
这是一个大树剖题。
容易发现对于一个询问,答案为总点权和的平方 减去 去掉u--v这条链后各个子树的点权和的平方的和。
开两棵线段树,tag1记录点权和,tag2记录某点的所有轻链子树的点权和的平方的和。
每次沿着重链往上走时,直接加上这条重链的所有点的tag2和,若有重儿子则直接用tag1计算。由于该条重链必定为其父亲的轻链,故为防止计算重复,还需减去该重链所有点的tag1平方和。
最后爬到同一颗重链后,还需计算重链上方所有点的贡献。
//HDU 5405 //答案为总点权和的平方 减去 去掉u--v这条链后各个子树的点权和的平方的和。 //T1记录点权和,T2记录某点的所有轻链子树的点权和的平方的和 //每次沿着重链往上走时,直接加上这条重链的所有点的tag2和,若有重儿子则直接用tag1计算。 //由于该条重链必定为其父亲的轻链,故为防止计算重复,还需减去该重链所有点的tag1平方和。 //最后爬到同一颗重链后,还需计算重链上方所有点的贡献。 #include <bits/stdc++.h> using namespace std; typedef long long LL; const int maxn = 1e5+5; const int mod = 1e9+7; struct Tree{ LL sum[maxn<<2]; void build(){ memset(sum,0,sizeof(sum)); } void pushup(int rt){ sum[rt] = (sum[rt<<1]+sum[rt<<1|1])%mod; } void update(int pos, LL v, int l, int r, int rt){ if(l == r){ sum[rt] += v; sum[rt] %= mod; return; } int mid = (l+r)>>1; if(pos <= mid) update(pos, v, l, mid, rt<<1); else update(pos, v, mid+1, r, rt<<1|1); pushup(rt); } LL query(int L, int R, int l, int r, int rt){ if(L<=l&&r<=R){ return sum[rt]; } int mid = (l+r)/2; if(R<=mid) return query(L,R,l,mid,rt<<1)%mod; else if(L>mid) return query(L,R,mid+1,r,rt<<1|1)%mod; else return (query(L,mid,l,mid,rt<<1)+query(mid+1,R,mid+1,r,rt<<1|1))%mod; } }T1, T2; int head[maxn],n, m, edgecnt, dfsclk; struct edge{ int to,next; }E[maxn*2]; int sz[maxn], top[maxn], son[maxn], dep[maxn]; int fa[maxn], tid[maxn], val[maxn]; void init(){ edgecnt = 0; dfsclk = 0; memset(head, -1, sizeof(head)); memset(son, -1, sizeof(son)); } void addedge(int u, int v){ E[edgecnt].to = v, E[edgecnt].next = head[u], head[u] = edgecnt++; } void dfs1(int u, int father, int d){ dep[u] = d; fa[u] = father; sz[u] = 1; for(int i = head[u]; ~i; i=E[i].next){ int v = E[i].to; if(v == father) continue; dfs1(v, u, d+1); sz[u] += sz[v]; if(son[u] == -1 || sz[v]>sz[son[u]]) son[u] = v; } } void dfs2(int u, int tp) { top[u] = tp; tid[u] = ++dfsclk; if(son[u] == -1) return; dfs2(son[u], tp); for(int i = head[u]; ~i; i=E[i].next){ int v = E[i].to; if(v!=son[u]&&v!=fa[u]) dfs2(v,v); } } inline LL sqr(int x){ return (LL)x*x; } void update(int x, int v){ int u = top[x]; while(fa[u]){ LL sum = T1.query(tid[u], tid[u]+sz[u]-1, 1, n, 1); T2.update(tid[fa[u]], ((sqr(val[x]-v)%mod)%mod-(LL)sum*2*(val[x]-v)%mod)%mod, 1, n, 1); u = top[fa[u]]; } T1.update(tid[x], v-val[x], 1, n, 1); val[x] = v; } LL query(int x, int y){ LL ret = 0; while(top[x] != top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); ret += T2.query(tid[top[x]], tid[x], 1, n, 1); ret %= mod; if(son[x]!=-1){ LL sum = T1.query(tid[son[x]], tid[son[x]]+sz[son[x]]-1, 1, n, 1); ret = ret + sum*sum%mod; ret %= mod; } LL sum = T1.query(tid[top[x]], tid[top[x]]+sz[top[x]]-1, 1, n, 1); ret = (ret - sum*sum%mod + mod)%mod; x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x, y); ret += T2.query(tid[x], tid[y], 1, n, 1); ret %= mod; if(son[y]!=-1){ LL sum = T1.query(tid[son[y]], tid[son[y]]+sz[son[y]]-1, 1, n, 1); ret = (ret + sum*sum%mod)%mod; } if(fa[x]){ LL sum = T1.query(1, n, 1, n, 1) - T1.query(tid[x], tid[x]+sz[x]-1, 1, n, 1); ret = (ret+sum*sum%mod)%mod; } return ret; } int main() { while(~scanf("%d %d", &n,&m)) { init(); T1.build(); T2.build(); for(int i=1; i<=n; i++) scanf("%d", &val[i]); for(int i=1; i<n; i++){ int u, v; scanf("%d %d", &u,&v); addedge(u, v); addedge(v, u); } dfs1(1, 0, 0); dfs2(1, 1); for(int i=1; i<=n; i++){ int x = val[i]; val[i] = 0; update(i, x); } while(m--) { int op, x, y; scanf("%d %d %d", &op,&x,&y); if(op == 1){ update(x, y); } else{ LL sum = T1.query(tid[1], tid[1]+sz[1]-1, 1, n, 1); sum = sum*sum; sum = sum-query(x, y); sum = sum%mod; if(sum<0) sum+=mod; printf("%lld\n", sum); } } } return 0; }