bzoj千题计划243:bzoj2325: [ZJOI2011]道馆之战

http://www.lydsy.com/JudgeOnline/problem.php?id=2325

 

设线段树节点区间为[l,r]

每个节点维护sum[0/1][0/1]  从l的A/B区域到r的A/B区域 经过冰块的最大数量

mx[0][0] 从l的A区域出发向r经过冰块的最大数量

mx[0][1] 从l的B区域出发向r经过冰块的最大数量

mx[1][0] 从r的A区域出发向l经过冰块的最大数量

mx[1][1] 从r的B区域出发向l经过冰块的最大数量

 

#include<cstdio>
#include<iostream>
#include<algorithm>

#define N 30001

using namespace std;

int n;

int front[N],nxt[N<<1],to[N<<1],tot;

int fa[N],dep[N],siz[N];
int bl[N];

int id[N],dy[N],cnt;

bool a[N][2];

#define max(x,y) ((x)>(y) ? (x) : (y))

struct node
{
    int sum[2][2];
    int mx[2][2];
    
    node()
    {
        for(int i=0;i<2;++i)
            for(int j=0;j<2;++j)
                sum[i][j]=mx[i][j]=0;
    } 

    node operator + (node p) const
    {
        node k;
        for(int i=0;i<=1;++i)
            for(int j=0;j<=1;++j)
            {
                if(sum[i][0] && p.sum[0][j]) 
                k.sum[i][j]=max(k.sum[i][j],sum[i][0]+p.sum[0][j]);
                if(sum[i][1] && p.sum[1][j]) 
                k.sum[i][j]=max(k.sum[i][j],sum[i][1]+p.sum[1][j]); 
            }
        k.mx[0][0]=max(k.sum[0][0],k.sum[0][1]);
        k.mx[0][1]=max(k.sum[1][0],k.sum[1][1]);
        k.mx[1][0]=max(k.sum[0][0],k.sum[1][0]);
        k.mx[1][1]=max(k.sum[0][1],k.sum[1][1]);
        if(sum[0][0]) k.mx[0][0]=max(k.mx[0][0],sum[0][0]+p.mx[0][0]);
        if(sum[0][1]) k.mx[0][0]=max(k.mx[0][0],sum[0][1]+p.mx[0][1]);
        if(sum[1][0]) k.mx[0][1]=max(k.mx[0][1],sum[1][0]+p.mx[0][0]);
        if(sum[1][1]) k.mx[0][1]=max(k.mx[0][1],sum[1][1]+p.mx[0][1]);
        if(p.sum[0][0]) k.mx[1][0]=max(k.mx[1][0],mx[1][0]+p.sum[0][0]);
        if(p.sum[1][0]) k.mx[1][0]=max(k.mx[1][0],mx[1][1]+p.sum[1][0]);
        if(p.sum[0][1]) k.mx[1][1]=max(k.mx[1][1],mx[1][0]+p.sum[0][1]);
        if(p.sum[1][1]) k.mx[1][1]=max(k.mx[1][1],mx[1][1]+p.sum[1][1]);
        k.mx[0][0]=max(k.mx[0][0],mx[0][0]);
        k.mx[0][1]=max(k.mx[0][1],mx[0][1]);
        k.mx[1][0]=max(k.mx[1][0],p.mx[1][0]);
        k.mx[1][1]=max(k.mx[1][1],p.mx[1][1]); 
        return k;
    }
    
    void turn()
    {
        swap(sum[0][1],sum[1][0]);
        swap(mx[0][0],mx[1][0]);
        swap(mx[0][1],mx[1][1]);
    }

}tr[N<<2];

void read(int &x)
{
    x=0; char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
}

void add(int u,int v)
{
    to[++tot]=v; nxt[tot]=front[u]; front[u]=tot;
    to[++tot]=u; nxt[tot]=front[v]; front[v]=tot;
}

void dfs1(int x)
{
    siz[x]=1;
    for(int i=front[x];i;i=nxt[i])
        if(to[i]!=fa[x])
        {
            fa[to[i]]=x;
            dep[to[i]]=dep[x]+1;
            dfs1(to[i]);
            siz[x]+=siz[to[i]];
        }
}

void dfs2(int x,int top)
{
    bl[x]=top;
    id[x]=++cnt;
    dy[cnt]=x;
    int y=0;
    for(int i=front[x];i;i=nxt[i])
        if(to[i]!=fa[x] && siz[to[i]]>siz[y]) y=to[i];
    if(y) dfs2(y,top);
    else return;
    for(int i=front[x];i;i=nxt[i])
        if(to[i]!=y && to[i]!=fa[x]) dfs2(to[i],to[i]);
}

void build(int k,int l,int r)
{
    if(l==r)
    {
        if(a[dy[l]][0]) tr[k].sum[0][0]=1;
        if(a[dy[l]][1]) tr[k].sum[1][1]=1;
        if(a[dy[l]][0] && a[dy[l]][1]) tr[k].sum[0][1]=tr[k].sum[1][0]=2;
        tr[k].mx[0][0]=max(tr[k].sum[0][0],tr[k].sum[0][1]);
        tr[k].mx[0][1]=max(tr[k].sum[1][0],tr[k].sum[1][1]);
        tr[k].mx[1][0]=max(tr[k].sum[0][0],tr[k].sum[1][0]);
        tr[k].mx[1][1]=max(tr[k].sum[0][1],tr[k].sum[1][1]);
        return;
    }
    int mid=l+r>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    tr[k]=tr[k<<1]+tr[k<<1|1];
}

node query(int k,int l,int r,int opl,int opr)
{
    if(l>=opl && r<=opr) return tr[k];
    int mid=l+r>>1;
    if(opr<=mid) return query(k<<1,l,mid,opl,opr);
    else if(opl>mid) return query(k<<1|1,mid+1,r,opl,opr);
    else return query(k<<1,l,mid,opl,opr)+query(k<<1|1,mid+1,r,opl,opr);
}

void Query(int u,int v)
{
    node ansu,ansv; 
    bool firstu=false,firstv=false;
    while(bl[u]!=bl[v])
    {
        if(dep[bl[u]]>dep[bl[v]])
        {
            if(!firstu) firstu=true,ansu=query(1,1,n,id[bl[u]],id[u]);
            else ansu=query(1,1,n,id[bl[u]],id[u])+ansu;
            u=fa[bl[u]];
        }
        else
        {
            if(!firstv) firstv=true,ansv=query(1,1,n,id[bl[v]],id[v]);
            else ansv=query(1,1,n,id[bl[v]],id[v])+ansv;
            v=fa[bl[v]];
        }
    }
    if(dep[u]>dep[v])
    {
        if(!firstu) firstu=true,ansu=query(1,1,n,id[v],id[u]);
        else ansu=query(1,1,n,id[v],id[u])+ansu;
    }
    else
    {
        if(!firstv) firstv=true,ansv=query(1,1,n,id[u],id[v]);
        else ansv=query(1,1,n,id[u],id[v])+ansv;
    }
    if(!firstu) ansu=ansv;
    else
    {
        ansu.turn();
        if(firstv) ansu=ansu+ansv;
    }
    cout<<max(ansu.mx[0][0],ansu.mx[0][1])<<'\n';
}

void change(int k,int l,int r,int x,bool u,bool v)
{
    if(l==r)
    {
        a[l][0]=u;
        a[l][1]=v;
        if(a[l][0]) tr[k].sum[0][0]=1;
        else tr[k].sum[0][0]=0;
        if(a[l][1]) tr[k].sum[1][1]=1;
        else tr[k].sum[1][1]=0;
        if(a[l][0] && a[l][1]) tr[k].sum[0][1]=tr[k].sum[1][0]=2;
        else tr[k].sum[0][1]=tr[k].sum[1][0]=0;
        tr[k].mx[0][0]=max(tr[k].sum[0][0],tr[k].sum[0][1]);
        tr[k].mx[0][1]=max(tr[k].sum[1][0],tr[k].sum[1][1]);
        tr[k].mx[1][0]=max(tr[k].sum[0][0],tr[k].sum[1][0]);
        tr[k].mx[1][1]=max(tr[k].sum[0][1],tr[k].sum[1][1]);
        return;
    }
    int mid=l+r>>1;
    if(x<=mid) change(k<<1,l,mid,x,u,v);
    else change(k<<1|1,mid+1,r,x,u,v);
    tr[k]=tr[k<<1]+tr[k<<1|1];
}

int main()
{
    freopen("fight.in","r",stdin);
    freopen("fight.out","w",stdout); 
    int m;
    read(n); read(m);
    int u,v;
    for(int i=1;i<n;++i)
    {
        read(u); read(v);
        add(u,v);
    }
    dfs1(1);
    dfs2(1,1);
    char s[3];
    for(int i=1;i<=n;++i)
    {
        scanf("%s",s);
        if(s[0]=='.') a[i][0]=true;
        if(s[1]=='.') a[i][1]=true;
    }
    build(1,1,n);
    char c[3];
    while(m--)
    {
        scanf("%s",c);
        if(c[0]=='Q') 
        {
            read(u); read(v);
            Query(u,v);
        }
        else
        {
            read(u);
            scanf("%s",c);
            change(1,1,n,id[u],c[0]=='.',c[1]=='.');
        }
    }
}

 

posted @ 2018-02-20 22:44  TRTTG  阅读(217)  评论(0编辑  收藏  举报