BZOJ3052 [wc2013] 糖果公园 【树上莫队】

树上莫队和普通的序列莫队很像,我们把树进行dfs,然后存一个长度为2n的括号序列,就是一个点进去当作左括号,出来当作右括号,然后如果访问从u到v路径,我们可以转化成括号序列的区间,记录x进去的时候编号为f[x],出来时为g[x],然后分类讨论一下(f[u]<f[v]),如果u和v的lca不是u,那么就是从g[u]到f[v],否则就是lca的f到另一个点的f,(可以自己试一下,中间过程没有用的点正好就抵消掉了)这里要注意一下,从g[u]到f[v]的时候我们会少掉lca这个点,特殊处理一下即可,然后按照普通莫队排一下序,暴力就行了。 —— by VANE

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100005;
int n,m,cnt1,cnt2,tot,clk,f[N],g[N];
vector<int> M[N];
int id[N<<1],blg[N<<1];
int bin[25],pos[N],fa[N][17],c[N],d[N];
int v[N],w[N],last[N],u[N];
bool vis[N];
struct node
{
    int l,r,t,id;
}a[N],b[N];
ll ans[N],sum;
void dfs(int x)
{
    f[x]=++clk;id[clk]=x;
    for(int i=1;bin[i]<=d[x];++i)
    fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=0;i<M[x].size();++i)
    {
        int y=M[x][i];
        if(y!=fa[x][0])
        {
            fa[y][0]=x;
            d[y]=d[x]+1;
            dfs(y);
        }
    }
    g[x]=++clk;
    id[clk]=x;
}
int lca(int x,int y)
{
    if(d[x]<d[y]) swap(x,y);
    int tmp=d[x]-d[y];
    for(int i=0;bin[i]<=tmp;++i)
    if(tmp&bin[i]) x=fa[x][i];
    if(x==y) return x;
    for(int i=16;i>=0;--i)
    if(fa[x][i]!=fa[y][i])
    x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
bool cmp(node x,node y)
{
    if(blg[x.l]<blg[y.l]) return 1;
    if(blg[x.l]==blg[y.l]&&blg[x.r]<blg[y.r]) return 1;
    if(blg[x.l]==blg[y.l]&&blg[x.r]==blg[y.r]) return x.t<y.t;
    return 0;
}
void modify(int x)
{
    if(vis[x]) sum-=1ll*v[c[x]]*w[u[c[x]]--];
    else sum+=1ll*v[c[x]]*w[++u[c[x]]];
    vis[x]^=1;
}
void change(int x,int y)
{
    if(vis[x]) {modify(x);c[x]=y;modify(x);}
    else c[x]=y;
}
int main()
{
    int cas;
    scanf("%d%d%d",&n,&m,&cas);
    bin[0]=1;for(int i=1;i<=17;++i) bin[i]=bin[i-1]<<1;
    for(int i=1;i<=m;++i) scanf("%d",v+i);
    for(int i=1;i<=n;++i) scanf("%d",w+i);
    for(int i=1;i<n;++i)
    {
        int l,r;scanf("%d%d",&l,&r);
        M[l].push_back(r);
        M[r].push_back(l);
    }
    for(int i=1;i<=n;++i)
    scanf("%d",c+i),last[i]=c[i];
    int sz=pow(n,2.0/3);
    dfs(1);
    for(int i=1;i<=clk;++i) blg[i]=(i-1)/sz;
    while(cas--)
    {
        int l,r,t;
        scanf("%d%d%d",&t,&l,&r);
        if(t)
        {
            if(f[l]>f[r]) swap(l,r);
            a[++cnt1].r=f[r];a[cnt1].t=cnt2;
            a[cnt1].id=cnt1;
            a[cnt1].l=(lca(l,r)==l)?f[l]:g[l];
        }
        else
        {
            b[++cnt2].l=l;b[cnt2].t=last[l];
            last[l]=b[cnt2].r=r;
        }
    }
    sort(a+1,a+1+cnt1,cmp);
    int l=1,r=0,t=1;
    for(int i=1;i<=cnt1;++i)
    {
        for(;t<=a[i].t;++t) change(b[t].l,b[t].r);
        for(;t>a[i].t;--t) change(b[t].l,b[t].t);
        while(l>a[i].l) modify(id[--l]);
        while(l<a[i].l) modify(id[l++]);
        while(r>a[i].r) modify(id[r--]);
        while(r<a[i].r) modify(id[++r]);
        int x=id[l],y=id[r],tmp=lca(x,y);
        if(x!=tmp&&y!=tmp) {modify(tmp);ans[a[i].id]=sum;modify(tmp);}
        else ans[a[i].id]=sum;
    }
    for(int i=1;i<=cnt1;++i)
    printf("%lld\n",ans[i]);
}

 

posted @ 2018-01-24 21:01  大奕哥&VANE  阅读(280)  评论(0编辑  收藏  举报