[SNOI2024]公交线路 题解

为啥洛谷现有的题解全是 \(O(n^2\log n)\) 的做法?给个好写的 \(O(n^2)\) 做法。

感觉这题是这套题中除了 D1T1 以外最简单的题(


显然最远的距离一定由两个叶子贡献,我们拎出一个非叶节点为根,分析一些性质。

考虑两个叶子 \(u,v\) 何时距离 \(\le 2\),这要求它们所一步能到达的最浅点 \(f(u),f(v)\) 为祖先后代关系。不妨设 \(f(u)\)\(f(v)\) 的祖先,还要求 \(u\) 存在一个走一步的方式走入 \(f(v)\) 的子树中。容易发现这便是充要条件。

考虑枚举 \(x\) 表示最深的 \(f(u)\),这个限制肯定最严。现在要求所有叶子都可以一步走到 \(x\),并且存在一个在其子树内的叶子满足所能达到的最浅点为 \(x\)

注意到第二个限制处理起来并不是很方便,可以先去掉这个条件,再扣掉所有在 \(x\) 子树内的叶子都能达到 \(x\) 的父亲的方案数。本质就是点数减边数等于一这个经典容斥。

只有第一个限制后直接做还是不好做,考虑容斥,这样就变成有一些边不能选,这是方便 dp 的。如果对所有叶子容斥跑背包的话是 \(O(n^3)\) 的,考虑优化。

注意到只对子树内的叶子跑这个算法就是树形背包,复杂度为 \(O(n^2)\),这是可以接受的。

于是我们只容斥子树内的叶子,由于只剩下最后一类反子树的点没有考虑,在确定了前面的信息之后我们是可以直接计算符合题意的方案数的,这样就把复杂度优化到了 \(O(n^2)\)

最后剩下的一点问题就是扣掉都能达到 \(x\) 的父亲的方案数,依旧考虑类似的做法,\(O(n^2)\) 计算是容易的。

#include<bits/stdc++.h>
#define IL inline
#define reg register
#define mod 998244353
#define N 3030
IL int read()
{
    reg int x=0; reg char ch=getchar();
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
    return x;
}

IL int Add(reg int x,reg int y){return x+y<mod?x+y:x+y-mod;}
IL int Sub(reg int x,reg int y){return x<y?x-y+mod:x-y;}
IL void Pls(reg int &x,reg int y){x=Add(x,y);}
IL void Dec(reg int &x,reg int y){x=Sub(x,y);}
IL int Mul(reg int x,reg int y){reg long long r=1ll*x*y; return r<mod?r:r%mod;}

int pw[N*N],c[N][N],con[N][N];

IL void init(reg int n)
{
    pw[0]=1;
    for(reg int i=1;i<=n*n;++i)pw[i]=Mul(pw[i-1],2);
    for(reg int i=0,j;i<=n;++i)
    {
        con[i][0]=1;
        for(j=1;j<=n;++j)con[i][j]=Mul(con[i][j-1],Sub(pw[i],1));
    }
    for(reg int i=0;i<=n;++i)c[i][0]=1;
    for(reg int i=1,j;i<=n;++i)for(j=1;j<=i;++j)c[i][j]=Add(c[i-1][j],c[i-1][j-1]);
}

int n,rt,ans;
std::vector<int>G[N];

IL void add(reg int u,reg int v){G[u].push_back(v),G[v].push_back(u);}

int fa[N],sz[N],cnt[N];

void dfs(reg int u)
{
    sz[u]=1,cnt[u]=G[u].size()==1;
    for(reg auto v:G[u])if(!sz[v])
        fa[v]=u,dfs(v),sz[u]+=sz[v],cnt[u]+=cnt[v];
}

IL int A(reg int n){return n*(n-1)>>1;}

main()
{
    n=read();
    if(n<=2)puts("1"),exit(0);
    init(n);
    for(reg int i=n;--i;)add(read(),read());
    for(rt=1;G[rt].size()==1;++rt);
    dfs(rt);
    for(reg int u=1,i,j,w,up;u<=n;++u)
    {
        static int f[N],g[N];
        for(i=1;i<=n;++i)f[i]=0;
        f[1]=up=1;
        for(reg auto v:G[u])if(v!=fa[u])
        {
            for(i=1;i<=up;++i)for(j=0;j<=cnt[v];++j)
            {
                w=Mul(Mul(f[i],c[cnt[v]][j]),pw[i*(sz[v]-j)]);
                if(j&1)Dec(g[i+sz[v]-j],w); else Pls(g[i+sz[v]-j],w);
            }
            up+=sz[v];
            for(i=1;i<=up;++i)f[i]=g[i],g[i]=0;
        }
        reg int a0=cnt[u],b0=sz[u]-a0,a1=cnt[rt]-cnt[u],b1=n-sz[u]-a1;
        w=0;
        for(i=1;i<=up;++i)Pls(w,Mul(Mul(f[i],pw[i*b1]),con[i][a1]));
        for(reg auto v:G[u])if(v!=fa[u])w=Mul(w,pw[A(sz[v])]);
        w=Mul(w,pw[A(n-sz[u])]),Pls(ans,w),w=0;
        for(i=0;i<=a0;++i)
        {
            reg int k=Mul(Mul(c[a0][i],pw[(a0-i)*b1]),con[b0+a0-i][a1]);
            if(i&1)Dec(w,k); else Pls(w,k);
        }
        w=Mul(w,pw[A(a0+b0)+A(a1+b1)+b0*b1]),Dec(ans,w);
    }
    printf("%d\n",ans);
}
posted @ 2024-01-24 13:44  Nesraychan  阅读(172)  评论(0)    收藏  举报