poj 3237 Tree

    就是简单的树链剖分,但标记下传的时候一定要 ^1 而不能直接 = 1,我竟然WA在这么逗比的错误上不如一头撞死……

    上代码:

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#define N 1100000
#define inf 0x7f7f7f7f
using namespace std;

struct sss
{
    int minnum, maxnum;
    int push;
}t[N*4];
int n, nowplace, bianp[N];
int p[N], next[N*2], v[N*2], c[N*2], bnum;
int fa[N], son[N], siz[N], deep[N], top[N], w[N];

void build_tree(int now, int l, int r)
{
    t[now].minnum = inf; t[now].maxnum = -inf; t[now].push = 0;
    if (l == r) return;
    int mid = (l+r)/2;
    build_tree(now*2, l, mid); build_tree(now*2+1, mid+1, r);
}

void downdate(int now)
{
    if (!t[now].push) return; t[now].push = 0;
    t[now*2].push ^= 1; t[now*2+1].push ^= 1;
    swap(t[now*2].maxnum, t[now*2].minnum);
    swap(t[now*2+1].maxnum, t[now*2+1].minnum);
    t[now*2].maxnum *= -1; t[now*2].minnum *= -1;
    t[now*2+1].maxnum *= -1; t[now*2+1].minnum *= -1;
}

void update(int now)
{
    t[now].maxnum = max(t[now*2].maxnum, t[now*2+1].maxnum);
    t[now].minnum = min(t[now*2].minnum, t[now*2+1].minnum);
}

void addbian(int x, int y)
{
    bnum++; next[bnum] = p[x]; p[x] = bnum; v[bnum] = y;
    bnum++; next[bnum] = p[y]; p[y] = bnum; v[bnum] = x;
}

void dfs_1(int now, int nowfa, int nowdeep)
{
    int k = p[now]; fa[now] = nowfa; deep[now] = nowdeep;
    int maxson = 0; son[now] = 0; siz[now] = 1;
    while (k)
    {
        if (v[k] != nowfa)
        {
            bianp[(k+1)/2] = v[k];
            dfs_1(v[k], now, nowdeep+1);
            siz[now] += siz[v[k]];
            if (siz[v[k]] > maxson)
            {
                maxson = siz[v[k]];
                son[now] = v[k];
            }
        }
        k = next[k];
    }
}

void dfs_2(int now, int nowfa, int nowtop)
{
    int k = p[now]; top[now] = nowtop; w[now] = ++nowplace;
    if (son[now]) dfs_2(son[now], now, nowtop);
    while (k)
    {
        if (v[k] != nowfa && v[k] != son[now])
            dfs_2(v[k], now, v[k]);
        k = next[k];
    }
}

int task(int now, int l, int r, int al, int ar)
{
    if (al <= l && r <= ar)    return t[now].maxnum;
    int mid = (l+r)/2, ans = -inf;
    downdate(now);
    if (al <= mid) ans = task(now*2, l, mid, al, ar);
    if (ar > mid) ans = max(ans, task(now*2+1, mid+1, r, al, ar));
    update(now); return ans;
}

void tneg(int now, int l, int r, int tl, int tr)
{
    if (tl <= l && r <= tr)
    {
        downdate(now);
        swap(t[now].maxnum, t[now].minnum);
        t[now].maxnum *= -1; t[now].minnum *= -1;
        t[now].push ^= 1; return;
    }
    int mid = (l+r)/2;
    downdate(now);
    if (tl <= mid) tneg(now*2, l, mid, tl, tr);
    if (tr > mid) tneg(now*2+1, mid+1, r, tl, tr);
    update(now); return;
}

void chan(int now, int l, int r, int cplace, int cnum)
{
    if (l == r)
    {
        t[now].maxnum = t[now].minnum = cnum;
        return;
    }
    int mid = (l+r)/2;
    downdate(now);
    if (cplace <= mid) chan(now*2, l, mid, cplace, cnum);
    else chan(now*2+1, mid+1, r, cplace, cnum);
    update(now); return;
}

void neg(int u, int v)
{
    int f1 = top[u], f2 = top[v];
    if (deep[f1] < deep[f2]) { swap(f1, f2); swap(u, v); }
    if (f1 == f2)
    {
        if (u == v) return;
        if (w[u] > w[v]) swap(u, v);
        tneg(1, 1, n, w[son[u]], w[v]);
        return;
    }
    tneg(1, 1, n, w[f1], w[u]); neg(fa[f1], v);
}

int find(int u, int v)
{
    int f1 = top[u],f2 = top[v];
    if (deep[f1] < deep[f2]) { swap(f1, f2); swap(u, v); }
    if (f1 == f2)
    {
        if (u == v) return -inf;
        if (w[u] > w[v]) swap(u, v);
        return task(1, 1, n, w[son[u]], w[v]);
    }
    int ans = task(1, 1, n, w[f1], w[u]);
    return max(ans, find(fa[f1], v));
}

int main()
{
    int T; scanf("%d", &T);
    while (T--)
    {
        scanf("%d", &n); memset(p, 0, sizeof(p));
        build_tree(1, 1, n); nowplace = 0; bnum = 0;
        for (int i = 1; i < n; ++i)
        {
            int x, y, z; scanf("%d%d%d", &x, &y, &z);
            addbian(x, y); c[i] = z;
        }
        dfs_1(1, 0, 1);
        dfs_2(1, 0, 1);
        for (int i = 1; i < n; ++i)
            chan(1, 1, n, w[bianp[i]], c[i]);
        char s[8];
        while (scanf("%s", s) != EOF)
        {
            if (s[0] == 'D') break;
            int x, y; scanf("%d%d", &x, &y);
            if (s[0] == 'Q') printf("%d\n", find(x, y));
            else if (s[0] == 'C') chan(1, 1, n, w[bianp[x]], y);
            else if (s[0]=='N') neg(x, y);
        }
    }
    return 0;
}

 

posted @ 2014-08-31 15:49  handsomeJian  阅读(122)  评论(0编辑  收藏  举报