P2486 [SDOI2011]染色

P2486染色

细节较多,线段树 \(pushup\) 时不仅要更新 \(kinds\) 还要更新 \(lco\)\(rco\).

#include <set>
#include <cmath>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <assert.h>
#include <algorithm>

using namespace std;

#define fir first
#define sec second
#define pb push_back
#define mp make_pair
#define LL long long
#define INF (0x3f3f3f3f)
#define mem(a, b) memset(a, b, sizeof (a))
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define Debug(x) cout << #x << " = " << x << endl
#define travle(i, x) for (register int i = head[x]; i; i = nxt[i])
#define For(i, a, b) for (register int (i) = (a); (i) <= (b); ++ (i))
#define Forr(i, a, b) for (register int (i) = (a); (i) >= (b); -- (i))
#define file(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)

namespace io {
    static char buf[1<<21], *pos = buf, *end = buf;
    inline char getc()
    { return pos == end && (end = (pos = buf) + fread(buf, 1, 1<<21, stdin), pos == end) ? EOF : *pos ++; }
    inline int rint() {
        register int x = 0, f = 1;register char c;
        while (!isdigit(c = getc())) if (c == '-') f = -1;
        while (x = (x << 1) + (x << 3) + (c ^ 48), isdigit(c = getc()));
        return x * f;
    }
    inline LL rLL() {
        register LL x = 0, f = 1; register char c;
        while (!isdigit(c = getc())) if (c == '-') f = -1;
        while (x = (x << 1ll) + (x << 3ll) + (c ^ 48), isdigit(c = getc()));
        return x * f;
    }
    inline void rstr(char *str) {
        while (isspace(*str = getc()));
        while (!isspace(*++str = getc()))
            if (*str == EOF) break;
        *str = '\0';
    }
    template<typename T> 
        inline bool chkmin(T &x, T y) { return x > y ? (x = y, 1) : 0; }
    template<typename T>
        inline bool chkmax(T &x, T y) { return x < y ? (x = y, 1) : 0; }    
}
using namespace io;

const int N = 1e5 + 1;

int n, m, a[N];
int tot, head[N], ver[N<<1], nxt[N<<1];
inline void add(int u, int v) 
{ ver[++tot] = v, nxt[tot] = head[u], head[u] = tot; }

int cnt, seg[N], top[N], dep[N], fa[N], son[N], size[N], rev[N];

#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
int val[N<<2], lco[N<<2], rco[N<<2], kinds[N<<2], tag[N];

void pushup(int x) {
    kinds[x] = kinds[ls(x)] + kinds[rs(x)] - (rco[ls(x)] == lco[rs(x)]);
    lco[x] = lco[ls(x)]; rco[x] = rco[rs(x)];
}

void build(int x, int l, int r) {
    if (l == r) {
        val[x] = lco[x] = rco[x] = a[rev[l]];
        kinds[x] = 1;
        return;
    }
    int mid = l + r >> 1;
    build(ls(x), l, mid);
    build(rs(x), mid + 1, r);
    pushup(x);
}

void pushdown(int x) {
    if (~tag[x]) {
        tag[ls(x)] = lco[ls(x)] = rco[ls(x)] = tag[x]; kinds[ls(x)] = 1;
        tag[rs(x)] = lco[rs(x)] = rco[rs(x)] = tag[x]; kinds[rs(x)] = 1;
        tag[x] = -1;
    }
}

void change(int x, int l, int r, int L, int R, int val) {
    if (L <= l && r <= R) {
        tag[x] = lco[x] = rco[x] = val;
        kinds[x] = 1;
        return;
    }
    pushdown(x);
    int mid = l + r >> 1;
    if (L <= mid) change(ls(x), l, mid, L, R, val);
    if (R > mid) change(rs(x), mid + 1, r, L, R, val);
    pushup(x);
}

int query(int x, int l, int r, int L, int R, int &a, int &b) {
    if (L <= l && r <= R) {
        if (l == L) a = lco[x];
        if (r == R) b = rco[x];
        return kinds[x];
    }
    pushdown(x);
    int Lres = 0, Rres = 0, mid = l + r >> 1;
    if (L <= mid) Lres = query(ls(x), l, mid, L, R, a, b);
    if (mid < R) Rres = query(rs(x), mid + 1, r, L, R, a, b);
    if (!Lres or !Rres) return Lres + Rres;
    return Lres + Rres - (rco[ls(x)] == lco[rs(x)]);
}


void Modify(int x, int y, int val) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        change(1, 1, n, seg[top[x]], seg[x], val);
        x = fa[top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    change(1, 1, n, seg[y], seg[x], val);
}

int Query(int x, int y) {
    int coll[2], colr[2], last0, last1, ans = 0;
    coll[0] = coll[1] = colr[0] = colr[1] = last0 = last1 = -1;
    while (top[x] != top[y]) {
        if (dep[top[x]] > dep[top[y]]) {
            ans += query(1, 1, n, seg[top[x]], seg[x], coll[0], colr[0]);
            ans -= colr[0] == last0;
            last0 = coll[0];
            x = fa[top[x]];
        } else {
            ans += query(1, 1, n, seg[top[y]], seg[y], coll[1], colr[1]);
            ans -= colr[1] == last1;
            last1 = coll[1];
            y = fa[top[y]];
        }
    }
    if (dep[x] > dep[y]) swap(x, y), swap(last1, last0);
    ans += query(1, 1, n, seg[x], seg[y], coll[1], colr[1]);
    ans -= (coll[1] == last0) + (colr[1] == last1);
    return ans ;
}
void DFS1(int u, int f) {
    fa[u] = f;
    dep[u] = dep[f] + 1;
    size[u] = 1;
    travle(i, u) {
        if (ver[i] != f) {
            DFS1(ver[i], u);
            size[u] += size[ver[i]];
            if (size[son[u]] < size[ver[i]]) son[u] = ver[i];
        }
    }
}
void DFS2(int u, int topf) {
    top[u] = topf;
    seg[u] = ++cnt;
    rev[cnt] = u;
    if (son[u]) DFS2(son[u], topf);
    travle(i, u) if (ver[i] != fa[u] && ver[i] != son[u]) 
        DFS2(ver[i], ver[i]);
}

int main() {
#ifndef ONLINE_JUDGE
    file("P2468");
#endif
    memset(tag, -1, sizeof tag);
    n = rint(), m = rint();
    For(i, 1, n) a[i] = rint();
    For(i, 1, n -1 ) {
        int u = rint(), v = rint();
        add(u, v); add(v, u);
    }   
    DFS1(1, 0);
    DFS2(1, 1);
    build(1, 1, n);
    char op[30];
    while (m --) {
        rstr(op);
        if (op[0] == 'C') {
            int u = rint(), v = rint(), val = rint();
            Modify(u, v, val);
        } else {
            int u = rint(), v = rint();
            printf("%d\n", Query(u, v));
        }
//    cout << "ok " << query(1, 1, n, seg[5], seg[5], cnt, cnt) << endl;
    }
}

posted @ 2019-02-28 07:34  茶Tea  阅读(107)  评论(0编辑  收藏  举报