Codeforces Round #425 (Div. 2) - D

 

题目链接:http://codeforces.com/contest/832/problem/D

题意:给定一棵n个点的树,然后给你q个询问,每个询问为三元组(a,b,c),问你从这三个点中选取一个作为终点,一个作为Misha的起点,一个作为Grisha的起点。然后每天早上Misha从起点到终点所经过的点都是标记为1, 傍晚Grisha从起点到终点所经过的点中带有标记的点的数目最多是多少?

思路:对于每个询问,我们枚举终点(共3种情况),其余两个点作为一个作为M的起点一个作为G的起点,然后问题就是M的起点到终点这条路径的点赋值1,统计G的起点到终点这条路径的1的个数,然后3种情况取个最大值即可。 然后就是经典的树链剖分题目,树剖之后就是区间覆盖+区间查询问题了。 起初用的是线段树,然后终测TLE掉了(可能我写的线段树不够优美,被卡常了),后来换成树状数组来维护区间覆盖,区间查询就AC掉了。

#define _CRT_SECURE_NO_DEPRECATE
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<queue>
#include<vector>
#include<time.h>
#include<cmath>
using namespace std;
typedef long long int LL;
const int MAXN = 1e5 + 24;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
int fa[MAXN],top[MAXN],deep[MAXN],num[MAXN],p[MAXN],fp[MAXN],son[MAXN];
int pos,cp[MAXN],n,q;
LL bit0[MAXN],bit1[MAXN];
vector<int>edge[MAXN];
void init(){
    pos = 1; 
    memset(bit0,0,sizeof(bit0));
    memset(bit1,0,sizeof(bit1));
    memset(son, -1, sizeof(son));
}
void dfs1(int u, int pre, int d){
    deep[u] = d; fa[u] = pre; num[u] = 1;
    for (int i = 0; i < edge[u].size(); i++){
        int v = edge[u][i];
        if (v != pre){
            dfs1(v, u, d + 1);
            num[u] += num[v];
            if (son[u] == -1 || num[v] > num[son[u]]
                ){
                son[u] = v;
            }
        }
    }
}
void dfs2(int u, int sp){
    top[u] = sp;  p[u] = pos++; fp[p[u]] = u;
    if (son[u] == -1){
        return;
    }
    dfs2(son[u], sp);
    for (int i = 0; i < edge[u].size(); i++){
        int v = edge[u][i];
        if (v != fa[u] && v != son[u]){
            dfs2(v, v);
        }
    }
}

//BIT
void Add(LL *b,int i,LL val){
    while (i<=n){
        b[i]+=val; i+=i&-i;
    }
}
LL Sum(LL *b,int i){
    LL s=0;
    while (i>0){
        s+=b[i]; i-=i&-i;
    }
    return s;
}
void Modify(int l,int r,int val){ //区间[l,r] + val
    //printf("M:%d %d %d\n",l,r,val);
    Add(bit0,l,-val*(l-1));
    Add(bit1,l,val);
    Add(bit0,r+1,val*r);
    Add(bit1,r+1,-val);
}
int Query(int l,int r){ //区间[l,r] 1 的个数
    //printf("Q:%d %d\n",l,r);
    LL res=0;
    res+=Sum(bit0,r)+1LL*Sum(bit1,r)*r;
    res-=Sum(bit0,l-1)+1LL*Sum(bit1,l-1)*(l-1);
    return res;
}
void solveC(int u, int v,int val){ //修改链
    int f1 = top[u], f2 = top[v];
    while (f1!=f2){
        if (deep[f1] < deep[f2]){
            swap(f1, f2); swap(u, v);
        }
        Modify(p[f1], p[u], val);
        u = fa[f1]; 
        f1 = top[u];
    }
    if (deep[u] > deep[v]){
        swap(u, v);
    }
    Modify(p[u], p[v], val);
}
int solveQ(int u, int v){ //查询链
    int f1 = top[u], f2 = top[v];
    int tmp = 0;
    while (f1 != f2){
        if (deep[f1] < deep[f2]){
            swap(f1, f2); swap(u, v);
        }
        tmp+=Query(p[f1], p[u]);
        u = fa[f1]; f1 = top[u];
    }
    if (deep[u] > deep[v]){
        swap(u, v);
    }
    tmp+=Query(p[u], p[v]);
    return tmp;
}
int solve(int s, int t, int f){
    solveC(s, f, 1);
    int tmp = solveQ(t, f);
    solveC(s, f, -1);
    return tmp;
}
int main(){
#ifdef kirito
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
#endif
    while (~scanf("%d%d",&n,&q)){
        init();
        for (int i = 1; i <= n; i++){
            edge[i].clear();
        }
        for (int i = 2; i <= n; i++){
            scanf("%d", &cp[i]); 
            edge[cp[i]].push_back(i);
            edge[i].push_back(cp[i]);
        }
        dfs1(1, 0, 0); 
        dfs2(1, 1);
        for (int i = 1; i <= q; i++){
            int a, b, c,res=0;
            scanf("%d%d%d", &a, &b, &c);
            res = max(res, solve(a, b, c));
            res = max(res, solve(a, c, b));
            res = max(res, solve(b, c, a));
            printf("%d\n", res);
        }
    }
    return 0;
}

 

posted @ 2017-07-27 20:06  キリト  阅读(151)  评论(0编辑  收藏  举报