[bzoj3307]雨天的尾巴【线段树】

【题目链接】
  http://www.lydsy.com/JudgeOnline/problem.php?id=3307
【题解】
  首先这道题所有的修改可以一次处理完。
  若有一条边(u,v)u,v上各打一个+1的标记,在lca(u,v),dad(lca(u,v))上各打一个1的标记。
  现在考虑怎么统计答案,dfs遍历整棵树,每个节点用权值线段树记录标记,每次回溯时,把标记上传,每个节点在它的所有儿子统计完以后再统计。
  统计的复杂度显然为O(nlogz)现在问题在于如何合并两棵权值线段树,考虑左偏树的合并方式,我们不妨照着模仿一下:
  基本与左偏树的写法相同,只是在叶子节点略有修改。

int merge(int x, int y, int l, int r){
    if (x==0) return y;
    if (y==0) return x;
    if (l==r){
        T[x].num+=T[y].num;
        return x;
    }
    int mid=(l+r)/2;
    T[x].pl=merge(T[x].pl,T[y].pl,l,mid); T[x].pr=merge(T[x].pr,T[y].pr,mid+1,r);
    return x;
}

  若一共有n个叶子节点有值,值域为m,那么时间复杂度为O(nlogm),接下来给出一段口胡:
  若一个节点所包含的区间有k个值,那么每次访问到这个区间,都会使至少2个值合并在一起,所以这个区间最多被访问O(k)次,一个值会被包含在logm个区间,所以总复杂度为O(nlogm)

/* --------------
    user Vanisher
    problem bzoj-3307 
----------------*/
# include <bits/stdc++.h>
# define    ll      long long
# define    inf     0x3f3f3f3f
# define    N       100010
# define    L       17
# define    LL      1
# define    RR      1e9
using namespace std;
int read(){
    int tmp=0, fh=1; char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') fh=-1; ch=getchar();}
    while (ch>='0'&&ch<='9'){tmp=tmp*10+ch-'0'; ch=getchar();}
    return tmp*fh;
}
struct Tree{
    int num,id,pl,pr;
}T[4000100];
struct node{
    int data,next;
}e[N*2];
vector <int> add[N],del[N];
int place,head[N],dad[N][L+1],dep[N],n,rt[N],ans[N],m;
void build(int u, int v){
    e[++place].data=u; e[place].next=head[v]; head[v]=place;
    e[++place].data=v; e[place].next=head[u]; head[u]=place;
}
void dfs(int x, int fa){
    dad[x][0]=fa; dep[x]=dep[fa]+1;
    for (int ed=head[x]; ed!=0; ed=e[ed].next)
        if (e[ed].data!=fa)
            dfs(e[ed].data,x);
}
void pre(){
    for (int i=1,j=1; j*2<=n; i++, j*=2)
        for (int k=1; k<=n; k++)
            dad[k][i]=dad[dad[k][i-1]][i-1];
}
int lca(int x, int y){
    if (dep[x]>dep[y]) swap(x,y);
    for (int i=L; i>=0; i--)
        if (dep[dad[y][i]]>=dep[x])
            y=dad[y][i];
    if (x==y) return x;
    for (int i=L; i>=0; i--)
        if (dad[x][i]!=dad[y][i])
            x=dad[x][i], y=dad[y][i];
    return dad[x][0];
}
void change(int p){
    if (T[T[p].pl].num>=T[T[p].pr].num)
        T[p].num=T[T[p].pl].num, T[p].id=T[T[p].pl].id;
        else T[p].num=T[T[p].pr].num, T[p].id=T[T[p].pr].id;
}
int extend(int x, int y, int l, int r){
    if (x==0) return y;
    if (y==0) return x;
    if (l==r){
        T[x].num+=T[y].num;
        T[x].id=l; 
        return x;
    }
    int mid=(l+r)/2;
    T[x].pl=extend(T[x].pl,T[y].pl,l,mid); T[x].pr=extend(T[x].pr,T[y].pr,mid+1,r);
    change(x);
    return x;
}
int inc(int p, int x, int l, int r){
    if (p==0) p=++place;
    if (l==r) {
        T[p].num++;
        T[p].id=x;
        return p;
    }
    int mid=(l+r)/2;
    if (x<=mid) T[p].pl=inc(T[p].pl,x,l,mid);
        else T[p].pr=inc(T[p].pr,x,mid+1,r);
    change(p);
    return p;
}
int dec(int p, int x, int l, int r){
    if (p==0) p=++place;
    if (l==r) {
        T[p].num--;
        if (T[p].num==0) T[p].id=0;
        return p;
    }
    int mid=(l+r)/2;
    if (x<=mid) T[p].pl=dec(T[p].pl,x,l,mid);
        else T[p].pr=dec(T[p].pr,x,mid+1,r);
    change(p);
    return p;
}
int getans(int x){
    return T[x].id;
}
void solve(int x, int fa){
    int now=0;
    for (int ed=head[x]; ed!=0; ed=e[ed].next)
        if (e[ed].data!=fa){
            solve(e[ed].data,x);
            now=e[ed].data;
        }
    rt[x]=rt[now];
    for (int ed=head[x]; ed!=0; ed=e[ed].next)
        if (e[ed].data!=fa&&e[ed].data!=now)
            rt[x]=extend(rt[x],rt[e[ed].data],LL,RR);
    for (int i=0; i<add[x].size(); i++)
        rt[x]=inc(rt[x],add[x][i],LL,RR);
    for (int i=0; i<del[x].size(); i++)
        rt[x]=dec(rt[x],del[x][i],LL,RR);
    ans[x]=getans(rt[x]);
}
int main(){
    n=read(), m=read();
    for (int i=1; i<n; i++)
        build(read(),read());
    place=0;
    dfs(1,0); pre();
    for (int i=1; i<=m; i++){
        int u=read(), v=read(), k=read();
        int l=lca(u,v);
        del[l].push_back(k);
        add[u].push_back(k);
        add[v].push_back(k);
        del[dad[l][0]].push_back(k);
    }
    solve(1,0);
    for (int i=1; i<=n; i++)
        printf("%d\n",ans[i]);
    return 0;
}
posted @ 2018-03-13 21:59  Vanisher  阅读(91)  评论(0编辑  收藏  举报