bzoj3991 [SDOI2015]寻宝游戏

3991: [SDOI2015]寻宝游戏

Time Limit: 40 Sec  Memory Limit: 128 MB
Submit: 1600  Solved: 779
[Submit][Status][Discuss]

Description

 小B最近正在玩一个寻宝游戏,这个游戏的地图中有N个村庄和N-1条道路,并且任何两个村庄之间有且仅有一条路径可达。游戏开始时,玩家可以任意选择一个村庄,瞬间转移到这个村庄,然后可以任意在地图的道路上行走,若走到某个村庄中有宝物,则视为找到该村庄内的宝物,直到找到所有宝物并返回到最初转移到的村庄为止。小B希望评测一下这个游戏的难度,因此他需要知道玩家找到所有宝物需要行走的最短路程。但是这个游戏中宝物经常变化,有时某个村庄中会突然出现宝物,有时某个村庄内的宝物会突然消失,因此小B需要不断地更新数据,但是小B太懒了,不愿意自己计算,因此他向你求助。为了简化问题,我们认为最开始时所有村庄内均没有宝物

Input

 第一行,两个整数N、M,其中M为宝物的变动次数。

接下来的N-1行,每行三个整数x、y、z,表示村庄x、y之间有一条长度为z的道路。
接下来的M行,每行一个整数t,表示一个宝物变动的操作。若该操作前村庄t内没有宝物,则操作后村庄内有宝物;若该操作前村庄t内有宝物,则操作后村庄内没有宝物。

Output

 M行,每行一个整数,其中第i行的整数表示第i次操作之后玩家找到所有宝物需要行走的最短路程。若只有一个村庄内有宝物,或者所有村庄内都没有宝物,则输出0。

Sample Input

4 5
1 2 30
2 3 50
2 4 60
2
3
4
2
1

Sample Output

0
100
220
220
280

HINT

 1<=N<=100000

1<=M<=100000
对于全部的数据,1<=z<=10^9

Source

Round 1 感谢yts1999上传

分析:算是比较简单的一题吧.

   有一个比较常见的结论:从一个点出发,经过指定的k个点再回到原点走的路径的长度是这k个点构成的生成树的边权和的两倍.这道题如果把有宝藏的点当作特殊点,就相当于维护m棵虚树.

   每次重新建虚树是肯定不行的,每次只会修改一个点,插入和删除这个点一定会伴随着若干条路径的改变,找到这些路径就好了.画图分析可以知道:这些路径一定是当前操作的点x和dfs序在它前面的最后一个点l组成的或者是和dfs序在它后面的第一个点r组成的.如果l存在,那么这条路径就是lca(l,x)到x,如果r存在路径lca(r,x)到x也会受到影响. 如果l和r同时存在,还要排除(l,r)的影响.

   不是特别好描述......只需要记住每次加点还是删点都会改变若干条路径对答案的贡献,找到这些路径并修改贡献就好了.

   因为要每次要插入一个dfs序,找前驱后继,所以用set来维护.

#include <cstdio>
#include <set>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long ll;
const ll maxn = 200010;
ll n,m,head[maxn],to[maxn],nextt[maxn],w[maxn],tot = 1,deep[maxn],fa[maxn][20];
ll flag[maxn],ans,dep[maxn],pos[maxn],id[maxn],dfs_clock;
set <ll> S;

void add(ll x,ll y,ll z)
{
    w[tot] = z;
    to[tot] = y;
    nextt[tot] = head[x];
    head[x] = tot++;
}

void dfs(ll u,ll faa)
{
    pos[u] = ++dfs_clock;
    id[dfs_clock] = u;
    fa[u][0] = faa;
    dep[u] = dep[faa] + 1;
    for (ll i = head[u]; i; i = nextt[i])
    {
        ll v = to[i];
        if (v == faa)
            continue;
        deep[v] = deep[u] + w[i];
        dfs(v,u);
    }
}

ll lca(ll x,ll y)
{
    if (dep[x] < dep[y])
        swap(x,y);
    for (ll i = 19; i >= 0; i--)
        if (dep[fa[x][i]] >= dep[y])
            x = fa[x][i];
    if (x == y)
        return x;
    for (ll i = 19; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
        {
            x = fa[x][i];
            y = fa[y][i];
        }
    return fa[x][0];
}

ll cal(ll x,ll y)
{
    return deep[x] + deep[y] - 2 * deep[lca(x,y)];
}

void solve()
{
    S.insert(0);
    S.insert(n + 1);
    for (ll i = 1; i <= m; i++)
    {
        ll x;
        scanf("%lld",&x);
        flag[x] ^= 1;
        if (flag[x])
        {
            S.insert(pos[x]);
            ll l = *(--S.find(pos[x]));
            ll r = *(++S.find(pos[x]));
            if (l >= 1 && r <= n)
            {
                ll p1 = lca(x,id[l]);
                ll p2 = lca(x,id[r]);
                if (deep[p1] >= deep[p2])
                    ans += cal(p1,x) * 2;
                else
                    ans += cal(p2,x) * 2;
            }
            else if (l >= 1)
                ans += cal(id[l],x);
            else if(r <= n)
                ans += cal(id[r],x);
        }
        else
        {
            ll l = *(--S.find(pos[x]));
            ll r = *(++S.find(pos[x]));
            if (l >= 1 && r <= n)
            {
                ll p1 = lca(x,id[l]);
                ll p2 = lca(x,id[r]);
                if (deep[p1] >= deep[p2])
                    ans -= cal(p1,x) * 2;
                else
                    ans -= cal(p2,x) * 2;
            }
            else if (l >= 1)
                ans -= cal(id[l],x);
            else if(r <= n)
                ans -= cal(id[r],x);
            S.erase(pos[x]);
        }
        ll l = *(++S.find(0)),r = *(--S.find(n + 1));
        ll temp = 0;
        if (l >= 1 && r <= n)
            temp = cal(id[l],id[r]);
        printf("%lld\n",ans + temp);
    }
}

int main()
{
    scanf("%lld%lld",&n,&m);
    for (ll i = 1; i < n; i++)
    {
        ll x,y,z;
        scanf("%lld%lld%lld",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
    }
    dfs(1,0);
    for (ll j = 1; j <= 19; j++)
        for (ll i = 1; i <= n; i++)
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
    solve();

    return 0;
}

 

posted @ 2018-03-03 21:28  zbtrs  阅读(215)  评论(0编辑  收藏  举报