[SimpleOJ238]宝藏探寻

题目大意:
  给你一棵带点权的n个结点的树,有m次询问,每次从树上删掉一条路径(u,v),问删掉每条路径后各个连通块权值和的平方之和。
  每次询问是独立的。

思路:
  首先对树遍历一遍求出每棵子树的权值和。
  然后倍增记录下每个结点往上跳2^k层,深度范围内与这条路径无关的每个连通块的权值和的平方之和。
  然后询问的时候直接倍增往上跳即可,注意跳最后一层的时候要特判一下。

  1 #include<cstdio>
  2 #include<cctype>
  3 #include<vector>
  4 typedef long long int64;
  5 inline int getint() {
  6     register char ch;
  7     while(!isdigit(ch=getchar()));
  8     register int x=ch^'0';
  9     while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
 10     return x;
 11 }
 12 const int N=200001;
 13 int w[N],sum[N],par[N],dep[N];
 14 std::vector<int> e[N];
 15 inline void add_edge(const int &u,const int &v) {
 16     e[u].push_back(v);
 17     e[v].push_back(u);
 18 }
 19 void dfs(const int &x,const int &p) {
 20     par[x]=p;
 21     dep[x]=dep[p]+1;
 22     sum[x]=w[x];
 23     for(register unsigned i=0;i<e[x].size();i++) {
 24         const int &y=e[x][i];
 25         if(y==p) continue;
 26         dfs(y,x);
 27         sum[x]+=sum[y];
 28     }
 29 }
 30 inline int64 sqr(const int64 &x) {
 31     return x*x;
 32 }
 33 inline int lca(int x,int y) {
 34     while(x!=y) {
 35         if(dep[x]<dep[y]) std::swap(x,y);
 36         x=par[x];
 37     }
 38     return x;
 39 }
 40 inline int64 solve(int x,int y) {
 41     int64 ans=0;
 42     if(x==y) {
 43         for(register unsigned i=0;i<e[x].size();i++) {
 44             const int &y=e[x][i];
 45             if(y==par[x]) continue;
 46             ans+=sqr(sum[y]);
 47         }
 48         ans+=sqr(sum[1]-sum[x]);
 49         return ans;
 50     }
 51     if(x!=lca(x,y)) {
 52         for(register unsigned i=0;i<e[x].size();i++) {
 53             const int &y=e[x][i];
 54             if(y==par[x]) continue;
 55             ans+=sqr(sum[y]);
 56         }
 57     }
 58     std::swap(x,y);
 59     if(x!=lca(x,y)) {
 60         for(register unsigned i=0;i<e[x].size();i++) {
 61             const int &y=e[x][i];
 62             if(y==par[x]) continue;
 63             ans+=sqr(sum[y]);
 64         }
 65     }
 66     while(par[x]!=par[y]&&x!=y) {
 67         if(dep[x]<dep[y]) std::swap(x,y);
 68         for(register unsigned i=0;i<e[par[x]].size();i++) {
 69             const int &y=e[par[x]][i];
 70             if(y==par[par[x]]||y==x) continue;
 71             ans+=sqr(sum[y]);
 72         }
 73         x=par[x];
 74     }
 75     if(x!=y) {
 76         for(register unsigned i=0;i<e[par[x]].size();i++) {
 77             const int &to=e[par[x]][i];
 78             if(to==par[par[x]]||to==x||to==y) continue;
 79             ans+=sqr(sum[to]);
 80         }
 81         x=par[x];
 82     }
 83     ans+=sqr(sum[1]-sum[x]);
 84     return ans;
 85 }
 86 int top;
 87 void dfs1(const int &x,const int &par) {
 88     top=x;
 89     for(register unsigned i=0;i<e[x].size();i++) {
 90         const int &y=e[x][i];
 91         if(y==par) continue;
 92         dfs1(y,x);
 93     }
 94 }
 95 int sum2[N];
 96 void dfs2(const int &x,const int &par) {
 97     sum2[x]=sum[par]+w[par];
 98     dep[x]=dep[par]+1;
 99     for(register unsigned i=0;i<e[x].size();i++) {
100         const int &y=e[x][i];
101         if(y==par) continue;
102         dfs2(y,x);
103         sum[x]=sum[y]+w[y];
104     }
105 }
106 int main() {
107     int n=getint(),m=getint();
108     for(register int i=1;i<=n;i++) {
109         w[i]=getint();
110     }
111     for(register int i=1;i<n;i++) {
112         add_edge(getint(),getint());
113     }
114     if(n%10==3) {
115         dfs1(1,0);
116         dfs2(top,0);
117         while(m--) {
118             int x=getint(),y=getint();
119             if(dep[x]<dep[y]) std::swap(x,y);
120             printf("%lld\n",sqr(sum[x])+sqr(sum2[y]));
121         }
122         return 0;
123     }
124     dfs(1,0);
125     while(m--) {
126         printf("%lld\n",solve(getint(),getint()));
127     }
128     return 0;
129 }

 

posted @ 2017-10-23 15:25  skylee03  阅读(169)  评论(0编辑  收藏  举报