HDU5877 Weak Pair dfs + 线段树/树状数组 + 离散化
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5877
题意:
weak pair的要求:
1.u是v的祖先(注意不一定是父亲)
2.val[u]*val[k] <=k;
题解:
1.将val(以及k/val)离散化,可用map 或者 用数组。 只要能将val与树状数组或线段树的下标形成映射就可以了。
2.从根节点开始搜索(题目中的树,不是线段树或树状数组的树),先统计当前节点与祖先能形成多少对weak pair,然后将其插入到树状数组或线段树中。
3.递归其子树。递归完子树后,再把当前节点从树状数组或线段树中删去。因为:根据递归的特性,如果不删除,这个值将会残留在c数组, 那么他的堂兄弟,堂叔伯,堂侄子等(后一步 递归的)会误认为这个值是他的祖先的。所以要及时删除。
类似的题(边查询边更新):http://blog.csdn.net/dolfamingo/article/details/71001021
树状数组(map离散):
#include<cstdio>//hdu5877 树状数组 map离散 dfs #include<cstring> #include<cstdlib> #include<cmath> #include<algorithm> #include<map> #include<vector> #define LL long long #define INF 2e18 using namespace std; LL sval[200020], val[200020]; int n,fa[100010],c[200010]; map<LL,int> m; vector<int> son[100010]; LL k,ans; int lowbit(int x) { return x & (-x); } void add(int x, int d) { for(; x<=2*n; x += lowbit(x)) { c[x] += d; } } int sumc(int x) { int s = 0; for(;x>0; x -= lowbit(x)) { s += c[x]; } return s; } void dfs(int rt)//c数组中的下标与val[i],k/val[i]映射 且k/v[i]的下标-i的下边等于n(自己定) { ans += sumc(m[val[n+rt]]);//统计<=k/val[rt]的个数,为什么不直接 m[k/val[rt]]? 因为val[rt]可能为0 add(m[val[rt]],1); for(int i = 0; i<son[rt].size(); i++) { dfs(son[rt][i]); } add(m[val[rt]],-1); } int main() { int T; scanf("%d",&T); while(T--) { scanf("%d%lld",&n,&k); for(int i = 1; i<=n; i++) { //存k/val[i]是因为 val[i]*(k/val[i])<= k. 到时可以直接从树状数组中统计<=k/val[i]的个数 scanf("%lld",&val[i]); if(val[i]) val[n+i] = k/val[i]; else //0为特殊情况 val[n+i] = INF; } //sval的作用是将v值按从小到大,一一与c数组的下标形成映射 for(int i = 1; i<=2*n; i++) sval[i] = val[i]; sort(sval+1,sval+2*n+1); int cnt = 0; m.clear(); for(int i = 1; i<=2*n; i++) { //map的作用是将v值与c数组的下标形成映射 if(!m[sval[i]]) m[sval[i]] = ++cnt; } for(int i = 1; i<=n; i++) fa[i] = 0, son[i].clear(); for(int i = 1,u,v; i<n; i++) { scanf("%d%d",&u,&v); son[u].push_back(v); fa[v] = u; } ans = 0; memset(c,0,sizeof(c)); for(int i = 1; i<=n; i++) { if(!fa[i]) { dfs(i); break; } } printf("%lld\n",ans); } return 0; }
线段树(map离散):
#include<cstdio>//hdu5877 线段树 map离散 dfs #include<cstring>//注意区分题目的树和线段树的树 #include<cstdlib> #include<cmath> #include<algorithm> #include<vector> #include<map> #define LL long long #define INF 2e18 using namespace std; int n,len;//n是题目给出的树的结点个数,len是线段树的线段长度。 LL k,ans; LL val[200100],sval[200100];//val记录原始值,sval记录经过排序,删除重复操作的值,用于线段树的操作 int fa[100100],sum[800100]; vector<int>son[100100]; map<LL,int>m; int query(int root, int le, int ri, int x, int y) { if(x<=le && y>=ri) return sum[root]; int mid = (le+ri)/2, ret = 0; if(x<=mid) ret += query(root*2,le,mid,x,y); if(y>=mid+1) ret += query(root*2+1,mid+1,ri,x,y); return ret; } void update(int root, int le, int ri, int pos, int d) { if(le==ri) { sum[root] += d; return; } int mid = (le+ri)/2; if(pos<=mid) update(root*2,le,mid,pos,d); else update(root*2+1,mid+1,ri,pos,d); sum[root] = sum[root*2] + sum[root*2+1]; } void dfs(int rt) { int last = m[val[n+rt]]; int pos = m[val[rt]]; ans += query(1,1,len,1,last); update(1,1,len,pos,1); for(int i = 0; i<son[rt].size(); i++) { dfs(son[rt][i]); } update(1,1,len,pos,-1); } int main() { int T; scanf("%d",&T); while(T--) { scanf("%d%lld",&n,&k); for(int i = 1; i<=n; i++) { scanf("%lld",&val[i]); if(val[i]) val[n+i] = k/val[i]; else val[n+i] = INF; } for(int i = 1; i<=2*n; i++) sval[i] = val[i]; sort(sval+1,sval+2*n+1); m.clear(); len = 0; for(int i = 1; i<=2*n; i++) { if(!m[sval[i]]) m[sval[i]] = ++len; } for(int i = 1; i<=n; i++) fa[i] = 0, son[i].clear(); for(int i = 1,u,v; i<n; i++) { scanf("%d%d",&u,&v); son[u].push_back(v); fa[v] = u; } ans = 0; memset(sum,0,sizeof(sum)); for(int i = 1; i<=n; i++) { if(!fa[i]) { dfs(i); break; } } printf("%lld\n",ans); } return 0; }
线段树(数组离散):
#include<cstdio>//hdu5877 线段树 dfs 普通数组进行离散 #include<cstring>//注意区分题目的树和线段树的树 #include<cstdlib> #include<cmath> #include<algorithm> #include<vector> #define LL long long #define INF 2e18 using namespace std; int n,m;//n是题目给出的树的结点个数,m是线段树的线段长度。 LL k,ans; LL val[200100],sval[200100];//val记录原始值,sval记录经过排序,删除重复操作的值,用于线段树的操作 int fa[100100],sum[800100]; vector<int>son[100100]; int query(int root, int le, int ri, int x, int y) { if(x<=le && y>=ri) return sum[root]; int mid = (le+ri)/2, ret = 0; if(x<=mid) ret += query(root*2,le,mid,x,y); if(y>=mid+1) ret += query(root*2+1,mid+1,ri,x,y); return ret; } void update(int root, int le, int ri, int pos, int d) { if(le==ri) { sum[root] += d; return; } int mid = (le+ri)/2; if(pos<=mid) update(root*2,le,mid,pos,d); else update(root*2+1,mid+1,ri,pos,d); sum[root] = sum[root*2] + sum[root*2+1]; } void dfs(int rt) { int last = lower_bound(sval+1, sval+m+1,val[n+rt]) - sval; int pos = lower_bound(sval+1,sval+m+1,val[rt]) - sval; ans += query(1,1,m,1,last); update(1,1,m,pos,1); for(int i = 0; i<son[rt].size(); i++) { dfs(son[rt][i]); } update(1,1,m,pos,-1); } int main() { int T; scanf("%d",&T); while(T--) { scanf("%d%lld",&n,&k); for(int i = 1; i<=n; i++) { scanf("%lld",&val[i]); if(val[i]) val[n+i] = k/val[i]; else val[n+i] = INF; } for(int i = 1; i<=2*n; i++) sval[i] = val[i]; sort(sval+1,sval+2*n+1); m = unique(sval+1,sval+2*n+1) - (sval+1); for(int i = 1; i<=n; i++) fa[i] = 0, son[i].clear(); for(int i = 1,u,v; i<n; i++) { scanf("%d%d",&u,&v); son[u].push_back(v); fa[v] = u; } ans = 0; memset(sum,0,sizeof(sum)); for(int i = 1; i<=n; i++) { if(!fa[i]) { dfs(i); break; } } printf("%lld\n",ans); } return 0; }