联考20200729 T1 zsy家今天的饭



分析:
如果我们知道了有哪些点需要访问,最短距离是多少呢
建出虚树,所有边权和为\(Sum\),直径为\(L\),那么答案为\(2Sum-L\)
期望=总贡献/方案
方案肯定为\(\binom{m}{k}\),我们开始算总贡献
先求\(Sum\),考虑每条边会在多少种情况下做贡献,显然是其两端都有关键点的情况下
设其在树上所接的儿子为\(u\)
这里的方案为\(\binom{m}{k}-\binom{m-sz_u}{k}-\binom{sz_u}{k}\)
应该不用解释什么

开始算直径
(我记得直径期望不是个很恐怖的东西吗(错乱)
由于这里\(m\)只有500,我们可以暴力处理出每对点的距离
我们暴力枚举,强行让某两点做直径端点,遇到同样大小的取编号最小,看剩下哪些点是可以选择的,假设有\(P\)
那么这条直径做贡献的方案数为\(\binom{P}{k-2}\)

总复杂度\(O(nlogn+m^2logn+m^3)\),可以通过
\(O(nlogn+m^2logn)\)这里看自己的LCA求法吧,我主要为了省事(

#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<vector>
#include<string>

#define maxn 200005
#define maxm 505
#define INF 0x3f3f3f3f
#define MOD 998244353

using namespace std;

inline long long getint()
{
	long long num=0,flag=1;char c;
	while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
	while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
	return num*flag;
}

int n,m,K;
int fir[maxn],nxt[maxn],to[maxn],len[maxn],cnt;
int f[maxn][18],sz[maxn],dpt[maxn];
long long dis[maxn];
int C[maxm][maxm];
int p[maxm];
long long D[maxm][maxm];
int ans;

inline int upd(int x){return x<MOD?x:x-MOD;}
inline int ksm(int num,int k)
{
	int ret=1;
	for(;k;k>>=1,num=1ll*num*num%MOD)if(k&1)ret=1ll*ret*num%MOD;
	return ret;
}

inline void newnode(int u,int v,int w)
{to[++cnt]=v,nxt[cnt]=fir[u],fir[u]=cnt,len[cnt]=w;}
inline void dfs(int u)
{
	for(int i=fir[u];i;i=nxt[i])if(to[i]!=f[u][0])
	{
		dpt[to[i]]=dpt[u]+1,dis[to[i]]=dis[u]+len[i],f[to[i]][0]=u;
		dfs(to[i]),sz[u]+=sz[to[i]];
		int tmp=upd(C[m][K]-upd(C[m-sz[to[i]]][K]+C[sz[to[i]]][K])+MOD);
		ans=(ans+1ll*len[i]*tmp)%MOD;
	}
}

inline int LCA(int u,int v)
{
	if(dpt[u]<dpt[v])swap(u,v);
	for(int i=17;~i;i--)if((dpt[u]-dpt[v])&(1<<i))u=f[u][i];
	if(u==v)return u;
	for(int i=17;~i;i--)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
	return f[u][0];
}
inline long long getdis(int u,int v)
{return dis[u]+dis[v]-2*dis[LCA(u,v)];}

int main()
{
	n=getint(),m=getint(),K=getint();
	for(int i=1;i<=m;i++)sz[p[i]=getint()]=1;
	for(int i=1;i<n;i++)
	{
		int u=getint(),v=getint(),w=getint();
		newnode(u,v,w),newnode(v,u,w);
	}
	if(K==1){printf("0\n");return 0;}
	for(int i=0;i<=m;i++)
	{
		C[i][0]=1;
		for(int j=1;j<=i;j++)C[i][j]=upd(C[i-1][j-1]+C[i-1][j]);
	}
	dfs(1);
	ans=upd(2*ans);
	for(int j=1;j<18;j++)for(int i=1;i<=n;i++)f[i][j]=f[f[i][j-1]][j-1];
	for(int i=1;i<=m;i++)for(int j=1;j<=m;j++)D[i][j]=getdis(p[i],p[j]);
	for(int i=1;i<=m;i++)for(int j=i+1;j<=m;j++)
	{
		int P=0;
		long long L=D[i][j];
		for(int k=1;k<=m;k++)
		{
			long long L1=D[i][k],L2=D[j][k];
			if((L>L1||(L==L1&&j<k))&&(L>L2||(L==L2&&i<k)))P++;
		}
		L%=MOD;
		ans=upd(ans-1ll*L*C[P][K-2]%MOD+MOD);
	}
	ans=1ll*ans*ksm(C[m][K],MOD-2)%MOD;
	printf("%d\n",ans);
}

posted @ 2020-07-29 15:15  Izayoi_Doyo  阅读(199)  评论(0编辑  收藏  举报