【bzoj3282】Tree

*题目描述:
给定N个点以及每个点的权值,要你处理接下来的M个操作。操作有4种。操作从0到3编号。点从1到N编号。
0:后接两个整数(x,y),代表询问从x到y的路径上的点的权值的xor和。保证x到y是联通的。
1:后接两个整数(x,y),代表连接x到y,若x到Y已经联通则无需连接。
2:后接两个整数(x,y),代表删除边(x,y),不保证边(x,y)存在。
3:后接两个整数(x,y),代表将点X上的权值变成Y。
*输入:
第1行两个整数,分别为N和M,代表点数和操作数。
第2行到第N+1行,每行一个整数,整数在[1,10^9]内,代表每个点的权值。
第N+2行到第N+M+1行,每行三个整数,分别代表操作类型和操作所需的量。
*输出:
对于每一个0号操作,你须输出X到Y的路径上点权的Xor和。
*样例输入:
3 3
1
2
3
1 1 2
0 1 2
0 1 1
*样例输出:
3
1
*题解:
lct模板题,每个节点维护一下preferred path的xor和。然后每次把路径搞出来,直接查询即可。
*代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>

#ifdef WIN32
    #define LL "%I64d"
#else
    #define LL "%lld"
#endif

#ifdef CT
    #define debug(...) printf(__VA_ARGS__)
    #define setfile() 
#else
    #define debug(...)
    #define filename ""
    #define setfile() freopen(filename".in", "r", stdin); freopen(filename".out", "w", stdout);
#endif

#define R register
#define getc() (S == T && (T = (S = B) + fread(B, 1, 1 << 15, stdin), S == T) ? EOF : *S++)
#define dmax(_a, _b) ((_a) > (_b) ? (_a) : (_b))
#define dmin(_a, _b) ((_a) < (_b) ? (_a) : (_b))
#define cmax(_a, _b) (_a < (_b) ? _a = (_b) : 0)
#define cmin(_a, _b) (_a > (_b) ? _a = (_b) : 0)
char B[1 << 15], *S = B, *T = B;
inline int FastIn()
{
    R char ch; R int cnt = 0; R bool minus = 0;
    while (ch = getc(), (ch < '0' || ch > '9') && ch != '-') ;
    ch == '-' ? minus = 1 : cnt = ch - '0';
    while (ch = getc(), ch >= '0' && ch <= '9') cnt = cnt * 10 + ch - '0';
    return minus ? -cnt : cnt;
}
#define maxn 300010
int val[maxn];
struct Node *null;
struct Node
{
    bool rev;
    Node *ch[2], *fa;
    int xorsum, id;
    inline void update()
    {
        xorsum = val[id] ^ ch[0] -> xorsum ^ ch[1] -> xorsum;
    }
    inline bool type()
    {
        return fa -> ch[1] == this;
    }
    inline bool check()
    {
        return fa -> ch[type()] == this;
    }
    inline void set_rev()
    {
        std::swap(ch[0], ch[1]);
        rev ^= 1;
    }
    inline void pushdown()
    {
        if (rev)
        {
            ch[0] -> set_rev();
            ch[1] -> set_rev();
            rev = 0;
        }
    }
    void pushdownall()
    {
        if (check())
            fa -> pushdownall();
        pushdown();
    }
    inline void rotate()
    {
        R Node *f = fa;
        R bool d = type();
        (fa = f -> fa), f -> check() ? fa -> ch[f -> type()] = this : 0;
        (f -> ch[d] = ch[!d]) != null ? ch[!d] -> fa = f : 0;
        (ch[!d] = f) -> fa = this;
        f -> update();
    }
    inline void splay(R bool need = 1)
    {
        if (need) pushdownall();
        for (; check(); rotate())
            if (fa -> check())
                (type() != fa -> type() ? this : fa) -> rotate();
        update();
    }
    inline Node *access()
    {
        R Node *i = this, *j = null;
        for (; i != null; i = (j = i) -> fa)
        {
            i -> splay();
            i -> ch[1] = j;
            i -> update();
        }
        return j;
    }
    inline void make_root()
    {
        access();
        splay(0);
        set_rev();
    }
    inline void link(R Node *that)
    {
        make_root();
        fa = that;
    }
    inline void cut(R Node *that)
    {
        make_root();
        that -> access();
        splay(0);
        if (ch[1] == that)
            that -> fa = ch[1] = null;
    }
    inline bool find(R Node *that)
    {
        access();
        splay();
        while (that -> fa != null)
            that = that -> fa;
        return that == this;
    }
}mem[maxn];
int main()
{
//  setfile();
    R int n = FastIn(), m = FastIn();
    null = mem;
    null -> fa = null -> ch[0] = null -> ch[1] = null;
    null -> xorsum = null -> id = 0;
    for (R int i = 1; i <= n; ++i)
    {
        val[i] = FastIn();
        (mem + i) -> fa = (mem + i) -> ch[0] = (mem + i) -> ch[1] = null;
        (mem + i) -> id = i;
    }
    for (R int i = 1; i <= m; ++i)
    {
        R int opt = FastIn(), x = FastIn(), y = FastIn();
        if (opt == 0)
        {
            (mem + x) -> make_root();
            (mem + y) -> access();
            (mem + y) -> splay(0);
            printf("%d\n", (mem + y) -> xorsum );
        }
        else if (opt == 1)
        {
            if (!(mem + x) -> find(mem + y))
                (mem + x) -> link(mem + y);
        }
        else if (opt == 2)
        {
            (mem + x) -> cut(mem + y);
        }
        else
        {
            (mem + x) -> access();
            (mem + x) -> splay(0);
            val[x] = y;
            (mem + x) -> update();
        }
    }
    return 0;
}
/*
3 3 
1
2
3
1 1 2
0 1 2 
0 1 1
*/
posted @ 2016-05-12 21:22  cot  阅读(157)  评论(0编辑  收藏  举报