P5405 [CTS2019]氪金手游
链接:https://www.luogu.com.cn/problem/P5405
题解:由于这是一个内外向树,不好做,先将它转化为内向树。
转化后其实可算出第\(i\)在\(i\)子树前被抽中的概率:
令子树和为\(SZ\),则\(p=\frac{p_{i}}{n}+\frac{p_{i}}{n}\times (\frac{n-SZ}{n})+\frac{p_{i}}{n}\times (\frac{n-SZ}{n})^2+...\)
那么\(p=\frac{p_{i}}{n}\times\frac{1}{1-\frac{n-SZ}{n}}=\frac{p_{i}}{n}\times\frac{n}{SZ}=\frac{p_{i}}{SZ}\)
所以我们可以令\(dp_{i,j}\)表示第\(i\)个节点,子树和为\(j\)的概率,转移即可。
但又有外向边,那么对于外向边的贡献,可拆为总贡献\(-\)内向边的贡献,那么转移时将总共献算进去,并且将内向边本身贡献乘\(-1\),这样就可以处理外向边了。
#include<iostream>
#include<cstdio>
#define int long long
#define mod 998244353
using namespace std;
struct node
{
int v,nxt,data;
};
node edge[2001];
long long len,ans,x,y,z,S[1001],inv[3001],sz[1001],T[1001][4],f[1001][3001],dp[1001][3001],head[1001],n;
bool used[1001];
int read()
{
char c=0;
int sum=0;
while (c<'0'||c>'9')
c=getchar();
while ('0'<=c&&c<='9')
{
sum=sum*10+c-'0';
c=getchar();
}
return sum;
}
void add(int x,int y,int z)
{
edge[++len].v=y;
edge[len].data=z;
edge[len].nxt=head[x];
head[x]=len;
return;
}
long long fast_pow(long long a,int b)
{
if (b==0)
return 1;
if (b&1)
return fast_pow(a*a%mod,b/2)*a%mod;
else
return fast_pow(a*a%mod,b/2);
}
void dfs(int x)
{
used[x]=1;
sz[x]=1;
dp[x][0]=1;
for (int i=head[x];i>0;i=edge[i].nxt)
if (!used[edge[i].v])
{
dfs(edge[i].v);
sz[x]+=sz[edge[i].v];
if (edge[i].data==1)
{
for (int j=3*sz[x];j>=1;--j)
{
dp[x][j]=0;
for (int k=max(j+3*(sz[edge[i].v]-sz[x]),1ll);k<=min(3*sz[edge[i].v],j*1ll);++k)
dp[x][j]=(dp[x][j]+dp[edge[i].v][k]*dp[x][j-k]%mod)%mod;
}
dp[x][0]=0;
}
else
{
for (int j=3*sz[x];j>=1;--j)
{
dp[x][j]=dp[x][j]*S[edge[i].v]%mod;
for (int k=max(j+3*(sz[edge[i].v]-sz[x]),1ll);k<=min(3*sz[edge[i].v],j*1ll);++k)
dp[x][j]=(dp[x][j]-dp[edge[i].v][k]*dp[x][j-k]%mod)%mod;
}
dp[x][0]=dp[x][0]*S[edge[i].v]%mod;
}
}
for (int j=1;j<=3*sz[x];++j)
for (int t=1;t<=min(3ll,j);++t)
f[x][j]=(f[x][j]+dp[x][j-t]*t%mod*inv[j]%mod*T[x][t]%mod)%mod;
for (int j=1;j<=3*sz[x];++j)
dp[x][j]=f[x][j];
for (int j=1;j<=3*sz[x];++j)
S[x]=(S[x]+dp[x][j])%mod;
return;
}
signed main()
{
n=read();
for (int i=1;i<=n;++i)
{
x=read(),y=read(),z=read();
T[i][1]=x*fast_pow(x+y+z,mod-2)%mod;
T[i][2]=y*fast_pow(x+y+z,mod-2)%mod;
T[i][3]=z*fast_pow(x+y+z,mod-2)%mod;
}
for (int i=1;i<=n-1;++i)
{
x=read(),y=read();
add(x,y,1);
add(y,x,0);
}
inv[1]=1;
for (int i=2;i<=3*n;++i)
inv[i]=(-inv[mod%i]*(mod/i)%mod+mod)%mod;
dfs(1);
printf("%lld\n",(S[1]+mod)%mod);
return 0;
}
本文来自博客园,作者:zhouhuanyi,转载请注明原文链接:https://www.cnblogs.com/zhouhuanyi/p/16983723.html