Codeforces Round #425 (Div. 2) - D
题目链接:http://codeforces.com/contest/832/problem/D
题意:给定一棵n个点的树,然后给你q个询问,每个询问为三元组(a,b,c),问你从这三个点中选取一个作为终点,一个作为Misha的起点,一个作为Grisha的起点。然后每天早上Misha从起点到终点所经过的点都是标记为1, 傍晚Grisha从起点到终点所经过的点中带有标记的点的数目最多是多少?
思路:对于每个询问,我们枚举终点(共3种情况),其余两个点作为一个作为M的起点一个作为G的起点,然后问题就是M的起点到终点这条路径的点赋值1,统计G的起点到终点这条路径的1的个数,然后3种情况取个最大值即可。 然后就是经典的树链剖分题目,树剖之后就是区间覆盖+区间查询问题了。 起初用的是线段树,然后终测TLE掉了(可能我写的线段树不够优美,被卡常了),后来换成树状数组来维护区间覆盖,区间查询就AC掉了。
#define _CRT_SECURE_NO_DEPRECATE #include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<string> #include<queue> #include<vector> #include<time.h> #include<cmath> using namespace std; typedef long long int LL; const int MAXN = 1e5 + 24; const int INF = 0x3f3f3f3f; const int mod = 1e9 + 7; int fa[MAXN],top[MAXN],deep[MAXN],num[MAXN],p[MAXN],fp[MAXN],son[MAXN]; int pos,cp[MAXN],n,q; LL bit0[MAXN],bit1[MAXN]; vector<int>edge[MAXN]; void init(){ pos = 1; memset(bit0,0,sizeof(bit0)); memset(bit1,0,sizeof(bit1)); memset(son, -1, sizeof(son)); } void dfs1(int u, int pre, int d){ deep[u] = d; fa[u] = pre; num[u] = 1; for (int i = 0; i < edge[u].size(); i++){ int v = edge[u][i]; if (v != pre){ dfs1(v, u, d + 1); num[u] += num[v]; if (son[u] == -1 || num[v] > num[son[u]] ){ son[u] = v; } } } } void dfs2(int u, int sp){ top[u] = sp; p[u] = pos++; fp[p[u]] = u; if (son[u] == -1){ return; } dfs2(son[u], sp); for (int i = 0; i < edge[u].size(); i++){ int v = edge[u][i]; if (v != fa[u] && v != son[u]){ dfs2(v, v); } } } //BIT void Add(LL *b,int i,LL val){ while (i<=n){ b[i]+=val; i+=i&-i; } } LL Sum(LL *b,int i){ LL s=0; while (i>0){ s+=b[i]; i-=i&-i; } return s; } void Modify(int l,int r,int val){ //区间[l,r] + val //printf("M:%d %d %d\n",l,r,val); Add(bit0,l,-val*(l-1)); Add(bit1,l,val); Add(bit0,r+1,val*r); Add(bit1,r+1,-val); } int Query(int l,int r){ //区间[l,r] 1 的个数 //printf("Q:%d %d\n",l,r); LL res=0; res+=Sum(bit0,r)+1LL*Sum(bit1,r)*r; res-=Sum(bit0,l-1)+1LL*Sum(bit1,l-1)*(l-1); return res; } void solveC(int u, int v,int val){ //修改链 int f1 = top[u], f2 = top[v]; while (f1!=f2){ if (deep[f1] < deep[f2]){ swap(f1, f2); swap(u, v); } Modify(p[f1], p[u], val); u = fa[f1]; f1 = top[u]; } if (deep[u] > deep[v]){ swap(u, v); } Modify(p[u], p[v], val); } int solveQ(int u, int v){ //查询链 int f1 = top[u], f2 = top[v]; int tmp = 0; while (f1 != f2){ if (deep[f1] < deep[f2]){ swap(f1, f2); swap(u, v); } tmp+=Query(p[f1], p[u]); u = fa[f1]; f1 = top[u]; } if (deep[u] > deep[v]){ swap(u, v); } tmp+=Query(p[u], p[v]); return tmp; } int solve(int s, int t, int f){ solveC(s, f, 1); int tmp = solveQ(t, f); solveC(s, f, -1); return tmp; } int main(){ #ifdef kirito freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout); #endif while (~scanf("%d%d",&n,&q)){ init(); for (int i = 1; i <= n; i++){ edge[i].clear(); } for (int i = 2; i <= n; i++){ scanf("%d", &cp[i]); edge[cp[i]].push_back(i); edge[i].push_back(cp[i]); } dfs1(1, 0, 0); dfs2(1, 1); for (int i = 1; i <= q; i++){ int a, b, c,res=0; scanf("%d%d%d", &a, &b, &c); res = max(res, solve(a, b, c)); res = max(res, solve(a, c, b)); res = max(res, solve(b, c, a)); printf("%d\n", res); } } return 0; }