[bzoj4765]普通计算姬
Description
给定一棵\(n\)个节点的带权树,节点编号为\(1\)~\(n\),以\(root\)为根,设\(sum_p\)表示以点\(p\)为根的这棵子树中所有节点的权值和。支持下列两种操作:
1.给定两个整数\(u,v\),修改点\(u\)的权值为\(v\);
2.给定两个整数\(l,r\),计算\(\sum_{i=l}^{r}sum_{i}\)。
Input
第一行两个整数\(n,m\),表示树的节点数与操作次数。
接下来一行\(n\)个整数,第\(i\)个整数\(d_i\)表示点\(i\)的初始权值。
接下来\(n\)行每行两个整数\(a_i,b_i\),表示一条树上的边,若\(a_i=0\)则说明\(b_i\)是根\(root\)。
接下来\(m\)行每行三个整数,第一个整数\(op\)表示操作类型。
若\(op=1\)则接下来两个整数\(u,v\)表示将点\(u\)的权值修改为\(v\)。
若\(op=2\)则接下来两个整数\(l,r\)表示询问\(\sum_{i=l}^{r}sum_{i}\)。
Output
对每个操作类型2输出一行一个整数表示答案。
Sample Input
6 4
0 0 3 4 0 1
0 1
1 2
2 3
2 4
3 5
5 6
2 1 2
1 1 1
2 3 6
2 3 5
Sample Output
16
10
9
HINT
\(0\;\leq\;d_i,v<2^{31},1\;\leq\;l\;\leq\;r\;\leq\;n,1\;\leq\;u\;\leq\;n\).
Solution
将\(a_i\)分成\(k(\sqrt{n}\;\leq\;k\;\leq\;\sqrt{n}+1)\)个区间,每个区间的大小为\(\sqrt{n}\)。预处理出每个节点\(i\)到\(root\)路线上的所有节点所属分块\(j\)的总数\(t[i][j]\)。每次修改时只需修改每个分块的总值。每次询问时,只需算至多\(k\)个区间的和,以及至多\(2\times\sqrt{n}\)个\(a_i\)。
再记录每个节点的\(dfs\)序:
- \(fro[i]\)表示开始访问\(i\)节点的时间
- \(beh[i]\)表示结束访问\(i\)节点的时间
- \(dfs\)序对应的值记为\(key[\;]\)
每次改变一个节点i的值时,在\(fro[i]\)之前(包括其自身)的所有\(key[i]\)值都加上\(v-a_i\)。
再对\(dfs\)序用类似的方法进行分块,记录每个分块里的节点统一被改变的\(a_i\)值,记为\(s[\;]\)。
记\(fro[i]\)所属分块为\(fr\),\(beh[i]\)所属分块为\(be\),则节点\(i\)的值为\((s[fr]+key[fro[i]])-(s[be]+key[beh[i]])\)。
时间复杂度:\(O(n+\sqrt{n}\;\times\;n)\)。
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define K 317
#define N 100005
#define M 200005
using namespace std;
typedef unsigned long long ll;
struct graph{
int nxt,to;
}e[M];
int g[N],f[N],n,m,cnt;
ll a[N];bool v[N];
/*==========================_d:dfs序 _n:编号==========================*/
ll tot[N][K]/*每个点对应编号分块个数*/,s_d[K<<1],s_n[K]/*每个分块里的数字和*/,key[N<<1]/*dfs序单个值*/;
int fro[N],beh[N],r_d,t_d=1,r_n,t_n;//s:分块总数,r:分块大小
int n_n[N]/*序号所属分块*/,n_f[N]/*fro所属编号*/,n_b[N]/*beh所属编号*/;
/*===============================read&write===============================*/
inline int read(){
int ret=0;char c=getchar();
while(!isdigit(c))
c=getchar();
while(isdigit(c)){
ret=ret*10+c-'0';
c=getchar();
}
return ret;
}
inline ll read_ll(){
ll ret=0;char c=getchar();
while(!isdigit(c))
c=getchar();
while(isdigit(c)){
ret=ret*10+c-'0';
c=getchar();
}
return ret;
}
inline void write(ll k){
if(!k) return;
write(k/10);
putchar(k%10+'0');
}
/*===============================ini_tree=================================*/
inline void addedge(int x,int y){
e[++cnt].nxt=g[x];g[x]=cnt;e[cnt].to=y;
}
inline void dfs(int u){
int sta[N],top=0;
cnt=0;v[u]=true;
for(int i=g[u];i;i=e[i].nxt){
v[e[i].to]=true;
sta[++top]=e[i].to;
++tot[e[i].to][n_n[e[i].to]];
}
while(top){
u=sta[top];fro[u]=++cnt;
if(v[e[g[u]].to]){
beh[u]=++cnt;--top;
while(f[u]){
u=sta[top--];beh[u]=++cnt;
}
}
else f[e[g[u]].to]=u;
for(int i=g[u];i;i=e[i].nxt)
if(!v[e[i].to]){
v[e[i].to]=true;
sta[++top]=e[i].to;
for(int j=1;j<=t_n;++j)
tot[e[i].to][j]=tot[u][j];
++tot[e[i].to][n_n[e[i].to]];
}
}
}
/*==================================do===================================*/
inline void change(int u,ll k){
// printf("k=%lld\n",k);
a[u]+=k;
for(int i=1;i<=t_n;++i)
s_n[i]+=tot[u][i]*k;
for(int i=1;i<n_f[u];++i)
s_d[i]+=k;
if(n_f[u]*r_d==fro[u])
s_d[n_f[u]]+=k;
else for(int i=(n_f[u]-1)*r_d+1;i<=fro[u];++i)
key[i]+=k;
}
/*=======================================================================*/
inline void init(){
/*==================read===================*/
n=read();m=read();
for(int i=1;i<=n;++i)
a[i]=read_ll();
for(int i=1,j,k;i<=n;++i){
j=read();k=read();
addedge(j,k);addedge(k,j);
}
/*================ini_tree=================*/
r_n=sqrt(n);
for(int i=1;i<=n;i+=r_n){
++t_n;
for(int j=0;j<r_n&&i+j<=n;++j)
n_n[i+j]=t_n;
}
dfs(0);
r_d=sqrt(n<<1);t_d=((n<<1)+r_d-1)/r_d;
for(int i=1;i<=n;++i){
n_f[i]=(fro[i]+r_d-1)/r_d;
n_b[i]=(beh[i]+r_d-1)/r_d;
}
for(int i=1;i<=n;++i){
a[0]=a[i];a[i]=0;change(i,a[0]);
}
/*===================do====================*/
int op,l,r,u;ll v,ans;
while(m--){
op=read();
if(op==1){
u=read();v=read_ll();
change(u,v-a[u]);
}
else{
l=read();r=read();ans=0;
if(n_n[l]!=n_n[r]){
for(int i=n_n[l]+1;i<n_n[r];++i)
ans+=s_n[i];
if(l==(n_n[l]-1)*r_n+1) ans+=s_n[n_n[l]];
else for(int i=n_n[l]*r_n;i>=l;--i)
ans+=(s_d[n_f[i]]+key[fro[i]])-(s_d[n_b[i]]+key[beh[i]]);
if(r==n_n[r]*r_n) ans+=s_n[n_n[r]];
else for(int i=(n_n[r]-1)*r_n+1;i<=r;++i)
ans+=(s_d[n_f[i]]+key[fro[i]])-(s_d[n_b[i]]+key[beh[i]]);
}
else{
if(l==(n_n[l]-1)*r_n+1&&r==n_n[r]*r_n){
ans=s_n[n_n[l]];
}
else for(int i=l;i<=r;++i)
ans+=(s_d[n_f[i]]+key[fro[i]])-(s_d[n_b[i]]+key[beh[i]]);
}
if(!ans) putchar('0');
else write(ans);
putchar('\n');
}
}
}
int main(){
freopen("common.in","r",stdin);
freopen("common.out","w",stdout);
init();
fclose(stdin);
fclose(stdout);
return 0;
}