HDU 4871 Shortest-path tree(树分治+spfa)(待续)

题意:给定一幅无向图的节点数、边数、连接情况及各点间的距离;

        在各节点到根节点距离最短的树上,指定n个节点,求包含n个节点的最长子树个数;

        参考:http://www.cnblogs.com/chanme/p/3863793.html

思路:spfa求最短路,构建子树,树分治求子树个数;

        用树分治的情况有:有多少条路径的乘积=k,有多少条路径的和>k,有多少条路径的乘积是完全立方数。。。做法就是典型的树分治。

        树分治的实现:

        具体的做法是找出重心,对重心外的部分递归求解,合并的时候枚举到重心的所有路径,枚举的时候可以用一个全局的map ds记录当前到达这个点的所有情况,然后用一个tds去枚举新的部分的路径,然后通过ds和tds更新答案,更新完后将tds的内容加进去ds。

#pragma warning(disable:4996)
#include <iostream>
#include <cstring>
#include <string>
#include <vector>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <queue>
#include <map>
using namespace std;
 
#define ll long long
#define maxn 31000
#define maxm 61000
#define MP make_pair
 
struct Edge{
    int v, w;
    Edge(int vi, int wi) :v(vi), w(wi){}
    Edge(){}
    bool operator < (const Edge & b) const{
        return v < b.v;
    }
};
 
vector<Edge> G[maxn];
vector<Edge> E[maxn];
vector<Edge> EE[maxn];
vector<Edge> T[maxn];
 
int n, m, k;
 
int d[maxn];
int dx[maxn];
bool in[maxn];
 
void dfs(int u,int dis)
{
    in[u] = true; dx[u] = dis;
    if (dx[u] != d[u]) puts("fuck");
    for (int i = 0; i < EE[u].size(); i++){
        int v = EE[u][i].v, w = EE[u][i].w;
        if (!in[v]&&w+dis==d[v]) {
            T[u].push_back(Edge(v, w));
            T[v].push_back(Edge(u, w));
            dfs(v, w + dis);
        }
    }
}
 
void spfa()
{
    queue<int> que;
    memset(in, 0, sizeof(in));
    memset(d, 0x3f, sizeof(d));
    d[1] = 0; in[1] = true; que.push(1);
    while (!que.empty()){
        int u = que.front(); que.pop(); in[u] = false;
        for (int i = 0; i < G[u].size(); i++){
            int v = G[u][i].v, w = G[u][i].w;
            if (d[u] + w < d[v]){
                d[v] = d[u] + w;
                if (!in[v]) {
                    in[v] = true; que.push(v);
                }
                E[v].clear(); E[v].push_back(Edge(u, w));
            }
            else if (d[u] + w == d[v]){
                E[v].push_back(Edge(u, w));
            }
        }
    }
    for (int i = 1; i <= n; i++){
        for (int j = 0; j < E[i].size(); j++){
            EE[E[i][j].v].push_back(Edge(i, E[i][j].w));
            EE[i].push_back(E[i][j]);
        }
    }
    for (int i = 1; i <= n; i++) sort(EE[i].begin(), EE[i].end());
    memset(in, 0, sizeof(in));
    memset(dx, 0x3f, sizeof(dx));
    dfs(1,0);
}
 
bool centroid[maxn];
int ssize[maxn];
 
int compute_subtree_size(int v, int p){
    int c = 1;
    for (int i = 0; i < T[v].size(); i++){
        int  w = T[v][i].v;
        if (w == p || centroid[w]) continue;
        c += compute_subtree_size(w, v);
    }
    ssize[v] = c;
    return c;
}
 
pair<int, int> search_centroid(int v, int p, int t){
    pair<int, int> res = MP(INT_MAX, -1);
    int s = 1, m = 0;
    for (int i = 0; i < T[v].size(); i++){
        int w = T[v][i].v;
        if (w == p || centroid[w]) continue;
 
        res = min(res, search_centroid(w, v, t));
 
        m = max(m, ssize[w]);
        s += ssize[w];
    }
    m = max(m, t - s);
    res = min(res, MP(m, v));
    return res;
}
 
map<int, pair<int, int> > ds;
map<int, pair<int, int> > tds;
map<int, pair<int, int> >::iterator it;
map<int, pair<int, int> >::iterator itt;
// pass kk points, distant is dis
void enumerate(int v, int p, int kk, int dis, map<int, pair<int, int> > &tds)
{
    if (kk > k) return;
    it = tds.find(kk);
    if (it!=tds.end()){
        if (it->second.first == dis) {
            it->second.second += 1;
        }
        else if(it->second.first<dis){
            tds.erase(it);
            tds.insert(MP(kk, MP(dis, 1)));
        }
    }
    else{
        tds.insert(MP(kk, MP(dis, 1)));
    }
    for (int i = 0; i < T[v].size(); i++){
        int w = T[v][i].v;
        if (w == p || centroid[w]) continue;
        enumerate(w, v, kk + 1, dis + T[v][i].w, tds);
    }
}
 
ll ans, num;
 
void solve(int v)
{
    compute_subtree_size(v, -1);
    int s = search_centroid(v, -1, ssize[v]).second;
    centroid[s] = true;
    for (int i = 0; i < T[s].size(); i++){
        if (centroid[T[s][i].v]) continue;
        solve(T[s][i].v);
    }
    ds.clear();
    ds.insert(MP(1, MP(0, 1)));
    for (int i = 0; i < T[s].size(); i++){
        if (centroid[T[s][i].v]) continue;
        tds.clear();
        enumerate(T[s][i].v, s, 1, T[s][i].w, tds);
        it = tds.begin();
        while (it != tds.end()){
            int kk = it->first;
            if (ds.count(k - kk)){
                itt = ds.find(k - kk);
                int ldis = it->second.first + itt->second.first;
                if (ldis>ans) {
                    ans = ldis; num = it->second.second*itt->second.second;
                }
                else if (ldis == ans){
                    num += it->second.second*itt->second.second;
                }
            }
            ++it;
        }
        it = tds.begin();
        while (it != tds.end()){
            int kk = it->first + 1;
            if (ds.count(kk)){
                itt = ds.find(kk);
                if (it->second.first > itt->second.first){
                    ds.erase(itt);
                    ds.insert(MP(kk, it->second));
                }
                else if (it->second.first == itt->second.first) itt->second.second += it->second.second;
            }
            else{
                ds.insert(MP(kk, it->second));
            }
            ++it;
        }
    }
    centroid[s] = false;
}
 
int main()
{
    int TE; cin >> TE;
    while (TE--){
        scanf("%d%d%d", &n, &m, &k);
        for (int i = 0; i <= n; i++) {
            G[i].clear(); E[i].clear(); EE[i].clear(); T[i].clear();
        }
        int ui, vi, wi;
        for (int i = 0; i < m; i++){
            scanf("%d%d%d", &ui, &vi, &wi);
            G[ui].push_back(Edge(vi, wi));
            G[vi].push_back(Edge(ui, wi));
        }
        spfa();
        ans = 0, num = 0;
        solve(1);
        cout << ans << " " << num << endl;
    }
    return 0;
}

 

posted on 2015-07-17 21:17  大树置林  阅读(270)  评论(0编辑  收藏  举报

导航