题解:

树链剖分

和普通的树链剖分不一样,这里的线段树不只是要记录x-y的和

而是要记录x左到y左,x左到y右,x右到y左,x右到y右

然后就可以了

代码:

#include<bits/stdc++.h>
const int N=30005,M=80005,inf=1e9;
using namespace std;
int n,m,cnt,place,sz,last[N],x,y,deep[N],u,v,fa[N][16],son[N],belong[N],pl[N];
char mp[N][2];
struct data{int l1,l2,r1,r2,d1,d2,d3,d4;};
struct edge{int to,next;}e[2*N];
struct seg{int l,r;data d;}t[4*N];
void insert(int u,int v)
{
    e[++cnt].to=v;
    e[cnt].next=last[u];
    last[u]=cnt;
}
data merge(data a,data b)
{
    data tmp;
    tmp.d1=max(a.d1+b.d1,a.d2+b.d3);
    if (tmp.d1<0)tmp.d1=-inf;
    tmp.d2=max(a.d1+b.d2,a.d2+b.d4);
    if (tmp.d2<0)tmp.d2=-inf;
    tmp.d3=max(a.d3+b.d1,a.d4+b.d3);
    if (tmp.d3<0)tmp.d3=-inf;
    tmp.d4=max(a.d3+b.d2,a.d4+b.d4);
    if (tmp.d4<0)tmp.d4=-inf;
    tmp.l1=max(a.d1+b.l1,a.d2+b.l2);
    tmp.l1=max(tmp.l1,a.l1);
    if (tmp.l1<0)tmp.l1=-inf;
    tmp.l2=max(a.d3+b.l1,a.d4+b.l2);
    tmp.l2=max(tmp.l2,a.l2);
    if (tmp.l2<0)tmp.l2=-inf;
    tmp.r1=max(b.d1+a.r1,b.d3+a.r2);
    tmp.r1=max(tmp.r1,b.r1);
    if (tmp.r1<0)tmp.r1=-inf;
    tmp.r2=max(b.d2+a.r1,b.d4+a.r2);
    tmp.r2=max(tmp.r2,b.r2);
    if (tmp.r2<0)tmp.r2=-inf;
    return tmp;
}
void build(int k,int l,int r)
{
    t[k].l=l;t[k].r=r;
    if (l==r)return;
    int mid=(l+r)>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
}
void update(int k,int x,char mp[2])
{
    int l=t[k].l,r=t[k].r;
    if(l==r)
    {
        t[k].d.d1=t[k].d.d2=t[k].d.d3=t[k].d.d4=-inf;
        t[k].d.l1=t[k].d.l2=t[k].d.r1=t[k].d.r2=-inf;
        if (mp[0]=='.')t[k].d.d1=t[k].d.l1=t[k].d.r1=1;
        if (mp[1]=='.')t[k].d.d4=t[k].d.l2=t[k].d.r2=1;
        if (mp[0]=='.'&&mp[1]=='.')
         t[k].d.d2=t[k].d.d3=t[k].d.l1=t[k].d.l2=t[k].d.r1=t[k].d.r2=2;
        return;
    }
    int mid=(l+r)>>1;
    if (x<=mid)update(k<<1,x,mp);
    else update(k<<1|1,x,mp);
    t[k].d=merge(t[k<<1].d,t[k<<1|1].d);
}
data query(int k,int x,int y)
{
    int l=t[k].l,r=t[k].r;
    if (x==l&&y==r)return t[k].d;
    int mid=(l+r)>>1;
    if (mid>=y)return query(k<<1,x,y);
    else if (mid<x)return query(k<<1|1,x,y);
    else return merge(query(k<<1,x,mid),query(k<<1|1,mid+1,y));
}
void dfs1(int x)
{
    son[x]=1;
    for (int i=1;i<=15;i++)
     if (deep[x]>=(1<<i))fa[x][i]=fa[fa[x][i-1]][i-1];
    for (int i=last[x];i;i=e[i].next)
     {
        if(e[i].to==fa[x][0])continue;
        fa[e[i].to][0]=x;
        deep[e[i].to]=deep[x]+1;
        dfs1(e[i].to);
        son[x]+=son[e[i].to];
     } 
}
void dfs2(int x,int chain)
{
    pl[x]=++place;belong[x]=chain;
    update(1,pl[x],mp[x]);
    int k=0;
    for (int i=last[x];i;i=e[i].next)
     {
        if (e[i].to==fa[x][0])continue;
        if (son[e[i].to]>son[k])k=e[i].to;
     }
    if (k)dfs2(k,chain);
    for (int i=last[x];i;i=e[i].next)
     {
        if (e[i].to==k||e[i].to==fa[x][0])continue;
        dfs2(e[i].to,e[i].to);
     }
}
int lca(int x,int y)
{
    if(deep[x]<deep[y])swap(x,y);
    int t=deep[x]-deep[y];
    for (int i=0;i<=15;i++)
     if ((1<<i)&t)x=fa[x][i];
    for (int i=15;i>=0;i--)
     if (fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    if (x==y)return x;
    return fa[x][0];
}
data solveque(int x,int f,bool flag)
{
    data ans;
    ans.l1=ans.l2=ans.r1=ans.r2=0;
    ans.d1=ans.d4=ans.d2=ans.d3=0;
    while (belong[x]!=belong[f])
     {
        ans=merge(query(1,pl[belong[x]],pl[x]),ans);
        x=fa[belong[x]][0];
     }
    if (flag==1&&pl[f]+1<=pl[x])ans=merge(query(1,pl[f]+1,pl[x]),ans);
    if (!flag)ans=merge(query(1,pl[f],pl[x]),ans);
    return ans;
}
void que(int x,int y)
{
    if (mp[x][0]=='#'&&mp[x][1]=='#'){puts("0");return;}
    int f=lca(x,y);
    data a=solveque(x,f,1),b=solveque(y,f,0);
    swap(a.d2,a.d3);
    swap(a.l1,a.r1);
    swap(a.l2,a.r2);
    data ans=merge(a,b);
    printf("%d\n",max(ans.l1,ans.l2));
}
int main()
{
       scanf("%d%d",&n,&m);
    for (int i=1;i<n;i++)scanf("%d%d",&u,&v),insert(u,v),insert(v,u);
    for (int i=1;i<=n;i++)scanf("%s",mp[i]);
    build(1,1,n);
    dfs1(1);dfs2(1,1);
    for (int i=1;i<=m;i++)
     { 
        char ch[2];
        scanf("%s",ch);
        if (ch[0]=='Q')scanf("%d%d",&x,&y),que(x,y);
        else 
         {
             scanf("%d",&x);
            scanf("%s",mp[x]);
            update(1,pl[x],mp[x]);
         }
     }
    return 0;
}

 

posted on 2017-12-07 19:10  宣毅鸣  阅读(130)  评论(0编辑  收藏  举报