题解 最短路径
考场上想了它们会形成类似一个环去掉一段的形状但并没有什么进展
- 树上经过多个点(每个点有一定概率被指定)的最短路径长度(点与边都可以重复经过):
路径长度就是这 \(k\) 个点的虚树的边长度之和乘2 ,再减去虚树的最长链(就是直径)就是最短路径长度了
于是求虚树的边长度之和的期望
分散到每条边上考虑,则一条边产生贡献当且仅当两边都有点被指定
可以用1减去其中一边没有指定点的概率
再求虚树的直径期望
一种 \(O(n^3)\) 的做法是枚举一条路径 \((u, v)\),考虑若这条路径是直径则还有哪些点可以加到虚树中
条件是这个点到两个端点的距离小于等于直径(等于的时候字典序必须比直径大以避免算重)
还有一种 \(O(n^2)\) 的做法(我没有写只是在口胡)
任选一点,枚举另一个点 \(u\),钦定这个点即为虚树上离选定点最远的点
可以在原树上任选点是考虑这个点到找到的那个点路径上第一个在虚树上的点:就相当于从这个点出发找了离它最远的点
将其它点按离这个选定点的距离排序形成一个序列,则距离大于钦定点的都不能选
剩下的是可能的在虚树上的点
于是复杂度 \(O(n+m^2)\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 2010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m, k;
int head[N], size, key[N];
bool iskey[N];
const ll mod=998244353;
struct edge{int from, to, next;}e[N<<1];
inline void add(int s, int t) {e[++size].to=t; e[size].from=s; e[size].next=head[s]; head[s]=size;}
inline ll qpow(ll a, ll b) {ll ans=1; for (; b; a=a*a%mod,b>>=1) if (b&1) ans=ans*a%mod; return ans;}
inline ll inv(ll a) {return qpow(a, mod-2);}
namespace force{
int buc[N], tag[N];
ll f[N], g[N], ans, cnt;
bool vis[N];
void dfs1(int u, int fa) {
ll dlt=0;
if (vis[u]) tag[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs1(v, u);
if (!tag[v]) continue;
else tag[u]=1;
f[u]+=f[v]+2;
dlt=max(dlt, f[v]-g[v]+1);
}
g[u]=f[u]-dlt;
}
ll dij() {
ll ans=INF;
for (int i=1; i<=k; ++i) {
for (int j=1; j<=n; ++j) f[j]=g[j]=tag[j]=0;
dfs1(buc[i], 0);
ans=min(ans, g[buc[i]]);
}
return ans;
}
void solve() {
int lim=1<<m;
for (int s=1; s<lim; ++s) {
if (__builtin_popcount(s)!=k) continue;
int now=0; ++cnt;
for (int i=1; i<=n; ++i) vis[i]=0;
for (int i=0; i<m; ++i) if (s&(1<<i)) {
buc[++now]=key[i+1];
vis[key[i+1]]=1;
}
ans=(ans+dij())%mod;
// cout<<"ans: "<<ans<<endl;
}
printf("%lld\n", ans*inv(cnt)%mod);
exit(0);
}
}
namespace task1{
int fa[22][N], lg[N], dep[N];
ll ans, cnt;
void dfs(int u, int pa) {
for (int i=1; i<=20; ++i)
if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==pa) continue;
dep[v]=dep[u]+1; fa[0][v]=u; dfs(v, u);
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
if (a==b) return a;
for (int i=lg[dep[a]]-1; ~i; --i)
if (fa[i][a]!=fa[i][b])
a=fa[i][a], b=fa[i][b];
return fa[0][a];
}
inline int dis(int a, int b) {return dep[a]+dep[b]-2*dep[lca(a, b)];}
void solve() {
for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
dep[1]=1; dfs(1, 0);
for (int i=1; i<=m; ++i) {
for (int j=i+1; j<=m; ++j) {
ans=(ans+dis(key[i], key[j]))%mod;
++cnt;
}
}
printf("%lld\n", ans*inv(cnt)%mod);
exit(0);
}
}
namespace task2{
int buc[N], tag[N];
ll f[N], g[N];
bool vis[N];
void dfs1(int u, int fa) {
ll dlt=0;
if (vis[u]) tag[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dfs1(v, u);
if (!tag[v]) continue;
else tag[u]=1;
f[u]+=f[v]+2;
dlt=max(dlt, f[v]-g[v]+1);
}
g[u]=f[u]-dlt;
}
ll dij() {
ll ans=INF;
for (int i=1; i<=m; ++i) {
for (int j=1; j<=n; ++j) f[j]=g[j]=tag[j]=0;
dfs1(buc[i], 0);
ans=min(ans, g[buc[i]]);
}
return ans;
}
void solve() {
for (int i=1; i<=m; ++i) vis[buc[i]=key[i]]=1;
printf("%lld\n", dij());
exit(0);
}
}
namespace task{
ll fac[N], inv2[N], ans;
int siz[N], dep[N], fa[22][N], lg[N], dis2[310][310];
inline ll C(int n, int k) {return n<k?0:fac[n]*inv2[k]%mod*inv2[n-k]%mod;}
void dfs1(int u, int pa) {
siz[u]=iskey[u];
for (int i=1; i<=20; ++i)
if (dep[u]>=1<<i) fa[i][u]=fa[i-1][fa[i-1][u]];
else break;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==pa) continue;
dep[v]=dep[u]+1; fa[0][v]=u; dfs1(v, u);
siz[u]+=siz[v];
}
}
int lca(int a, int b) {
if (dep[a]<dep[b]) swap(a, b);
while (dep[a]>dep[b]) a=fa[lg[dep[a]-dep[b]]-1][a];
if (a==b) return a;
for (int i=lg[dep[a]]-1; ~i; --i)
if (fa[i][a]!=fa[i][b])
a=fa[i][a], b=fa[i][b];
return fa[0][a];
}
inline int dis(int a, int b) {return dep[a]+dep[b]-2*dep[lca(a, b)];}
void solve() {
fac[0]=fac[1]=1; inv2[0]=inv2[1]=1;
for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod;
for (int i=2; i<=n; ++i) inv2[i]=(mod-mod/i)*inv2[mod%i]%mod;
for (int i=2; i<=n; ++i) inv2[i]=inv2[i-1]*inv2[i]%mod;
for (int i=1; i<=n; ++i) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
dep[1]=1; dfs1(1, 0);
ll p=inv(C(m, k)), tem;
for (int i=1,u,v; i<=size; i+=2) {
u=e[i].from; v=e[i].to; tem=0;
if (dep[u]>dep[v]) swap(u, v);
if (siz[v]>=k) tem=(tem+C(siz[v], k)*p)%mod;
if (m-siz[v]>=k) tem=(tem+C(m-siz[v], k)*p)%mod;
ans=(ans+(1-tem))%mod;
}
ans=ans*2%mod;
// cout<<"ans: "<<ans<<endl;
for (int i=1; i<=m; ++i) for (int j=1; j<=m; ++j) dis2[i][j]=dis(key[i], key[j]);
for (int i=1; i<=m; ++i) {
for (int j=i+1; j<=m; ++j) {
int u=key[i], v=key[j], tem=dis2[i][j], cnt=0;
for (int k=1,t; k<=m; ++k) if (k!=i && k!=j) {
t=dis2[i][k];
if (t>tem || (t==tem&&key[k]<v)) continue;
t=dis2[j][k];
if (t>tem || (t==tem&&key[k]<u)) continue;
++cnt;
}
ans=(ans-tem*C(cnt, k-2)%mod*p)%mod;
}
}
printf("%lld\n", (ans%mod+mod)%mod);
exit(0);
}
}
signed main()
{
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
n=read(); m=read(); k=read();
memset(head, -1, sizeof(head));
for (int i=1; i<=m; ++i) iskey[key[i]=read()]=1;
sort(key+1, key+m+1);
for (int i=1,u,v; i<n; ++i) {
u=read(); v=read();
add(u, v); add(v, u);
}
// if (k==2) task1::solve();
// else if (k==m) task2::solve();
// else force::solve();
task::solve();
return 0;
}