永琳的竹林迷径(path)
永琳的竹林迷径(path)
题目描述
竹林可以看作是一个n 个点的树,每个边有一个边长wi,其中有k 个关键点,永琳需要破坏这些关键点才能走出竹林迷径。
然而永琳打算将这k 个点编号记录下来,然后随机排列,按这个随机的顺序走过k 个点,但是两点之间她只走最短路线。初始时永琳会施展一次魔法,将自己传送到选定的k 个点中随机后的第一个点。
现在永琳想知道,她走过路程的期望是多少,答案对998244353 取模。
注意,如果对期望不理解,题目最后有详细解释,请自行阅读。
输入
第一行一个数Case,表示测试点编号。(样例的编号表示其满足第Case 个测试点的性质)
下一行一个n,表示树的点数。
下面 n-1 行,每行三个数ui,vi,wi,表示一条边连接ui和vi,长度为wi。
下面一行一个数k,表示关键点数。
下面一行k 个数,表示k 个关键点的编号。
输出
一行一个数,表示答案(对998244353 取模)。
数据范围
对于 100%的数据,保证1≤wi≤1041≤wi≤104。
测试点编号 |
n |
k |
特殊性质 |
1 |
≤10≤10 |
=1=1 |
无 |
2 |
|||
3 |
≤5≤5 |
||
4 |
|||
5 |
≤1000≤1000 |
≤7≤7 |
|
6 |
|||
7 |
≤105≤105 |
≤8≤8 |
|
8 |
|||
9 |
|||
10 |
≤16≤16 |
||
11 |
|||
12 |
|||
13 |
≤105≤105 |
||
14 |
|||
15 |
|||
16 |
|||
17 |
|||
18 |
≤106≤106 |
≤106≤106 |
是一条链 |
19 |
|||
20 |
|||
21 |
无 |
||
22 |
|||
23 |
|||
24 |
|||
25 |
【可能会用到的知识】
关于期望:
期望的定义:离散随机变量的一切可能值与其对应的概率P 的乘积之和称为数学期望。
即: E(x)=∑P(x=k)×val(k)E(x)=∑P(x=k)×val(k)
其中E(x)是期望,P(x=k)是 x=k 发生的概率。
提示:答案必定可以表示成pqpq的形式,在模意义下,pq=p×q−1pq=p×q−1,其中q−1q−1是qq的逆元。
【提示】
读入数据较大,请使用快速的读入方式。
solution
本题求有序走完一个排列的期望长度。
考虑一个点i之前是j的贡献:dist[i]+dist[j]-2*dist[lca]
算出所有点的dist的贡献,lca的贡献则做一次类似树形dp的东西
一个点不同子树互相走的数量就是它作lca的贡献
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 1000006
#define mod 998244353
#define ll long long
using namespace std;
int n,num,head[maxn],s[maxn],flag[maxn],t1,t2,t3,tot;
int Te,size[maxn];
ll ans,Sum,ny2,ny,d[maxn];
struct node{
int v,nex,w;
}e[maxn*2];
int read(){
int v=0,ch;
while(!isdigit(ch=getchar()));v=ch-48;
while(isdigit(ch=getchar()))v=(v<<1)+(v<<3)+ch-48;
return v;
}
void lj(int t1,int t2,int t3){
e[++tot].v=t2;e[tot].w=t3;e[tot].nex=head[t1];head[t1]=tot;
}
void dfs(int k,int fa,ll dist){
d[k]=dist;
for(int i=head[k];i;i=e[i].nex){
if(e[i].v==fa)continue;
dfs(e[i].v,k,dist+e[i].w);
size[k]+=size[e[i].v];
}
size[k]+=flag[k];
}
void dp(int k,int fa){
ll sum=0;
for(int i=head[k];i;i=e[i].nex){
if(e[i].v==fa)continue;
dp(e[i].v,k);
sum+=size[e[i].v];
}
ll cnt=0;
for(int i=head[k];i;i=e[i].nex){
if(e[i].v==fa)continue;
cnt=cnt+size[e[i].v]*(sum-size[e[i].v])%mod;
cnt=cnt%mod;
}
ans=(ans-2*d[k]%mod*cnt%mod)%mod;
if(flag[k]){
ans=(ans-4*d[k]%mod*sum%mod)%mod;
}
}
ll work(ll a,int Num){
ll Ans=1,p=a;
while(Num){
if(Num&1)Ans=Ans*p;
p=p*p;p%=mod;Ans%=mod;Num>>=1;
}
return Ans;
}
int main()
{
Te=read();n=read();
for(int i=1;i<n;i++){
t1=read();t2=read();t3=read();
lj(t1,t2,t3);lj(t2,t1,t3);
}
num=read();
for(int i=1;i<=num;i++){
s[i]=read();flag[s[i]]++;
}
dfs(1,0,(ll)0);
ny2=work(2,mod-2);
for(int i=1;i<=num;i++)Sum+=d[s[i]];
for(int i=1;i<=num;i++){
ans=(ans+d[s[i]]*(num-1))%mod+(Sum-d[s[i]])%mod;
ans=ans%mod;
}
dp(1,0);ans%=mod;
ans=ans*work(num,mod-2);
ans=(ans%mod+mod)%mod;
printf("%lld\n",ans);
return 0;
}