[SDOI2014]旅行 【树链剖分+主席树】

题目描述

最喜欢这种细节少少的组合型问题了,感觉被树套树虐过之后其他数据结构都简单了
这题就是简单的树链剖分+主席树,好像只要会就不是很容易写错
信仰每一种宗教的单独建一颗线段树就好了

代码

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<cmath>
using namespace std;
const int N=100010;
int n, m, T_cnt, w[N], c[N], rt[N];
int idx, id[N], top[N], real[N], siz[N], fa[N], hson[N], dep[N];
struct node{int ch[2], s, mw, sw;} T[N<<6];
int head[N], to[N<<1], next[N<<1], now;
void add(int u, int v){to[++now]=v, next[now]=head[u], head[u]=now;} 

#define ls(o) T[o].ch[0]
#define rs(o) T[o].ch[1]
void maintain(int o){
    T[o].s=T[ls(o)].s+T[rs(o)].s; T[o].mw=max(T[ls(o)].mw, T[rs(o)].mw);T[o].sw=T[ls(o)].sw+T[rs(o)].sw;
}
void ins(int &o, int l, int r, int x, int v){
    T[++T_cnt]=T[o]; o=T_cnt; 
    if(l == r) {T[o].s+=v, T[o].mw+=w[real[x]]*v, T[o].sw+=w[real[x]]*v; return ; }
    int mid=l+r>>1; if(x <= mid) ins(ls(o), l, mid, x, v); else ins(rs(o), mid+1, r, x, v); maintain(o);
}
int queryS(int o, int l, int r, int L, int R){
    if(l > R || r < L || (!T[o].s)) return 0; if(l >= L && r <= R) return T[o].sw;
    int mid=l+r>>1; return queryS(ls(o), l, mid, L, R)+queryS(rs(o), mid+1, r, L, R);
}
int queryM(int o, int l, int r, int L, int R){
    if(l > R || r < L || (!T[o].s)) return 0; if(l >= L && r <= R) return T[o].mw;
    int mid=l+r>>1; return max(queryM(ls(o), l, mid, L, R), queryM(rs(o), mid+1, r, L, R));
}

void dfs1(int u, int f){
    fa[u]=f, siz[u]=1; int v;
    for(int i=head[u]; i; i=next[i]) if((v=to[i]) != f)
        {dep[v]=dep[u]+1; dfs1(v, u); siz[u]+=siz[v]; if(!hson[u] || siz[v] > siz[hson[u]]) hson[u]=v;}
}
void dfs2(int u, int anc){
    top[u]=anc; real[++idx]=u, id[u]=idx; int v; if(!hson[u]) return ; dfs2(hson[u], anc);
    for(int i=head[u]; i; i=next[i]) if((v=to[i]) != fa[u] && v != hson[u]) dfs2(v, v);
}
void Query(int u, int v, int k)
{
    int tu=top[u], tv=top[v], cc=c[u], ans=0;
    while(tu != tv){
        if(dep[tu] < dep[tv]) swap(u, v), swap(tu, tv);
        if(k) ans+=queryS(rt[cc], 1, n, id[tu], id[u]); else ans=max(ans, queryM(rt[cc], 1, n, id[tu], id[u]));
        u=fa[tu], tu=top[u];
    }
    if(dep[u] > dep[v]) swap(u, v); 
    if(k) ans+=queryS(rt[cc], 1, n, id[u], id[v]); else ans=max(ans, queryM(rt[cc], 1, n, id[u], id[v]));
    printf("%d\n", ans);
}

int read(){
    int out=0, f=1; char c=getchar(); while(c < '0' || c > '9') {if(c == '-') f=-1; c=getchar();}
    while(c >= '0' && c <= '9') {out=(out<<1)+(out<<3)+c-'0'; c=getchar();}
    return out*f;
}

void solve()
{
    n=read(), m=read(); int u, v;
    for(int i=1; i <= n; i++) w[i]=read(), c[i]=read();
    for(int i=1; i < n; i++) u=read(), v=read(), add(u, v), add(v, u);
    dep[1]=1; dfs1(1, 0); dfs2(1, 1);
    for(int i=1; i <= n; i++) ins(rt[c[real[i]]], 1, n, i, 1);
    for(int i=1; i <= m; i++){
        char c1=getchar(), c2; while(c1 != 'C' && c1 != 'Q') c1=getchar(); c2=getchar();
        int x=read(), y=read();
        if(c1 == 'C'){
            if(c2 == 'C') ins(rt[c[x]], 1, n, id[x], -1), c[x]=y, ins(rt[c[x]], 1, n, id[x], 1);
            else ins(rt[c[x]], 1, n, id[x], -1), w[x]=y, ins(rt[c[x]], 1, n, id[x], 1);
        }
        else {if(c2 == 'S') Query(x, y, 1); else Query(x, y, 0);}
    }
}

int main()
{
    solve();
    return 0;
}
posted @ 2018-03-09 20:51  zerolt  阅读(136)  评论(0编辑  收藏  举报