poj2763 树链剖分(线段树)
注意这里都是把边放到线段树中,所以lca的时候,要注意如果top[x]==top[y] && x==y 的时候已经完成了。
仔细想想边和点的不同之处!!!
#include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define INF 1000000007 #define mod 100000000 #define ll long long #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 using namespace std; const int MAXN = 100010; struct node { int to; int val; int next; }edge[MAXN*2]; int a[MAXN][3]; int fa[MAXN],pre[MAXN],ind,w[MAXN],top[MAXN],siz[MAXN],son[MAXN],deq[MAXN],cnt;//w[]表示它和他父亲的连线在线段树种的位置 int vis[MAXN],n,m,tree[MAXN<<2]; void add(int x,int y,int z) { edge[ind].to = y; edge[ind].val = z; edge[ind].next = pre[x]; pre[x] = ind++; } void dfs1(int rt,int pa,int d) { vis[rt] = 1; siz[rt] = 1; son[rt] = -1; deq[rt] = d; fa[rt] = pa; for(int i = pre[rt]; i != -1; i = edge[i].next){ int t = edge[i].to; if(!vis[t]){ dfs1(t,rt,d+1); siz[rt] += siz[t]; if(siz[t] > siz[son[rt]]){ son[rt] = t; } } } } void dfs2(int rt,int tp) { vis[rt] = 1; w[rt] = ++ cnt; top[rt] = tp; if(son[rt] != -1){ dfs2(son[rt],tp); } for(int i = pre[rt]; i != -1; i = edge[i].next){ int t = edge[i].to; if(!vis[t] && t != son[rt]){ dfs2(t,t); } } } void pushup(int rt) { tree[rt] = tree[rt<<1] + tree[rt<<1|1]; } void updata(int p,int val,int l,int r,int rt) { if(l == r){ tree[rt] = val; return ; } int m = (l + r) >> 1; if(m >= p){ updata(p,val,lson); } else { updata(p,val,rson); } pushup(rt); } int query(int L,int R,int l,int r,int rt) { if(L <= l && r <= R){ return tree[rt]; } int m = (l + r) >> 1; int ans = 0; if(m >= L){ ans += query(L,R,lson); } if(m < R){ ans += query(L,R,rson); } return ans; } int lca(int x,int y) { int ans = 0; while(top[x] != top[y]){ if(deq[top[x]] < deq[top[y]]){ swap(x,y); } ans += query(w[top[x]],w[x],1,cnt,1); x = fa[top[x]]; } if(x == y){//同一点没有边,仔细想想这就是为什么边和点的不同地方! return ans; } if(deq[x] < deq[y]){ swap(x,y); } ans += query(w[son[y]],w[x],1,cnt,1);//既然top[x] == top[y]说明x 和 y在同一条重边上 所以son[y]也在该路径上 //并且是树根的孩子,又w[]表示该点和其父亲相连的点的边在线段树中的编号, //所以需要这步,这样能得出答案。 return ans; } int main() { int s,p; while(~scanf("%d%d%d",&n,&p,&s)){ int x,y,z; ind = 0; memset(pre,-1,sizeof(pre)); for(int i = 1; i < n; i++){ scanf("%d%d%d",&a[i][0],&a[i][1],&a[i][2]); add(a[i][0],a[i][1],a[i][2]); add(a[i][1],a[i][0],a[i][2]); } memset(vis,0,sizeof(vis)); dfs1(1,1,1);//求deq siz son fa cnt = 0; memset(vis,0,sizeof(vis)); dfs2(1,1);//求top w memset(tree,0,sizeof(tree)); for(int i = 1; i < n; i++){ if(deq[a[i][0]] < deq[a[i][1]]){//让a[i][0]在下面 swap(a[i][0],a[i][1]); } updata(w[a[i][0]],a[i][2],1,cnt,1);//将a[i][0]和他父亲相连的边插入到线段树中 } int q; while(p--){ scanf("%d",&q); if(q == 0){ scanf("%d",&x); printf("%d\n",lca(s,x)); s = x; } else { scanf("%d%d",&x,&y); updata(w[a[x][0]],y,1,cnt,1); } } } return 0; }