【BZOJ3648】寝室管理-环套树+点分治+树状数组

测试地址:寝室管理
题目大意:给定一棵树或环套树,求图中经过至少k个点的路径数。
做法:本题需要用到环套树+点分治+树状数组。
先考虑树上的做法。对于这种树上路径计数的问题,应该能形成一种条件反射了,不能DP马上想到点分治。点分治中,每一次我们考虑过某个点的合法路径数时,先把子树列成一列,对于一棵子树里的所有点,它到根的距离dis和之前子树中的点到根的距离x应该满足x+disk1才是合法的,那么实际上我们就是要求之前子树中满足xkdis1的点的数量,这样一个明显的后缀和形式显然可以用树状数组维护。那么我们就得到了一个O(nlog2n)的树上的算法。
那么再考虑环套树。首先对于所有外向树,我们都可以点分治出该外向树中的所有合法路径,因此我们只需要再考虑过环上的路径即可。为了不算重,我们需要计算从每个环上点的外向树中的点,顺时针(或逆时针,总之就是按同一个方向)走环,最后走到某个其他点的合法路径数。按套路破环为链并倍长,然后顺次编号,那么如果两个点u,v到它们对应的外向树的根的距离分别为dis,x,而它们外向树的编号分别为i,j,不妨设i<ji=j时就是同一棵外向树了,我们已经算过了),那么u,v之间的路径合法当且仅当ji+dis+xk1成立。也就是x+jkdis1+i成立。因此我们可以把x+j看做每个点的权值,这样我们就可以相似地用树状数组求出满足条件的点数了。上述算法的时间复杂度为O(nlogn),加上点分治,总的时间复杂度还是O(nlog2n),可以通过此题。
我傻逼的地方:太久没写大代码了,重心又求错了,TLE了两发……我可能要NOIP退役了……
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,k,first[100010]={0},tot=0,limit=0;
int st[100010],top,siz[100010],mxson[100010];
int inst[100010]={0},loop[100010],looplen;
ll ans=0,sum[400010]={0};
bool vis[100010]={0},inloop[100010]={0};
struct edge
{
    int v,next;
}e[200010];

void insert(int a,int b)
{
    e[++tot].v=b;
    e[tot].next=first[a];
    first[a]=tot;
}

int lowbit(int x)
{
    return x&(-x);
}

void add(int x,ll d)
{
    for(int i=x;i<=(n<<2);i+=lowbit(i))
        sum[i]+=d;
}

ll query(int x)
{
    ll ans=0;
    for(int i=x;i;i-=lowbit(i))
        ans+=sum[i];
    return ans;
}

ll Sum(int l,int r)
{
    if (r<1||l>r) return 0;
    return query(r)-query(l-1);
}

void dp(int v,int fa)
{
    st[++top]=v;
    siz[v]=1,mxson[v]=0;
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v]&&e[i].v!=fa)
        {
            if (inloop[v]&&inloop[e[i].v]) continue;
            dp(e[i].v,v);
            mxson[v]=max(mxson[v],siz[e[i].v]);
            siz[v]+=siz[e[i].v];
        }
}

int find_ctr(int v)
{
    top=0;
    dp(v,0);
    int mn=1000000000,mni;
    for(int i=1;i<=top;i++)
        if (max(mxson[st[i]],siz[v]-siz[st[i]])<mn)
        {
            mn=max(mxson[st[i]],siz[v]-siz[st[i]]);
            mni=st[i];
        }
    return mni;
}

void maintain(int v,int fa,int dis,ll d)
{
    add(dis,d);
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v]&&e[i].v!=fa)
        {
            if (inloop[v]&&inloop[e[i].v]) continue;
            maintain(e[i].v,v,dis+1,d);
        }
}

void calc(int v,int fa,int dis)
{
    ans+=Sum(max(1,k-dis-1+limit),n<<2);
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v]&&e[i].v!=fa)
        {
            if (inloop[v]&&inloop[e[i].v]) continue;
            calc(e[i].v,v,dis+1);
        }
}

void solve(int v)
{
    v=find_ctr(v);
    vis[v]=1;
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v])
        {
            if (inloop[v]&&inloop[e[i].v]) continue;
            solve(e[i].v);
        }

    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v])
        {
            if (inloop[v]&&inloop[e[i].v]) continue;
            calc(e[i].v,0,1);
            maintain(e[i].v,0,1,1);
        }
    ans+=Sum(k-1,n);
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v])
        {
            if (inloop[v]&&inloop[e[i].v]) continue;
            maintain(e[i].v,0,1,-1);
        }

    vis[v]=0;
}

bool find_loop(int v,int fa)
{
    st[++top]=v;
    inst[v]=top;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa)
        {
            if (!inst[e[i].v])
            {
                if (find_loop(e[i].v,v))
                    return 1;
            }
            else
            {
                looplen=0;
                for(int j=inst[e[i].v];j<=top;j++)
                {
                    loop[++looplen]=st[j];
                    inloop[st[j]]=1;
                }
                return 1;
            }
        }
    top--;
    inst[v]=0;
    return 0;
}

int main()
{
    scanf("%d%d%d",&n,&m,&k); 
    for(int i=1;i<=m;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        insert(a,b),insert(b,a);
    }

    if (m<n)
    {
        solve(1);
    }
    else
    {
        top=0;
        find_loop(1,0);
        for(int i=1;i<=looplen;i++)
            solve(loop[i]);
        for(int i=1;i<=looplen;i++)
            maintain(loop[i],0,i,1);
        for(limit=1;limit<=looplen;limit++)
        {
            maintain(loop[limit],0,limit,-1);
            calc(loop[limit],0,0);
            maintain(loop[limit],0,looplen+limit,1);
        }
    }
    printf("%lld",ans);

    return 0;
}
posted @ 2018-08-05 11:32  Maxwei_wzj  阅读(98)  评论(0编辑  收藏  举报