题目:这里
题意:
给出一个n个结点的树和一个数k,每个结点都有一个权值,问有多少对点(u,v)满足u是v的祖先结点且二者的权值之积小于等于k、
从根结点开始dfs,假设搜的的点的权值是v,我们需要的是在此之前搜的点中小于等于k/v的数的个数,于是用树状数组查询,查完后将v加入树状数组以供下个
点查询,回溯的时候再一个个的删除将其树状数组。由于这个权值和k的范围比较大,所以得先离散化后在加入树状数组。
1 #include<cstdio> 2 #include<cstring> 3 #include<iostream> 4 #include<algorithm> 5 #include<vector> 6 using namespace std; 7 8 typedef long long ll; 9 const int M = 1e5 + 10; 10 int head[M],cas,du[M],has[M],len; 11 ll a[M],b[M],k,sum; 12 bool vis[M]; 13 14 int max(int x,int y){return x>y?x:y;} 15 struct Edge{ 16 int to,next; 17 }edge[M*2]; 18 19 void add(int u,int v) 20 { 21 edge[++cas].next=head[u]; 22 edge[cas].to=v; 23 head[u]=cas; 24 } 25 26 int lowbit(int x) {return x&(-x);} 27 28 void shuadd(int x,int y) 29 { 30 while (x<=len){ 31 has[x]+=y; 32 x+=lowbit(x); 33 } 34 } 35 36 int getsum(int x){ 37 int ans=0; 38 while (x>0){ 39 ans+=has[x]; 40 x-=lowbit(x); 41 } 42 return ans; 43 } 44 45 int sreach(ll x) 46 { 47 int l=1,r=len,ans=1; 48 while (l<=r){ 49 int mid=(l+r)/2; 50 if (b[mid]<=x) ans=max(ans,mid),l=mid+1; 51 else r=mid-1; 52 } 53 return ans; 54 } 55 56 void dfs(int u) 57 { 58 int pos; 59 if (a[u]!=0) pos=sreach(k/a[u]); 60 else pos=len; 61 sum+=getsum(pos); 62 //cout<<sum<<endl; 63 shuadd(sreach(a[u]),1); 64 for (int i=head[u] ; i ; i=edge[i].next) 65 { 66 int v=edge[i].to; 67 if (vis[v]) continue; 68 vis[v]=true; 69 dfs(v);vis[v]=false; 70 shuadd(sreach(a[v]),-1); 71 } 72 } 73 74 int main() 75 { 76 int t; 77 scanf("%d",&t); 78 while (t--){ 79 int n,j=0;cas=0; 80 scanf("%d%I64d",&n,&k); 81 for (int i=1 ; i<=n ; i++) { 82 scanf("%I64d",&a[i]); 83 if (a[i]==0) continue; 84 b[i]=a[i]; 85 } 86 sort(b+1,b+n+1); 87 len=unique(b+1,b+n+1)-b; 88 sort(b+1,b+len);len--; 89 // cout<<len<<endl; 90 memset(du,0,sizeof(du)); 91 memset(vis,false,sizeof(vis)); 92 memset(head,0,sizeof(head)); 93 memset(has,0,sizeof(has)); 94 for (int i=1 ; i<n ; i++){ 95 int u,v; 96 scanf("%d%d",&u,&v); 97 add(u,v); 98 du[v]++; 99 } 100 //cout<<len<<endl; 101 sum=0; 102 for (int i=1 ; i<=n ; i++) 103 { 104 if (du[i]) continue; 105 dfs(i); 106 } 107 printf("%I64d\n",sum); 108 } 109 return 0; 110 }