【bzoj1036】[ZJOI2008]树的统计Count

*题目描述:
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
*输入:
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
*输出:
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
*样例输入:
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
*样例输出:
4
1
2
2
10
6
5
6
5
16
*题解:
树链剖分套线段树。
树链剖分就是两次dfs,第一次处理出每棵子树的重儿子,大小,每个节点的深度;第二次处理出每个节点所属的链的顶部还有每个节点在线段树中的位置。
然后每次查询两点间的信息就是如果两点的不在同一条重链上就由跳到更深的重链的顶部的父亲上,直到两点在同一条重链上为止。此时再查询这两点间的信息。
时间复杂度分析:轻重链剖分的复杂度是log的,每次在线段树上查询是log的,所以单次查询的复杂度是log方的。
*代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>

#ifdef WIN32
    #define LL "%I64d"
#else
    #define LL "%lld"
#endif

#ifdef CT
    #define debug(...) printf(__VA_ARGS__)
    #define setfile() 
#else
    #define debug(...)
    #define filename ""
    #define setfile() freopen(filename".in", "r", stdin); freopen(filename".out", "w", stdout);
#endif

#define R register
#define getc() (S == T && (T = (S = B) + fread(B, 1, 1 << 15, stdin), S == T) ? EOF : *S++)
#define dmax(_a, _b) ((_a) > (_b) ? (_a) : (_b))
#define dmin(_a, _b) ((_a) < (_b) ? (_a) : (_b))
#define cmax(_a, _b) (_a < (_b) ? _a = (_b) : 0)
#define cmin(_a, _b) (_a > (_b) ? _a = (_b) : 0)
char B[1 << 15], *S = B, *T = B;
inline int FastIn()
{
    R char ch; R int cnt = 0; R bool minus = 0;
    while (ch = getc(), (ch < '0' || ch > '9') && ch != '-') ;
    ch == '-' ? minus = 1 : cnt = ch - '0';
    while (ch = getc(), ch >= '0' && ch <= '9') cnt = cnt * 10 + ch - '0';
    return minus ? -cnt : cnt;
}
int M;
#define maxn 1 << 16
struct Edge
{
    int to;
    Edge *next;
}*last[maxn], e[maxn], *ecnt = e;
inline void add(R int a, R int b)
{
    *++ecnt = (Edge) {b, last[a]};
    last[a] = ecnt;
}
int dep[maxn], fa[maxn], size[maxn], son[maxn], top[maxn], pos[maxn], timer;
int v[maxn];
bool vis[maxn];
void dfs1(R int x)
{
    dep[x] = dep[fa[x]] + 1; size[x] = 1; vis[x] = 1;
    for (R Edge *iter = last[x]; iter; iter = iter -> next)
        if (!vis[iter -> to])
        {
            R int pre = iter -> to;
            fa[pre] = x;
            dfs1(pre);
            size[x] += size[pre];
            size[pre] > size[son[x]] ? son[x] = pre : 0;
        }
}
void dfs2(R int x)
{
    vis[x] = 0;
    pos[x] = timer++;
    top[x] = x == son[fa[x]] ? top[fa[x]] : x;
    for (R Edge *iter = last[x]; iter; iter = iter -> next)
        if (iter -> to == son[x])
            dfs2(iter -> to);
    for (R Edge *iter = last[x]; iter; iter = iter -> next)
        if (vis[iter -> to])
            dfs2(iter -> to);
}
inline int getlca(R int a, R int b)
{
    while (top[a] != top[b])
    {
        dep[top[a]] < dep[top[b]] ? b = fa[top[b]] : a = fa[top[a]];
    }
    return dep[a] < dep[b] ? a : b;
}
struct SegmentTree
{
    int sum, mx;
}tr[maxn];
inline void Build()
{
    for (R int i = M - 1; i; --i)
    {
        tr[i].sum = tr[i << 1].sum + tr[i << 1 | 1].sum;
        tr[i].mx = dmax(tr[i << 1].mx, tr[i << 1 | 1].mx);
    }
}
inline void Change(R int pos, R int val)
{
    R int i = pos + M;
    for (tr[i].sum = tr[i].mx = val, i >>= 1; i; i >>= 1)
    {
        tr[i].sum = tr[i << 1].sum + tr[i << 1 | 1].sum;
        tr[i].mx = dmax(tr[i << 1].mx, tr[i << 1 | 1].mx);
    }
}
#define inf 0x7fffffff
inline int Query_max(R int s, R int t)
{
    R int ans = -inf;
    for (s = s + M - 1, t = t + M + 1; s ^ t ^ 1; s >>= 1, t >>= 1)
    {
        if (~ s & 1) cmax(ans, tr[s ^ 1].mx);
        if (t & 1) cmax(ans, tr[t ^ 1].mx);
    }
    return ans;
}
inline int Query_sum(R int s, R int t)
{
    R int ans = 0;
    for (s = s + M - 1, t = t + M + 1; s ^ t ^ 1; s >>= 1, t >>= 1)
    {
        if (~ s & 1) ans += tr[s ^ 1].sum;
        if (t & 1) ans += tr[t ^ 1].sum;
    }
    return ans;
}
inline int getmax(R int a, R int b)
{
    R int ans = -inf;
    while (top[a] != top[b])
    {
        cmax(ans, Query_max(pos[top[a]], pos[a]));
        a = fa[top[a]];
    }
    cmax(ans, Query_max(pos[b], pos[a]));
    return ans;
}
inline int getsum(R int a, R int b)
{
    R int ans = 0;
    while (top[a] != top[b])
    {
        ans += Query_sum(pos[top[a]], pos[a]);
        a = fa[top[a]];
    }
    ans += Query_sum(pos[b], pos[a]);
    return ans;
}
int main()
{
//  setfile();
    R int n = FastIn();
    for (M = 1; M < n; M <<= 1);
    for (R int i = 1; i < n; ++i)
    {
        R int a = FastIn(), b = FastIn();
        add(a, b);
        add(b, a);
    }
    dfs1(1);
    dfs2(1);
    for (R int i = 1; i <= n; ++i)
    {
        v[i] = FastIn();
        tr[pos[i] + M] = (SegmentTree) {v[i], v[i]};
    }
    for (R int i = n; i < M; ++i) tr[i + M] = (SegmentTree) {0, -inf};
    Build();
    for (R int q = FastIn(); q; --q)
    {
        char opt = getc();
        while (opt < 'A' || opt > 'Z') opt = getc();
        if (opt == 'C')
        {
            R int x = FastIn(), y = FastIn();
            Change(pos[x], y);
            v[x] = y;
        }
        else
        {
            opt = getc();
            R int x = FastIn(), y = FastIn();
            R int lca = getlca(x, y);
            if (opt == 'M')
                printf("%d\n", dmax(getmax(x, lca), getmax(y, lca)) );
            else
                printf("%d\n", getsum(x, lca) + getsum(y, lca) - v[lca]);
        }
    }
    return 0;
}
/*
input:
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
output:
4
1
2
2
10
6
5
6
5
16
*/
posted @ 2016-05-07 13:15  cot  阅读(140)  评论(0编辑  收藏  举报