2016 大连网赛---Weak Pair(dfs+树状数组)
题目链接
http://acm.split.hdu.edu.cn/showproblem.php?pid=5877
Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k.
Can you find the number of weak pairs in the tree?
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k.
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
题意:输入n,k 表示有n个节点的一棵树,然后输入n个节点的权值和n-1条边,求点对(u,v)的对数,满足:
1、u是v的祖先节点。
2、a[u]*a[v]<=k,a[]是存储权值的数组。
思路:从根节点开始向下深搜,每到一个点时计算sum+=Sum(a[i]),Sum(x)表示大于等于x的个数,然后向树状数组中加入k/a[i],继续递归深搜,退栈时从树状数组中减去k/a[i] ,这样可以保证树状数组中存的一直是一条到根节点的路径值。大题思路如上,这里要做一个离散化的处理,输入的权值<=1e9 k<=1e18 而只有1e5个点,所以可以离散到2*1e5 后处理;
题解中提示用treap计算大于等于x的个数,这样可以不需要进行离散化;
第一次自己做出深搜的题,挺高兴的^_^ 看样子我对深搜有了一点认识了
代码如下:
#include <iostream> #include <algorithm> #include <cstring> #include <cstdio> #include <vector> #include <map> using namespace std; const long long maxn=200003; long long root; long long sum,k; long long in[100005]; vector<long long>g[100005]; long long a[100005]; long long b[200005]; long long c[200005]; map<long long,long long>q; long long Lowbit(long long t) { return t&(t^(t-1)); } void add(long long x,long long t) { while(x > 0) { c[x]+=t; x -= Lowbit(x); } } long long Sum(long long li) { long long s=0; while(li<200005) { s+=c[li]; li=li+Lowbit(li); } return s; } void dfs(long long t) { long long n=g[t].size(); for(long long i=0;i<n;i++) { long long v=g[t][i]; sum+=(long long)Sum(q[a[v]]); if(a[v]==0) add(maxn,1); else add(q[k/a[v]],1); dfs(v); if(a[v]==0) add(maxn,-1); else add(q[k/a[v]],-1); } } int main() { long long T,N; scanf("%lld",&T); while(T--) { q.clear(); memset(in,0,sizeof(in)); memset(c,0,sizeof(c)); memset(b,0,sizeof(b)); scanf("%lld%lld",&N,&k); for(long long i=1;i<=N;i++) { scanf("%lld",&a[i]); b[2*i-2]=a[i]; if(a[i]!=0) b[2*i-1]=k/a[i]; g[i].clear(); } sort(b,b+2*N); long long tot=0,pre=-1; for(long long i=0;i<2*N;i++) { if(b[i]!=pre) { pre=b[i]; q[pre]=++tot; } } for(long long i=0;i<N-1;i++) { long long aa,bb; scanf("%lld%lld",&aa,&bb); g[aa].push_back(bb); in[bb]++; } for(long long i=1;i<=N;i++) if(in[i]==0) { root=i; break; } sum=0; if(a[root]==0) add(maxn,1); else add(q[k/a[root]],1); dfs(root); printf("%lld\n",sum); } return 0; }