洛谷 P2634 [国家集训队]聪聪可可
聪聪和可可是兄弟俩,他们俩经常为了一些琐事打起来,例如家中只剩下最后一根冰棍而两人都想吃、两个人都想玩儿电脑(可是他们家只有一台电脑)……遇到这种问题,一般情况下石头剪刀布就好了,可是他们已经玩儿腻了这种低智商的游戏。
他们的爸爸快被他们的争吵烦死了,所以他发明了一个新游戏:由爸爸在纸上画n个“点”,并用n-1条“边”把这n个“点”恰好连通(其实这就是一棵树)。并且每条“边”上都有一个数。接下来由聪聪和可可分别随即选一个点(当然他们选点时是看不到这棵树的),如果两个点之间所有边上数的和加起来恰好是3的倍数,则判聪聪赢,否则可可赢。
聪聪非常爱思考问题,在每次游戏后都会仔细研究这棵树,希望知道对于这张图自己的获胜概率是多少。现请你帮忙求出这个值以验证聪聪的答案是否正确。
一看就是点分治,注意下处理答案的时候就好了
每次处理出来这个子树的路径长度模\(3\)后的长度的个数\(s_0,s_1,s_2\),再和之前累加的值\(mp_0,mp_1,mp_2\)更新答案
根自己到自己也就是\(mp_0\)刚开始赋为\(1\),这样便于统计这个点到根的答案
最后加上每个点自己到自己的方案数也就是加上\(n\)
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
const int N = 20000;
using namespace std;
int n,vis[N + 5],dis[N + 5],num,su,cnt,rt,size[N + 5],maxp[N + 5],s[5],ans,mp[5];
struct node
{
int to,cost;
};
vector <node> d[N + 5];
void get_rt(int u,int fa)
{
size[u] = 1;
maxp[u] = 0;
vector <node>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it).to;
if (v == fa || vis[v])
continue;
get_rt(v,u);
size[u] += size[v];
maxp[u] = max(maxp[u],size[v]);
}
maxp[u] = max(maxp[u],su - size[u]);
if (maxp[u] < maxp[rt])
rt = u;
}
void get_dis(int u,int fa)
{
s[dis[u]]++;
vector <node>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it).to,w = (*it).cost;
if (v == fa || vis[v])
continue;
dis[v] = (dis[u] + w) % 3;
get_dis(v,u);
}
}
void calc(int u)
{
vector <node>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it).to,w = (*it).cost;
if (vis[v])
continue;
dis[v] = w % 3;
s[0] = s[1] = s[2] = 0;
get_dis(v,u);
ans += (mp[0] * s[0] + mp[1] * s[2] + mp[2] * s[1]) * 2;
for (int i = 0;i <= 2;i++)
mp[i] += s[i];
}
for (int i = 0;i <= 2;i++)
mp[i] = 0;
}
void solve(int u)
{
vis[u] = 1;
mp[0] = 1;
calc(u);
vector <node>::iterator it;
for (it = d[u].begin();it != d[u].end();it++)
{
int v = (*it).to;
if (vis[v])
continue;
su = size[v];
rt = 0;
maxp[0] = size[v];
get_rt(v,0);
solve(rt);
}
}
int gcd(int a,int b)
{
if (!b)
return a;
return gcd(b,a % b);
}
int main()
{
scanf("%d",&n);
int u,v,w;
for (int i = 1;i < n;i++)
{
scanf("%d%d%d",&u,&v,&w);
d[u].push_back((node){v,w});
d[v].push_back((node){u,w});
}
su = n;
maxp[rt] = n;
get_rt(1,0);
solve(rt);
ans += n;
printf("%d/%d\n",ans / gcd(ans,n * n),n * n / gcd(ans,n * n));
return 0;
}