sjtu1591 Count On Tree
Description
Crystal家有一棵树。树上有\(n\)个节点,编号由\(1\)到\(n\)(\(1\)号点是这棵树的根),两点之间距离为1当且仅当它们直接相连。每个点都有各自的权值,第\(i\)号节点的权值为\(value_i\)。Crystal现在指着编号为\(x\)的点问,在以点\(x\)为根的子树中,与点\(x\)距离大于等于\(k\)的所有点的点权和是多少。
Input Format
第\(1\)行两个整数\(n,Q\),分别表示树上点的个数和Crystal有\(Q\)个问题。
第\(2\)行,\(n\)个整数,分别表示\(1\)至\(n\)号点的点权。
接下来的\(n - 1\)行,每行两个整数\(u,v\),表示编号为\(u\)的点与编号为\(v\)的点直接相连。
接下来\(Q\)行,每行两个整数\(x,k\),表示询问在以点\(x\)为根的子树中,与点\(x\)距离大于等于为\(k\)的所有点的点权和是多少。
Output Format
\(Q\)行,每行一个整数,表示对第\(i\)个询问的回答。
Sample Input
5 3
1 1 1 1 1
1 2
1 3
3 4
4 5
1 3
1 2
1 1
Sample Output
1
2
4
Hints
对于\(30\%\)的数据,保证\(n \le 1000, k < 1, Q \le 1000\)。
对于\(60\%\)的数据,保证\(n \le 1000, k < 1000, Q \le 1000\)
对于\(80\%\)的数据,保证\(n \le 1000, k < 1000, Q \le 1000000\);
对于最后\(20\%\)的数据,保证\(n \le 50000, k < 100, Q \le 1000000\);
对于\(100\%\)的数据,保证所有输入数据均为非负整数,且在\(int\)范围内。
这题\(O(NK)\)的做法不难想(用总的减去小于\(K\)的),现在假设\(N,K\)同级怎么做。
首先考虑离线做法,我们可以考虑按照询问最深的深度从小到大一层层加点,答案还是用总的减去小于\(K\)的。
再考虑在线所做法,我们可以先处理出dfs序和子树和,然后对于树的每层开一个vector,vector中记录该层点的编号,按dfs序排序。对于每个询问\(x,k\),我们只需要跳到\(dep[x]+k\)层的vector中,找到在\(x\)子树中的点,且一定是段连续区间,二分即可。现在只需要对该区间求子树和的和即可。
代码是\(O(NK)\)的
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
#define maxn (50010)
int cnt = 1,side[maxn],toit[maxn*2],next[maxn*2],val[maxn],N,Q,mxk;
int tx[maxn*20],tk[maxn*20],num[20],len; ll sum[maxn]; vector <ll> res[maxn];
inline int read()
{
char ch; int f = 1,ret = 0;
do ch = getchar(); while (!(ch >= '0'&&ch <= '9')&&ch != '-');
if (ch == '-') f = -1,ch = getchar();
do ret = ret*10+ch-'0',ch = getchar(); while (ch >= '0'&&ch <= '9');
return ret*f;
}
inline void add(int a,int b) { next[++cnt] = side[a]; side[a] = cnt; toit[cnt] = b; }
inline void ins(int a,int b) { add(a,b); add(b,a); }
inline void dfs(int now,int fa)
{
for (int i = 0;i <= mxk;++i) res[now].push_back(val[now]);
sum[now] = val[now];
for (int i = side[now];i;i = next[i])
{
if (toit[i] == fa) continue;
dfs(toit[i],now);
sum[now] += sum[toit[i]];
for (int j = 0;j < mxk;++j)
res[now][j+1] += res[toit[i]][j];
}
}
inline void print(ll a)
{
do num[++len] = a%10,a /= 10; while (a);
while (len) putchar('0'+num[len--]);
puts("");
}
int main()
{
//freopen("a.in","r",stdin);
//freopen("a.out","w",stdout);
N = read(); Q = read();
for (int i = 1;i <= N;++i) val[i] = read();
for (int i = 1;i < N;++i) ins(read(),read());
for (int i = 1;i <= Q;++i) tx[i] = read(),tk[i] = read(),mxk = max(mxk,tk[i]);
dfs(1,0);
// print(123456LL);
// print(0LL);
// print(12LL);
for (int i = 1;i <= Q;++i)
{
if (!tk[i]) //cout << sum[tx[i]] << endl;
print(sum[tx[i]]);
else //cout << sum[tx[i]]-res[tx[i]][tk[i]-1] << endl;
print(sum[tx[i]]-res[tx[i]][tk[i]-1]);
}
//fclose(stdin); fclose(stdout);
return 0;
}