P10060 [SNOI2024] 树 V 图
首先想到 \(f\) 值相同的点一定构成一个连通块,所以应当有 \(k\) 个连通块并且每个连通块 \(f\) 值互不相同。
判断一下 \([1,k]\) 是否在 \(f\) 中都出现过,并且是否有 \(k-1\) 条边两个端点的 \(f\) 值不同,若有不符合的就是非法输入,直接输出 \(0\)。
考虑 \(k=2\) 的部分分,对于那个两端点不同的边 \((x,y)\),一个属于 \(x\) 所在连通块的点 \(i\),一个属于 \(y\) 所在连通块的点 \(j\),点对 \((i,j)\) 是一组合法答案当且仅当
\[([dis(x,i)=dis(y,j)+1]\wedge[a_x<a_y]\vee[dis(x,i)<dis(y,j)+1])
\]
并且
\[([dis(y,j)=dis(x,i)+1]\wedge[a_y<a_x]\vee[dis(y,j)<dis(x,i)+1])
\]
于是可以枚举每个点对,时间复杂度 \(O(n^2)\)。
对于所有数据考虑,每个连通块缩起来以后肯定也构成一棵 \(k\) 个节点的树,在这棵树上从下往上处理即可。
具体的,记 \(f_i\) 为只考虑 \(i\) 所在的连通块及其子树中,最后 \(i\) 可以作为关键点的方案数,初始时都是 \(1\)。对于一个子节点,枚举其中所有的节点来暴力和当前连通块中的节点做上述的匹配。
对于一个 \(i\),找到所有合法的 \(j\),将 \(f_j\) 累加起来再乘到 \(f_i\) 上,因为不同的子节点的贡献是相乘的关系。
可以先 \(O(n^2)\) 预处理出每个点对的距离,上述算法每个点对只会考虑一次所以是 \(O(n^2)\) 的。
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const int MAXN=3e3+10,MOD=998244353;
int T,n,k,cur,cnt,x[MAXN],y[MAXN],a[MAXN];
long long t[MAXN],f[MAXN],ans;bool vis[MAXN];
vector <int> v[MAXN],p[MAXN][MAXN];
inline void clear()
{
for(int i=1;i<=n;++i)
{
vis[a[i]]=false;t[i]=f[i]=0,v[i].clear();
for(int j=0;!p[i][j].empty();++j) p[i][j].clear();
}
cnt=ans=0;return ;
}
void init(int x,int fa=0,int dep=0)
{
p[cur][dep].push_back(x);
for(int y:v[x])
if(y!=fa&&a[x]==a[y]) init(y,x,dep+1);
return ;
}
void dfs(int x,int fa=0)
{
for(int y:v[x])
{
if(y==fa) continue;
if(a[x]==a[y]) dfs(y,x);
else
{
dfs(y,x);
for(int depy=0;!p[y][depy].empty();++depy)
for(int j:p[y][depy])
for(int depx=max(depy-1,0);depx<=depy+1;++depx)//这里有个小优化,但是其实没啥必要也没啥用
for(int i:p[x][depx])
if((depx<depy+1||a[x]<a[y])&&(depy<depx+1||a[y]<a[x]))
t[i]=(t[i]+f[j])%MOD;
for(int depx=0;!p[x][depx].empty();++depx)
for(int i:p[x][depx]) f[i]=f[i]*t[i]%MOD,t[i]=0;
}
}
return ;
}
inline void work()
{
clear();cin>>n>>k;
for(int i=1;i<n;++i)
cin>>x[i]>>y[i],
v[x[i]].push_back(y[i]),
v[y[i]].push_back(x[i]);
for(int i=1;i<=n;++i)
cin>>a[i],vis[a[i]]=true,f[i]=1;
for(int i=1;i<=k;++i)
if(!vis[i]) {cout<<"0\n";return ;}
for(int i=1;i<n;++i)
if(a[x[i]]!=a[y[i]]) ++cnt;
if(cnt!=k-1){cout<<"0\n";return ;}
for(int i=1;i<=n;++i) cur=i,init(i);dfs(1);
for(int d=0;!p[1][d].empty();++d)
for(int i:p[1][d]) ans=(ans+f[i])%MOD;
cout<<ans<<'\n';return ;
}
int main()
{
cin.tie(0),cout.tie(0);
ios::sync_with_stdio(0);
cin>>T;while(T--) work();
return 0;
}