BZOJ_3653_谈笑风生_树状数组
BZOJ_3653_谈笑风生_树状数组
Description
设T 为一棵有根树,我们做如下的定义:
? 设a和b为T 中的两个不同节点。如果a是b的祖先,那么称“a比b不知道
高明到哪里去了”。
? 设a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定
常数x,那么称“a 与b 谈笑风生”。
给定一棵n个节点的有根树T,节点的编号为1 到 n,根节点为1号节点。你需
要回答q 个询问,询问给定两个整数p和k,问有多少个有序三元组(a;b;c)满足:
1. a、b和 c为 T 中三个不同的点,且 a为p 号节点;
2. a和b 都比 c不知道高明到哪里去了;
3. a和b 谈笑风生。这里谈笑风生中的常数为给定的 k。
Input
第一行含有两个正整数n和q,分别代表有根树的点数与询问的个数。
接下来n - 1行,每行描述一条树上的边。每行含有两个整数u和v,代表在节点u和v之间有一条边。
接下来q行,每行描述一个操作。第i行含有两个整数,分别表示第i个询问的p和k。
1<=P<=N
1<=K<=N
N<=300000
Q<=300000
Output
输出 q 行,每行对应一个询问,代表询问的答案。
Sample Input
5 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3
Sample Output
3
1
3
1
3
b有位置两种可能,a的祖先和a的子树。
如果b是a的祖先,则c只能选择a子树中除了a的一个。
如果b是a的子树中的一个点,c只能选择b子树除了b的一个。
于是问题转化为子树里距离小于等于k的子树大小和。
两个限制条件:子树内和深度小于等于一个值。
把每个点的dfs序位置和深度看成两个坐标,转化为二维数点,同树状数组解决。
代码:
#include <stdio.h> #include <string.h> #include <algorithm> using namespace std; #define N 300050 typedef long long ll; int head[N],to[N<<1],nxt[N<<1],cnt,n,m; int dep[N],siz[N],dfn[N],S[N],son[N]; ll c[N],ans[N]; inline void add(int u,int v) { to[++cnt]=v; nxt[cnt]=head[u]; head[u]=cnt; } void fix(int x,int v) { for(;x<=n;x+=x&(-x)) c[x]+=v; } ll inq(int x) { ll re=0; for(;x;x-=x&(-x)) re+=c[x]; return re; } struct QAQ { int p,d,id,opt; }a[N<<1]; bool cmp(const QAQ &x,const QAQ &y) { if(x.p==y.p) return x.opt<y.opt; return x.p<y.p; } void dfs(int x,int y) { int i; S[++S[0]]=x; dfn[x]=S[0]; dep[x]=dep[y]+1; siz[x]=1; for(i=head[x];i;i=nxt[i]) if(to[i]!=y) { dfs(to[i],x); siz[x]+=siz[to[i]]; } son[x]=S[0]; } int main() { scanf("%d%d",&n,&m); int i,x,y; for(i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1,0); int tot=0; for(i=1;i<=m;i++) { scanf("%d%d",&x,&y); ans[i]=1ll*min(y,dep[x]-1)*(siz[x]-1); int depp=min(dep[x]+y,n); a[++tot].p=dfn[x]; a[tot].opt=-1; a[tot].d=depp; a[tot].id=i; a[++tot].p=son[x]; a[tot].opt=1; a[tot].d=depp; a[tot].id=i; } sort(a+1,a+tot+1,cmp); int now=0; for(i=1;i<=tot;i++) { while(now<=n&&now<a[i].p) now++,fix(dep[S[now]],siz[S[now]]-1); ans[a[i].id]+=a[i].opt*inq(a[i].d); } for(i=1;i<=m;i++) printf("%lld\n",ans[i]); }