bzoj3572[HNOI2014]世界树

只说一下大概做法吧,我这个做法好难写(或许有的地方是可以简便的做到的),首先建虚树,把虚树上的"议事处"叫做黑点,其他点叫做白点。
对于每个白点算出最近的黑点以及到它的距离(这个我是dfs一遍用线段树维护深度做的),那么这个白点可以看做那个黑点了,只不过计算距离时要多加一个值,然后dfs一遍,对于虚树上的每条边都计算一下它对两个端点答案的贡献,具体的贡献是原树在这条链上所有的点及其子树。好像不在虚树边上的点答案不好算,于是对于每个黑点把它的初始答案设为它的所有白点siz的最大值(虚树的根的最近的黑点初始答案设为n),枚举每条边时就只需从答案中减去不是黑点控制的即可。
复杂度\(O(logn*\sum m)\)

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define P puts("lala")
#define cp cerr<<"lala"<<endl
#define fi first
#define se second
#define ln putchar('\n')
#define pb push_back
#define shmem(x) cerr<<sizeof(x)/(1024*1024.0)<<"MB"<<endl
using namespace std;
inline int read()
{
    char ch=getchar();int g=1,re=0;
    while(ch<'0'||ch>'9'){if(ch=='-')g=-1; ch=getchar();}
    while(ch<='9'&&ch>='0') re=(re<<1)+(re<<3)+(ch^48),ch=getchar();
    return re*g;
}
typedef long long ll;
typedef pair<int,int> pii;

const int N=300050;
const int inf=0x3f3f3f3f;
int head[N],cnt=0;
struct node
{
    int to,next,w;
}e[N<<1];
inline void add(int x,int y,int w)
{
    e[++cnt]=(node){y,head[x],w};head[x]=cnt;
    e[++cnt]=(node){x,head[y],w};head[y]=cnt;
}
int n,f[N][23],dep[N],m,dfn[N],clk=0,efn[N];
void dfs(int u,int fa,int d)
{
    f[u][0]=fa; dep[u]=d; dfn[u]=++clk;
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa) continue;
        dfs(v,u,d+1);
    }
    efn[u]=clk;
}
inline int lca(int x,int y)
{
    if(dep[x]<dep[y]) swap(x,y);
    int d=dep[x]-dep[y];
    for(int i=19;i>=0;--i) if(d&1<<i) x=f[x][i];
    if(x==y) return x;
    for(int i=19;i>=0;--i) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
inline bool cmp(int a,int b) {return dfn[a]<dfn[b];}

int dis[N],rnk[N];
bool black[N];
namespace seg
{
    pii minv[N<<2];
    int add[N<<2];
    void build(int o,int l,int r)
    {
        add[o]=0;
        if(l==r) 
        {
            if(black[rnk[l]]) minv[o]=pii(dis[rnk[l]],rnk[l]);
            else minv[o]=pii(inf,0);
            return ;
        }
        int mid=l+r>>1;
        build(o<<1,l,mid); build(o<<1|1,mid+1,r);
        minv[o]=min(minv[o<<1],minv[o<<1|1]);
    }
    inline void pushdown(int o)
    {
        if(add[o])
        {
            add[o<<1]+=add[o]; add[o<<1|1]+=add[o];
            minv[o<<1].fi+=add[o]; minv[o<<1|1].fi+=add[o];
            add[o]=0;
        }
    }
    void update(int o,int l,int r,int x,int y,int k)
    {
        if(x<=l&&r<=y) {add[o]+=k; minv[o].fi+=k; return ;}
        pushdown(o);
        int mid=l+r>>1;
        if(x<=mid) update(o<<1,l,mid,x,y,k);
        if(y>mid) update(o<<1|1,mid+1,r,x,y,k);
        minv[o]=min(minv[o<<1],minv[o<<1|1]);
    }
}

int vt[N],tot=0,stk[N],top=0,last[N],dfn2[N],efn2[N],clk2=0,Ans[N];
pii p[N];

inline int getfa(int x,int k)
{
    k=max(k,0);
    for(int i=19;i>=0;--i) if(k&1<<i) x=f[x][i];
    return x;
}
inline void calc(int x,int y,int l) //dep[x]<dep[y]!!!
{
    int cx=getfa(y,dep[y]-dep[x]-1),z=getfa(y,dep[y]-dep[x]-l-1);
    Ans[p[x].se]-=(efn[z]-dfn[z]+1);
    Ans[p[y].se]+=efn[z]-dfn[z]+1-(efn[y]-dfn[y]+1);
}

void dfs1(int u,int fa,int d)
{
    dis[u]=d; dfn2[u]=++clk2; rnk[clk2]=u;
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa) continue;
        dfs1(v,u,d+e[i].w);
    }
    efn2[u]=clk2;
}
void dfs2(int u,int fa)
{   
    if(black[u]) p[u]=pii(0,u);
    else p[u]=seg::minv[1];
    Ans[p[u].se]=max(Ans[p[u].se],efn[u]-dfn[u]+1);
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa) continue;
        seg::update(1,1,tot,1,tot,e[i].w);
        seg::update(1,1,tot,dfn2[v],efn2[v],-2*e[i].w);
        dfs2(v,u);
        seg::update(1,1,tot,1,tot,-e[i].w);
        seg::update(1,1,tot,dfn2[v],efn2[v],2*e[i].w);
    }
}
void dfs3(int u,int fa) //fa'dep must be less than u'dep
{
    for(int i=head[u];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v==fa) continue;
        if(p[u].se==p[v].se) ;
        else
        {
            int len=(e[i].w+p[v].fi-p[u].fi)/2;
            if(!(e[i].w+p[u].fi+p[v].fi&1))
            {
                if(p[v].se<p[u].se) len--;
            }
            calc(u,v,len);
        }
        dfs3(v,u);
    }
}

int ve[N],siz=0;
int main()
{
#ifndef ONLINE_JUDGE
    freopen("1.in","r",stdin);freopen("1.out","w",stdout);
#endif
    n=read();
    for(int i=1;i<n;++i)
    {
        int x=read(),y=read();
        add(x,y,1);
    }
    dfs(1,0,0);
    for(int j=1;j<=19;++j) for(int i=1;i<=n;++i) f[i][j]=f[f[i][j-1]][j-1];

    cnt=0;
    for(int i=1;i<=n;++i) head[i]=0;

    int Q=read();
    for(int cas=1;cas<=Q;++cas)
    {
        m=read();
        siz=0;
        for(int i=1;i<=m;++i) 
            vt[i]=read(),black[vt[i]]=1,last[vt[i]]=cas,ve[++siz]=vt[i];
        sort(vt+1,vt+1+m,cmp);
        tot=m;
        for(int i=1;i<m;++i)
        {
            int o=lca(vt[i],vt[i+1]);
            if(last[o]!=cas) last[o]=cas,vt[++tot]=o;
        }
        sort(vt+1,vt+1+tot,cmp);

        cnt=0;
        for(int i=1;i<=tot;++i) head[vt[i]]=0,Ans[vt[i]]=0;
        stk[top=1]=vt[1];
        for(int i=2;i<=tot;++i)
        {
            while(dfn[vt[i]]>efn[stk[top]]) top--;
            add(stk[top],vt[i],dep[vt[i]]-dep[stk[top]]);
            stk[++top]=vt[i];
        }

        clk2=0;
        dfs1(vt[1],0,0);
        seg::build(1,1,tot);
        dfs2(vt[1],0);

        Ans[p[vt[1]].se]=n;
        dfs3(vt[1],0);
        for(int i=1;i<=siz;++i) printf("%d ",Ans[ve[i]]);
        ln;

        for(int i=1;i<=tot;++i) black[vt[i]]=0;
    }
    return 0;
}
posted @ 2018-03-29 10:27  BLMontgomery  阅读(218)  评论(0编辑  收藏  举报