BZOJ2152 聪聪可可 (点分治)

2152: 聪聪可可

题意:

  在一棵边带权的树中,问任取两个点,这两个点间的权值和是3的倍数的概率。

思路:

  经典的点分治题目。

  利用点分治在计算所有路径长度,把路径长度对3取模,用$t[0],t[1],t[2]$分别记录模为0、1、2的情况,那么显然答案就是$t[1]*t[2]*2+t[0]*t[0]$。

#include <algorithm>
#include  <iterator>
#include  <iostream>
#include   <cstring>
#include   <cstdlib>
#include   <iomanip>
#include    <bitset>
#include    <cctype>
#include    <cstdio>
#include    <string>
#include    <vector>
#include     <cmath>
#include     <queue>
#include      <list>
#include       <map>
#include       <set>
using namespace std;
//#pragma GCC optimize(3)
//#pragma comment(linker, "/STACK:102400000,102400000")  //c++
#define lson (l , mid , rt << 1)
#define rson (mid + 1 , r , rt << 1 | 1)
#define debug(x) cerr << #x << " = " << x << "\n";
#define pb push_back
#define pq priority_queue



typedef long long ll;
typedef unsigned long long ull;

typedef pair<ll ,ll > pll;
typedef pair<int ,int > pii;
typedef pair<int,pii> p3;

//priority_queue<int> q;//这是一个大根堆q
//priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q
#define fi first
#define se second
//#define endl '\n'

#define OKC ios::sync_with_stdio(false);cin.tie(0)
#define FT(A,B,C) for(int A=B;A <= C;++A)  //用来压行
#define REP(i , j , k)  for(int i = j ; i <  k ; ++i)
//priority_queue<int ,vector<int>, greater<int> >que;

const ll mos = 0x7FFFFFFFLL;  //2147483647
const ll nmos = 0x80000000LL;  //-2147483648
const int inf = 0x3f3f3f3f;
const ll inff = 0x3f3f3f3f3f3f3f3fLL; //18
const int mod = 998244353;

const double PI=acos(-1.0);

// #define _DEBUG;         //*//
#ifdef _DEBUG
freopen("input", "r", stdin);
// freopen("output.txt", "w", stdout);
#endif
/*-----------------------showtime----------------------*/
            const int maxn = 1e5+9;
            int root = 0,S,mx;
            int n,k;
            int sz[maxn],f[maxn],dis[maxn],cnt;
            int t[4];
            bool used[maxn];
            struct node
            {
                int to,w,nx;
            }e[maxn];
            int h[maxn],tot = 0;
            void add(int u,int v,int w){
                e[tot].to = v;
                e[tot].w = w;
                e[tot].nx = h[u];
                h[u] = tot++;
            }
            void getRoot(int u, int fa){
                sz[u] = 1,f[u] = 1;
                for(int i = h[u] ; ~i; i= e[i].nx){
                    int v = e[i].to;
                    if(used[v] || fa == v)continue;
                    getRoot(v,u);
                    sz[u] += sz[v];
                    f[u] = max(f[u] , sz[v]);
                }
                f[u] = max(f[u],S - sz[u]);
                if(f[u] < mx){root = u;mx = f[u];}
            }

            void getDis(int u,int fa,int D){
                for(int i=h[u] ; ~i; i=e[i].nx){
                    int v = e[i].to;
                    if(used[v]||v == fa)continue;
                    dis[++cnt] = D + e[i].w;
                    getDis(v,u,dis[cnt]);
                }
            }

            int getAns(int x,int D){
                dis[cnt = 1] = D;
                getDis(x,0,D);
                // sort(dis+1,dis+1+cnt);
                int ans = 0;
                t[0] = t[1] = t[2] = 0;
                for(int i=1; i<=cnt; i++){
                    t[dis[i]%3]++;
                }
                ans += t[1]*t[2]*2 + t[0]*t[0];
                return ans;
            }

            int Divide(int x){
                used[x] = true;
                ll ans = getAns(x,0);
                for(int i=h[x]; ~i; i= e[i].nx){
                    int v = e[i].to;
                    if(used[v])continue;
                    ans -= getAns(v,e[i].w);
                    mx = inf,S = sz[v];
                    getRoot(v,x);ans += Divide(root);
                }
                return ans;
            }
            ll gcd(ll a,ll b){
                if(b==0)return a;
                return gcd(b,a%b);
            }
int main(){
            
            while(~scanf("%d", &n) && n)
            {
                memset(h,-1,sizeof(h));
                memset(used,false,sizeof(used));
                tot = 0;
                for(int i=1; i<n; i++){
                    int u,v,c;
                    scanf("%d%d%d", &u, &v,&c);
                    add(u,v,c);
                    add(v,u,c);
                }
                S = n;mx = inf;
                getRoot(1,-1);
                int r = n*n;
                int ans = Divide(root);
                int tmp = gcd(ans,r);
                printf("%d/%d\n",ans/tmp,r/tmp);
            }
            return 0;
}
BZOj2152

 

posted @ 2018-08-20 22:28  ckxkexing  阅读(148)  评论(0编辑  收藏  举报