BZOJ 2152 聪聪可可(树形DP)
给出一颗n个点带边权的树(n<=20000),求随机选择两个点,使得它们之间的路径边权是3的倍数的概率是多少。
首先总的对数是n*n,那么只需要统计路径边权是3的倍数的点对数量就行了。
考虑将无根树化为有根树,令dp[x][i]表示以x点为路径起点,x的某个子孙为路径终点的边权值模3为i的点对数量。
那么显然有dp[x][i]+=dp[son[x]][(i-w)%3].
考虑点对之间的路径,要么是它们的LCA是点对中的一个点,要么不在点对中,因此统计一下以每个点x为LCA时的路径边权值%3为i的点对数量。
而这两个统计都可以在一次树形DP中完成。因此总复杂度为O(n).
# include <cstdio> # include <cstring> # include <cstdlib> # include <iostream> # include <vector> # include <queue> # include <stack> # include <map> # include <bitset> # include <set> # include <cmath> # include <algorithm> using namespace std; # define lowbit(x) ((x)&(-x)) # define pi acos(-1.0) # define eps 1e-8 # define MOD 30031 # define INF 1000000000 # define mem(a,b) memset(a,b,sizeof(a)) # define FOR(i,a,n) for(int i=a; i<=n; ++i) # define FO(i,a,n) for(int i=a; i<n; ++i) # define bug puts("H"); # define lch p<<1,l,mid # define rch p<<1|1,mid+1,r # define mp make_pair # define pb push_back typedef pair<int,int> PII; typedef vector<int> VI; # pragma comment(linker, "/STACK:1024000000,1024000000") typedef long long LL; int Scan() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } const int N=20005; //Code begin... struct Edge{int p, next, w;}edge[N<<1]; int head[N], cnt=1, dp[N][3], son[N]; void add_edge(int u, int v, int w){edge[cnt].p=v; edge[cnt].w=w; edge[cnt].next=head[u]; head[u]=cnt++;} void dfs(int x, int fa){ for (int i=head[x]; i; i=edge[i].next) { int v=edge[i].p; if (v==fa) continue; dfs(v,x); FO(j,0,3) dp[x][j]+=dp[v][((j-edge[i].w)%3+3)%3]; } for (int i=head[x]; i; i=edge[i].next) { int v=edge[i].p; if (v==fa) continue; int y0=((-edge[i].w)%3+3)%3, y1=((1-edge[i].w)%3+3)%3, y2=((2-edge[i].w)%3+3)%3; son[x]+=dp[v][y0]*(dp[x][0]-dp[v][y0])+dp[v][y1]*(dp[x][2]-dp[v][y2])+dp[v][y2]*(dp[x][1]-dp[v][y1]); } dp[x][0]+=1; } int main () { int n, ans=0, sum=0, u, v, w; scanf("%d",&n); FO(i,1,n) scanf("%d%d%d",&u,&v,&w), add_edge(u,v,w%3), add_edge(v,u,w%3); dfs(1,0); FOR(i,1,n) ans+=dp[i][0]; ans=ans*2-n; sum=n*n; FOR(i,1,n) ans+=son[i]; int gcd=__gcd(ans,sum); ans/=gcd; sum/=gcd; printf("%d/%d\n",ans,sum); return 0; }