P2486 [SDOI2011]染色 - 树剖

先考虑在链上的做法,线段树子节点合并的时候减去重复算的颜色,然后树剖,跨越轻链的时候可以单点查询找颜色(我想的是维护颜色。。。但是明显这个点也是在线段树上的,直接单点查询就好了,想想问题的本质是什么,有没有不那么麻烦的做法)
所以要考虑好一些问题的区间可维护性,然后选用适当的数据结构
比如区间最大公约数,显然不同的区间可以合并维护,可以用线段树或者st表在logn时间内维护好,但是st表常数更小
容易写出错的地方:
应该写成左右子节点:query(w << 1 | 1, mid+1, r, x, y);,然而w<<1|1写成了w。。。
树剖建树应该在两次dfs之后,并且用rnk数组映射dfn序为编号
路径查询/修改的时候比较的应该是两个点的dep[top[x]],比较的是深度!

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
using namespace std;
#define debug(x) cerr << #x << "=" << x << endl;
const int MAXN = 200000 + 10;
int n,m,last[MAXN],cnt_div,edge_tot,col[MAXN],dep[MAXN],siz[MAXN],dfn[MAXN];
int top[MAXN],rnk[MAXN],fa[MAXN],son[MAXN];
struct Edge{
    int u, v, to;
    Edge(){}
    Edge(int u, int v, int to) : u(u), v(v), to(to) {}
}e[MAXN * 2];
inline void add(int u, int v) {
    e[++edge_tot] = Edge(u, v, last[u]);
    last[u] = edge_tot;
}
struct segment{
    int sum, siz, coll, colr, add;
}tr[MAXN*4];

void update(int w) {
    int le = w << 1, ri = w << 1 | 1;
    tr[w].sum = tr[le].sum + tr[ri].sum;
    tr[w].coll = tr[le].coll, tr[w].colr = tr[ri].colr;
    if(tr[le].colr == tr[ri].coll) tr[w].sum--;
}
void down(int w) {
    int le = w << 1, ri = w << 1 | 1, add = tr[w].add;
    if(add == -1) return;
    tr[w].add = -1;
    tr[le].add = tr[le].coll = tr[le].colr = add;
    tr[ri].add = tr[ri].coll = tr[ri].colr = add;
    tr[le].sum = tr[ri].sum = 1;
    tr[w].add = -1;
}

void build(int w, int l, int r) {
    tr[w].add = -1;
    if(l == r) {
        tr[w].sum = 1;
        tr[w].coll = col[rnk[l]], tr[w].colr = col[rnk[l]];
        return;
    }
    int mid = (l + r) >> 1;
    build(w<<1, l, mid);
    build(w<<1|1, mid+1, r);
    update(w);
}

int query(int w, int l, int r, int x, int y) {
    if(x <= l && r <= y) {
        return tr[w].sum;
    }
    down(w);
    int mid = (l + r) >> 1;
    int sum = 0;
    int le = w << 1, ri = w << 1 | 1;
    if(x <= mid) sum += query(le, l, mid, x, y);
    if(y > mid) sum += query(ri, mid+1, r, x, y);
    if(x <= mid && y > mid)
        if(tr[le].colr == tr[ri].coll) sum--;
    return sum;
}

int sin_que(int w, int l, int r, int pos) {
    if(l == r) {
        return tr[w].coll;
    }
    down(w);
    int mid = (l + r) >> 1;
    int le = w << 1, ri = w << 1 | 1;
    if(pos <= mid) return sin_que(le, l, mid, pos);
    return sin_que(ri, mid+1, r, pos);
}

void change(int w, int l, int r, int x, int y, int k) {
    if(x <= l && r <= y) {
        tr[w].add = tr[w].coll = tr[w].colr = k;
        tr[w].sum = 1;
        return;
    }
    down(w);
    int mid = (l + r) >> 1;
    int le = w << 1, ri = w << 1 | 1;
    if(x <= mid) change(le, l, mid, x, y, k);
    if(y > mid) change(ri, mid+1, r, x, y, k);
    update(w); 
}

void dfs_init(int x, int depth, int f) {
    dep[x] = depth;
    siz[x] = 1;
    fa[x] = f;
    for(int i=last[x]; i; i=e[i].to) {
        int v = e[i].v;
        if(v == f) continue;
        dfs_init(v, depth+1, x);
        if(siz[v] > siz[son[x]]) son[x] = v;
        siz[x] += siz[v];
    }
}

void dfs_top(int x, int t) {
    top[x] = t;
    dfn[x] = ++cnt_div;
    rnk[cnt_div] = x;
    if(son[x]) dfs_top(son[x], t);
    for(int i=last[x]; i; i=e[i].to) {
        int v = e[i].v;
        if(v == fa[x] || v == son[x]) continue;
        dfs_top(v, v);
    }
}


int query_path(int s, int t) {
    int sum = 0;
    while(top[s] != top[t]) {
        if(dep[top[s]] < dep[top[t]]) swap(s, t); //这里十分容易忘了写dep[]
        sum += query(1, 1, n, dfn[top[s]], dfn[s]);
        if(sin_que(1, 1, n, dfn[top[s]]) == sin_que(1, 1, n, dfn[fa[top[s]]])) sum--;
        s = fa[top[s]];
    }
    if(dfn[s] > dfn[t]) swap(s, t);
    sum += query(1, 1, n, dfn[s], dfn[t]);
    return sum;
}

void change_path(int s, int t, int k) {
    while(top[s] != top[t]) {
        if(dep[top[s]] < dep[top[t]]) swap(s, t);
        change(1, 1, n, dfn[top[s]], dfn[s], k);
        s = fa[top[s]];
    }
    if(dfn[s] > dfn[t]) swap(s, t); 
    change(1, 1, n, dfn[s], dfn[t], k); 
}

int main() {
    scanf("%d%d", &n, &m);
    for(int i=1; i<=n; i++) {
        scanf("%d", &col[i]);
    }
    for(int i=1; i<n; i++) {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    } 
    dfs_init(1, 1, 0);
    dfs_top(1, 1);
    build(1, 1, n); 
    char temp_que[5];
    for(int i=1; i<=m; i++) {
        scanf("%s", temp_que);
        int s, t, c;
        if(temp_que[0] == 'C') {
            scanf("%d%d%d", &s, &t, &c);
            change_path(s, t, c);
        } else {
            scanf("%d%d", &s, &t);
            printf("%d\n", query_path(s, t));
        }
    }
    return 0;
}
posted @ 2018-10-17 18:59  Zolrk  阅读(149)  评论(0编辑  收藏  举报