POJ - 1741 - Tree - 点分治 模板
题意:
对于带权的一棵树,求树中距离不超过k的点的对数。
思路:
点分治的裸题。 将这棵树分成很多小的树,分治求解。
#include <algorithm> #include <iterator> #include <iostream> #include <cstring> #include <cstdlib> #include <iomanip> #include <bitset> #include <cctype> #include <cstdio> #include <string> #include <vector> #include <cmath> #include <queue> #include <list> #include <map> #include <set> using namespace std; //#pragma GCC optimize(3) //#pragma comment(linker, "/STACK:102400000,102400000") //c++ #define lson (l , mid , rt << 1) #define rson (mid + 1 , r , rt << 1 | 1) #define debug(x) cerr << #x << " = " << x << "\n"; #define pb push_back #define pq priority_queue typedef long long ll; typedef unsigned long long ull; typedef pair<ll ,ll > pll; typedef pair<int ,int > pii; typedef pair<int,pii> p3; //priority_queue<int> q;//这是一个大根堆q //priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q #define fi first #define se second //#define endl '\n' #define OKC ios::sync_with_stdio(false);cin.tie(0) #define FT(A,B,C) for(int A=B;A <= C;++A) //用来压行 #define REP(i , j , k) for(int i = j ; i < k ; ++i) //priority_queue<int ,vector<int>, greater<int> >que; const ll mos = 0x7FFFFFFFLL; //2147483647 const ll nmos = 0x80000000LL; //-2147483648 const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3fLL; //18 const int mod = 998244353; const double PI=acos(-1.0); // #define _DEBUG; //*// #ifdef _DEBUG freopen("input", "r", stdin); // freopen("output.txt", "w", stdout); #endif /*-----------------------showtime----------------------*/ const int maxn = 1e5+9; int root = 0,S,mx; int n,k; int sz[maxn],f[maxn],dis[maxn],cnt; bool used[maxn]; struct node { int to,w,nx; }e[maxn]; int h[maxn],tot = 0; void add(int u,int v,int w){ e[tot].to = v; e[tot].w = w; e[tot].nx = h[u]; h[u] = tot++; } void getRoot(int u, int fa){ sz[u] = 1,f[u] = 1; for(int i = h[u] ; ~i; i= e[i].nx){ int v = e[i].to; if(used[v] || fa == v)continue; getRoot(v,u); sz[u] += sz[v]; f[u] = max(f[u] , sz[v]); } f[u] = max(f[u],S - sz[u]); if(f[u] < mx){root = u;mx = f[u];} } void getDis(int u,int fa,int D){ for(int i=h[u] ; ~i; i=e[i].nx){ int v = e[i].to; if(used[v]||v == fa)continue; dis[++cnt] = D + e[i].w; getDis(v,u,dis[cnt]); } } int getAns(int x,int D){ dis[cnt = 1] = D; getDis(x,0,D); sort(dis+1,dis+1+cnt); int le = 1,ri =cnt,ans = 0; while(le <= ri){ if(dis[le] + dis[ri] <= k)ans += ri - le,le++; else ri--; } return ans; } int Divide(int x){ used[x] = true; ll ans = getAns(x,0); for(int i=h[x]; ~i; i= e[i].nx){ int v = e[i].to; if(used[v])continue; ans -= getAns(v,e[i].w); mx = inf,S = sz[v]; getRoot(v,x);ans += Divide(root); } return ans; } int main(){ while(~scanf("%d%d", &n, &k) && n+k) { memset(h,-1,sizeof(h)); memset(used,false,sizeof(used)); tot = 0; for(int i=1; i<n; i++){ int u,v,c; scanf("%d%d%d", &u, &v,&c); add(u,v,c); add(v,u,c); } S = n;mx = inf; getRoot(1,-1); printf("%d\n",Divide(root)); } return 0; }
自己今天又写了一遍。
//点分治 //#pragma GCC optimize(3) //#pragma comment(linker, "/STACK:102400000,102400000") //c++ // #pragma GCC diagnostic error "-std=c++11" // #pragma comment(linker, "/stack:200000000") // #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native") // #pragma GCC optimize("-fdelete-null-pointer-checks,inline-functions-called-once,-funsafe-loop-optimizations,-fexpensive-optimizations,-foptimize-sibling-calls,-ftree-switch-conversion,-finline-small-functions,inline-small-functions,-frerun-cse-after-loop,-fhoist-adjacent-loads,-findirect-inlining,-freorder-functions,no-stack-protector,-fpartial-inlining,-fsched-interblock,-fcse-follow-jumps,-fcse-skip-blocks,-falign-functions,-fstrict-overflow,-fstrict-aliasing,-fschedule-insns2,-ftree-tail-merge,inline-functions,-fschedule-insns,-freorder-blocks,-fwhole-program,-funroll-loops,-fthread-jumps,-fcrossjumping,-fcaller-saves,-fdevirtualize,-falign-labels,-falign-loops,-falign-jumps,unroll-loops,-fsched-spec,-ffast-math,Ofast,inline,-fgcse,-fgcse-lm,-fipa-sra,-ftree-pre,-ftree-vrp,-fpeephole2",3) #include <algorithm> #include <iterator> #include <iostream> #include <cstring> #include <cstdlib> #include <iomanip> #include <bitset> #include <cctype> #include <cstdio> #include <string> #include <vector> #include <stack> #include <cmath> #include <queue> #include <list> #include <map> #include <set> #include <cassert> using namespace std; #define lson (l , mid , rt << 1) #define rson (mid + 1 , r , rt << 1 | 1) #define debug(x) cerr << #x << " = " << x << "\n"; #define pb push_back #define pq priority_queue typedef long long ll; typedef unsigned long long ull; //typedef __int128 bll; typedef pair<ll ,ll > pll; typedef pair<int ,int > pii; typedef pair<int,pii> p3; //priority_queue<int> q;//这是一个大根堆q //priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q #define fi first #define se second //#define endl '\n' #define OKC ios::sync_with_stdio(false);cin.tie(0) #define FT(A,B,C) for(int A=B;A <= C;++A) //用来压行 #define REP(i , j , k) for(int i = j ; i < k ; ++i) #define max3(a,b,c) max(max(a,b), c); #define min3(a,b,c) min(min(a,b), c); //priority_queue<int ,vector<int>, greater<int> >que; const ll mos = 0x7FFFFFFF; //2147483647 const ll nmos = 0x80000000; //-2147483648 const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3f; //18 const int mod = 1e9+7; const double esp = 1e-8; const double PI=acos(-1.0); const double PHI=0.61803399; //黄金分割点 const double tPHI=0.38196601; template<typename T> inline T read(T&x){ x=0;int f=0;char ch=getchar(); while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x=f?-x:x; } /*-----------------------showtime----------------------*/ /* using namespace std; #define pb push_back #define debug(x) cerr<<#x<<" = " << x<<endl; #define fi first #define se second typedef pair<int,int> pii; const int inf = 0x3f3f3f3f; */ int n,k,ans; const int maxn = 10009; vector<pii>mp[maxn]; int dp[maxn],cen[maxn]; void getsize(int u,int fa){ dp[u] = 1; for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i].fi; if(fa == v || cen[v]) continue; getsize(v, u); dp[u] += dp[v]; } } pii getbig(int u,int fa,int t){ pii res = pii(inf, u); int mx = 0; for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i].fi; if(v == fa || cen[v])continue; res = min(res, getbig(v, u, t)); mx = max(mx, dp[v]); } res = min(res, pii(max(mx, t - dp[u]), u)); return res; } void dfs(int u,int fa,int c, vector<int> & b){ b.pb(c); for(int i=0; i<mp[u].size(); i++){ int v = mp[u][i].fi,d = c + mp[u][i].se; if(v == fa || cen[v])continue; dfs(v,u,d,b); } } int cal(vector<int>&b){ sort(b.begin(), b.end()); int res = 0,r = b.size(); for(int i=0; i<b.size(); i++){ while(r && b[i] + b[r-1] > k) r--; if(r > i) res += r - 1; else res += r; } return res/2; } void solve(int u){ getsize(u, -1); int s = getbig(u, -1,dp[u]).se; cen[s] = 1; for(int i=0; i<mp[s].size(); i++){ int v = mp[s][i].fi; if(cen[v])continue; solve(v); } vector<int>a; a.pb(0); for(int i=0; i<mp[s].size(); i++) { int v = mp[s][i].fi; if(cen[v])continue; vector<int>b; dfs(v, s, mp[s][i].se, b); ans -= cal(b); a.insert(a.end(),b.begin(),b.end()); } ans += cal(a); cen[s] = 0; } int main(){ while(~scanf("%d%d", &n, &k) && n + k){ for(int i=1; i<=n; i++) mp[i].clear(); for(int i=1; i<n; i++){ int u,v,w; scanf("%d%d%d", &u, &v, &w); mp[u].pb(pii(v,w)); mp[v].pb(pii(u,w)); } ans = 0; solve(1); printf("%d\n", ans); } return 0; }
具有更好的可读性
//点分治 #include <algorithm> #include <iterator> #include <iostream> #include <cstring> #include <cstdlib> #include <iomanip> #include <bitset> #include <cctype> #include <cstdio> #include <string> #include <vector> #include <stack> #include <cmath> #include <queue> #include <list> #include <map> #include <set> #include <cassert> using namespace std; #define lson (l, mid, rt << 1) #define rson (mid + 1, r, rt << 1 | 1) #define debug(x) cerr << #x << " = " << x << "\n"; #define pb push_back #define pq priority_queue typedef long long ll; typedef unsigned long long ull; //typedef __int128 bll; typedef pair<ll, ll> pll; typedef pair<int, int> pii; typedef pair<int, pii> p3; //priority_queue<int> q;//这是一个大根堆q //priority_queue<int,vector<int>,greater<int> >q;//这是一个小根堆q #define fi first #define se second //#define endl '\n' #define OKC \ ios::sync_with_stdio(false); \ cin.tie(0) #define FT(A, B, C) for (int A = B; A <= C; ++A) //用来压行 #define REP(i, j, k) for (int i = j; i < k; ++i) #define max3(a, b, c) max(max(a, b), c); #define min3(a, b, c) min(min(a, b), c); //priority_queue<int ,vector<int>, greater<int> >que; const ll mos = 0x7FFFFFFF; //2147483647 const ll nmos = 0x80000000; //-2147483648 const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3f; //18 const int mod = 1e9 + 7; const double esp = 1e-8; const double PI = acos(-1.0); const double PHI = 0.61803399; //黄金分割点 const double tPHI = 0.38196601; template <typename T> inline T read(T &x) { x = 0; int f = 0; char ch = getchar(); while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar(); while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return x = f ? -x : x; } /*-----------------------showtime----------------------*/ int n, k, ans; const int maxn = 10009; vector<pii> mp[maxn]; int dp[maxn], cen[maxn]; void get_size(int u, int fa) { dp[u] = 1; for (int i = 0; i < mp[u].size(); i++) { int v = mp[u][i].fi; if (fa == v || cen[v]) continue; get_size(v, u); dp[u] += dp[v]; } } pii get_center(int u, int fa, int t) { pii res = pii(inf, u); int mx = 0; for (int i = 0; i < mp[u].size(); i++) { int v = mp[u][i].fi; if (v == fa || cen[v]) continue; res = min(res, get_center(v, u, t)); mx = max(mx, dp[v]); } res = min(res, pii(max(mx, t - dp[u]), u)); return res; } void get_dists_to_son(int u, int fa, int c, vector<int> &b) { b.pb(c); for (int i = 0; i < mp[u].size(); i++) { int v = mp[u][i].fi, d = c + mp[u][i].se; if (v == fa || cen[v]) continue; get_dists_to_son(v, u, d, b); } } int cal_pair_less_k(vector<int> &b) { sort(b.begin(), b.end()); int res = 0, r = b.size(); for (int i = 0; i < b.size(); i++) { while (r && b[i] + b[r - 1] > k) r--; if (r > i) res += r - 1; else res += r; } return res / 2; } void solve(int u) { get_size(u, -1); int s = get_center(u, -1, dp[u]).se; cen[s] = 1; for (int i = 0; i < mp[s].size(); i++) { int v = mp[s][i].fi; if (cen[v]) continue; solve(v); } vector<int> a; a.pb(0); for (int i = 0; i < mp[s].size(); i++) { int v = mp[s][i].fi; if (cen[v]) continue; vector<int> b; get_dists_to_son(v, s, mp[s][i].se, b); ans -= cal_pair_less_k(b); a.insert(a.end(), b.begin(), b.end()); } ans += cal_pair_less_k(a); cen[s] = 0; } int main() { while (~scanf("%d%d", &n, &k) && n + k) { for (int i = 1; i <= n; i++) mp[i].clear(); for (int i = 1; i < n; i++) { int u, v, w; scanf("%d%d%d", &u, &v, &w); mp[u].pb(pii(v, w)); mp[v].pb(pii(u, w)); } ans = 0; solve(1); printf("%d\n", ans); } return 0; }
skr