P10060 [SNOI2024] 树 V 图

原题链接

首先想到 \(f\) 值相同的点一定构成一个连通块,所以应当有 \(k\) 个连通块并且每个连通块 \(f\) 值互不相同。

判断一下 \([1,k]\) 是否在 \(f\) 中都出现过,并且是否有 \(k-1\) 条边两个端点的 \(f\) 值不同,若有不符合的就是非法输入,直接输出 \(0\)

考虑 \(k=2\) 的部分分,对于那个两端点不同的边 \((x,y)\),一个属于 \(x\) 所在连通块的点 \(i\),一个属于 \(y\) 所在连通块的点 \(j\),点对 \((i,j)\) 是一组合法答案当且仅当

\[([dis(x,i)=dis(y,j)+1]\wedge[a_x<a_y]\vee[dis(x,i)<dis(y,j)+1]) \]

并且

\[([dis(y,j)=dis(x,i)+1]\wedge[a_y<a_x]\vee[dis(y,j)<dis(x,i)+1]) \]

于是可以枚举每个点对,时间复杂度 \(O(n^2)\)

对于所有数据考虑,每个连通块缩起来以后肯定也构成一棵 \(k\) 个节点的树,在这棵树上从下往上处理即可。

具体的,记 \(f_i\) 为只考虑 \(i\) 所在的连通块及其子树中,最后 \(i\) 可以作为关键点的方案数,初始时都是 \(1\)。对于一个子节点,枚举其中所有的节点来暴力和当前连通块中的节点做上述的匹配。

对于一个 \(i\),找到所有合法的 \(j\),将 \(f_j\) 累加起来再乘到 \(f_i\) 上,因为不同的子节点的贡献是相乘的关系。

可以先 \(O(n^2)\) 预处理出每个点对的距离,上述算法每个点对只会考虑一次所以是 \(O(n^2)\) 的。

#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const int MAXN=3e3+10,MOD=998244353;
int T,n,k,cur,cnt,x[MAXN],y[MAXN],a[MAXN];
long long t[MAXN],f[MAXN],ans;bool vis[MAXN];
vector <int> v[MAXN],p[MAXN][MAXN];
inline void clear()
{
    for(int i=1;i<=n;++i)
    {
        vis[a[i]]=false;t[i]=f[i]=0,v[i].clear();
        for(int j=0;!p[i][j].empty();++j) p[i][j].clear();
    }
    cnt=ans=0;return ;
}
void init(int x,int fa=0,int dep=0)
{
    p[cur][dep].push_back(x);
    for(int y:v[x])
        if(y!=fa&&a[x]==a[y]) init(y,x,dep+1);
    return ;
}
void dfs(int x,int fa=0)
{
    for(int y:v[x])
    {
        if(y==fa) continue;
        if(a[x]==a[y]) dfs(y,x);
        else
        {
            dfs(y,x);
            for(int depy=0;!p[y][depy].empty();++depy)
                for(int j:p[y][depy])
            for(int depx=max(depy-1,0);depx<=depy+1;++depx)//这里有个小优化,但是其实没啥必要也没啥用
                for(int i:p[x][depx])
                if((depx<depy+1||a[x]<a[y])&&(depy<depx+1||a[y]<a[x]))
                    t[i]=(t[i]+f[j])%MOD;
            for(int depx=0;!p[x][depx].empty();++depx)
                for(int i:p[x][depx]) f[i]=f[i]*t[i]%MOD,t[i]=0;
        }
    }
    return ;
}
inline void work()
{
    clear();cin>>n>>k;
    for(int i=1;i<n;++i)
        cin>>x[i]>>y[i],
        v[x[i]].push_back(y[i]),
        v[y[i]].push_back(x[i]);
    for(int i=1;i<=n;++i)
        cin>>a[i],vis[a[i]]=true,f[i]=1;
    for(int i=1;i<=k;++i)
        if(!vis[i]) {cout<<"0\n";return ;}
    for(int i=1;i<n;++i)
        if(a[x[i]]!=a[y[i]]) ++cnt;
    if(cnt!=k-1){cout<<"0\n";return ;}
    for(int i=1;i<=n;++i) cur=i,init(i);dfs(1);
    for(int d=0;!p[1][d].empty();++d)
        for(int i:p[1][d]) ans=(ans+f[i])%MOD;
    cout<<ans<<'\n';return ;
}
int main()
{
    cin.tie(0),cout.tie(0);
    ios::sync_with_stdio(0);
    cin>>T;while(T--) work();
    return 0;
}
posted @ 2024-01-22 18:07  int_R  阅读(53)  评论(0编辑  收藏  举报