[洛谷P2486] [SDOI2011]染色

题目链接:##

传送门

题目分析:##

树链剖分经典题
首先考虑在序列上如何维护染色与色段个数,可以很容易想到每次pushup时将左子和右子的个数合并上来,并判断一下中间是否是同一段
在树上同理,对每个重链建线段树并如上所述维护,注意在查找答案时跳每段重链要记录端点颜色来比较,判一下两端端点颜色是否相同
也可用LCT做,还没学,学完回来补档。
写树剖细节很多,尽量考虑全面。

代码:##

//记录线段树区间左端点右端点颜色和每段重链的头尾颜色 

#include<bits/stdc++.h>
#define N (100000 + 5) 
using namespace std;
inline int read() {
    int cnt = 0, f = 1; char c;
    c = getchar();
    while(!isdigit(c)) {
        if (c == '-') f = -f;
        c = getchar();
    }
    while(isdigit(c)) {
        cnt = cnt * 10 + c - '0';
        c = getchar();
    }
    return cnt * f;
}
int n, m, a[N];
int nxt[N<<1],	first[N], to[N<<1], tot, num[N], id[N], top[N], idx = 0, father[N], siz[N], dep[N], son[N];
int ans1, ans2, Lc, Rc;
int x, y, z; char opr = '0'; 

struct node{
    int l, r;
    long long tag, dat, lc, rc;
    #define l(p) tree[p].l
    #define r(p) tree[p].r
    #define dat(p) tree[p].dat
    #define tag(p) tree[p].tag
    #define lc(p) tree[p].lc
    #define rc(p) tree[p].rc
} tree[N * 4]; 

void add(int x, int y) {
    nxt[++tot] = first[x];
    first[x] = tot;
    to[tot] = y;
}

void pushdown(int p);
void pushup(int p);

void debug(int p, int l, int r) {
    if(l == r) {
        //cout<<id[l]<<" "<<lc(l)<<endl;
        cout<<lc(l)<<" ";
        return;
    }
    pushdown(p);
    int mid = (l + r) >> 1;
    debug(p << 1, l, mid);
    debug(p << 1 | 1, mid + 1, r);
    pushup(p);
}

void dfs1(int cur, int fa) {
    father[cur] = fa, siz[cur] = 1, dep[cur] = dep[fa] + 1;
    for (register int i = first[cur]; i; i = nxt[i]) {
        int v = to[i];
        if (v != fa) {
            dfs1(v, cur);
            siz[cur] += siz[v];
            if (siz[son[cur]] < siz[v]) son[cur] = v;
        }
    }
}

void dfs2(int cur, int tp) {
    top[cur] = tp; num[cur] = ++idx;
    id[idx] = cur;
    if(son[cur]) dfs2(son[cur], tp);
    for (register int i = first[cur]; i; i = nxt[i]) {
        int v = to[i];
        if (!num[v]) dfs2(v, v);
    }
}

void pushup(int p) {
    lc(p) = lc(p << 1), rc(p) = rc(p << 1 | 1); 
    dat(p) = dat(p << 1) + dat(p << 1 | 1);
    if(rc(p << 1) == lc(p << 1 | 1)) --dat(p);
}

void pushdown(int p) {
    if(tag(p)) {
        tag(p << 1) = tag(p << 1 | 1) = tag(p);
        lc(p << 1) = lc(p << 1 | 1) = rc(p << 1) = rc(p << 1 | 1) = tag(p);
        dat(p << 1) = dat(p << 1 | 1) = 1;
        tag(p) = 0; 
    }
}

void build_tree(int p, int l, int r) {
    l(p) = l, r(p) = r;
    if (l == r) {
        lc(p) = rc(p) = a[id[l]];
        dat(p) = 1;
        return;
    }
    int mid = (l + r) >> 1;
    build_tree(p << 1, l, mid);
    build_tree(p << 1 | 1, mid + 1, r);
    pushup(p);
}

void modify(int p, int l, int r, int d) {
    if (l <= l(p) && r >= r(p)) {
        tag(p) = d;
        lc(p) = rc(p) = d;
        dat(p) = 1;
        return;
    }
    pushdown(p);
    int mid = (l(p) + r(p)) >> 1;
    if(l <= mid) modify(p << 1, l, r, d);
    if(r > mid) modify(p << 1 | 1, l, r, d);
    pushup(p); 
}

long long query(int p, int l, int r) {
//	cout<<"p= "<<p<<" l(p)= "<<l(p)<<" r(p)= "<<r(p)<<endl;
    if (l(p) == l) Lc = lc(p);
    if (r(p) == r) Rc = rc(p);
//	if(l(p)>r||r(p)<l)return 0; 
    if (l <= l(p) && r >= r(p)) {
//		cout<<" dat(p)= "<<dat(p)<<" "<<l<<" "<<r<<endl;
        return dat(p);
    }
    pushdown(p);
    long long val = 0;
    int mid = (l(p) + r(p)) >> 1;
    if (l > mid) val += query(p << 1 | 1, l, r);
    else if (r <= mid) val += query(p << 1, l, r);
    else {
//		cout<<val<<endl;
        if (rc(p << 1) == lc(p << 1 | 1)) val += query(p << 1, l, r) + query(p << 1 | 1, l, r) - 1;
        else val += query(p << 1, l, r) + query(p << 1 | 1, l, r); 
    }
    return val;
}

void chain_modify(int u, int v, int d) {
    while(top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        modify(1, num[top[u]], num[u], d);
        u = father[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v);
    modify(1, num[v], num[u], d);
}

long long chain_query(int u, int v) {
    long long ans = 0;
    ans1 = ans2 = 0;
    while(top[u] != top[v]) {
//		cout<<"chain_query_start:"<<u<<" "<<v<<endl;
        if (dep[top[u]] < dep[top[v]]) swap(u, v), swap(ans1, ans2);
        ans += query(1, num[top[u]], num[u]);
//		cout<<"fuck "<<u<<" "<<top[u]<<endl;
        if (ans1 == Rc) ans--;
        ans1 = Lc;
        u = father[top[u]];
    }
    if (dep[u] < dep[v]) swap(u, v), swap(ans1, ans2);
    ans += query(1, num[v], num[u]);
    if (Rc == ans1) ans--;
    if (Lc == ans2) ans--;
    return ans;
}

void solve() {
    n = read(); m = read();
    for (register int i = 1; i <= n; i++) a[i] = read() + 1;
    for (register int i = 1; i < n; i++) {
        x = read(); y = read();
        add(x, y); add(y, x);
    }
    dfs1(1, 0); dfs2(1, 1);
    //for(int i=1; i<=n; i++) cout<<father[i]<<" "<<top[i]<<" "<<dep[i]<<" "<<num[i]<<endl;
    build_tree(1, 1, n);
//	cout<<"#"<<endl;
    for (register int i = 1; i <= m; i++) {
        cin >> opr;
        x = read(); y = read();
        if (opr == 'C') {
            z = read();
        //	cout<<"fffffasfasf"<<endl;
        //	debug(1, 1, n);cout<<endl;
            chain_modify(x, y, z + 1);
        //	debug(1, 1, n);
        //	cout<<endl;
        }
        if (opr == 'Q') {
            long long res = chain_query(x, y);
            printf("%lld\n", res);
        }
    }
}
int main() {
    //freopen("input.in","r",stdin);
    solve();
    return 0;
}
posted @ 2019-06-19 21:35  kma_093  阅读(224)  评论(0编辑  收藏  举报