hihoCoder 1238 : Total Highway Distance(dfs + 二分)

题目连接

题意:给出n个城市,n-1条道路,求出两两路径的和。

思路:题意等价于求每天道路的使用次数,如下图所示

 

红色路径的使用度为以节点2为根节点的子树的节点数x * (n-x),此处为2 * 2 = 4。先按u<v的规则保存好道路,然后dfs一遍处理处每天道路的使用度,dfs过程中需要知道当前的边是哪条道路,此过程用二分查找,这中双变量的二分之前也没怎么写过。

code:

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <map> 
#define maxn 100005
typedef long long ll;
using namespace std;
struct node{
    int to,next;
    int length;
}road[maxn*2];
struct node1{
    int u,v;
    int length;
}edge[maxn];
int head[maxn];
int cnt = 0;
int n,m;
ll ans;
bool vis[maxn];
int rootSize[maxn];
ll tt[maxn];
bool cmp(node1 a,node1 b){  
    if(a.u==b.u)  
        return a.v<b.v;  
    else return a.u<b.u;  
}  
void init(){
    for(int i = 0;i < maxn;i ++) head[i] = -1;
}
int binarysearch(int u,int v){///二分查找u,v所对应的边,此处可以通过map<pair<int,int>,int>实现  
    if(u>v) swap(u,v);  
    int l=0,r=n-2;  
    int ll,rr;             //确定edge[i].u = u 的左右边界 
    while(l<=r){  
        int mid=(l+r)>>1;  
        if(edge[mid].u==u){  
            ll=mid;  
            r=mid-1;  
        }  
        else if(edge[mid].u>u)  
            r=mid-1;  
        else  
            l=mid+1;  
    }  
    l=ll;r=n-2;  
    while(l<=r){  
        int mid=(l+r)>>1;  
        if(edge[mid].u==u){  
            rr=mid;  
            l=mid+1;  
        }  
        else if(edge[mid].u>u)  
            r=mid-1;  
        else  
            l=mid+1;  
    }  
    while(ll<=rr){  
        int mid=(ll+rr)>>1;  
        if(edge[mid].v==v) return mid;  
        else if(edge[mid].v>v)  
            rr=mid-1;  
        else  
            ll=mid+1;  
    }  
}
void built(int from,int to,int length){
    road[cnt].to = to;
    road[cnt].length = length;
    road[cnt].next = head[from];
    head[from] = cnt;
    cnt ++;
}

int dfs(int u){
    vis[u] = true;
    rootSize[u] = 1;
    for(int j = head[u];j != -1;j = road[j].next){
        int v = road[j].to;
        int l = road[j].length;
        if(!vis[v]){
            int tn = binarysearch(u,v);
            rootSize[u] += dfs(v);
            tt[tn] = (ll)(n - rootSize[v])*(ll)rootSize[v];
            
        }
    }
    return rootSize[u];
}
void init2(){
    memset(vis,0,sizeof(vis));
    dfs(1);
}

void query(){
    for(int i = 0;i < n-1;i ++)
         ans += (ll)(tt[i]*(ll)edge[i].length);
}
void update(int a,int b,int l){
    if(a > b) swap(a,b);
    int tn = binarysearch(a,b);
    ans -= (ll)(tt[tn]*(ll)edge[tn].length);
    edge[tn].length = l;
    ans += (ll)(tt[tn]*(ll)edge[tn].length);
}
int main()
{
    char op[10];
    init();
    cin >> n >> m;
    for(int i = 0;i <n-1;i ++){
        int a,b,l;
        scanf("%d %d %d",&a,&b,&l);
        if(a > b) swap(a,b);
        edge[i].u = a;
        edge[i].v = b;
        edge[i].length = l;
        built(a,b,l);
        built(b,a,l);
    }
    sort(edge,edge+n-1,cmp);
    cnt = 0;
    init2();
    //for(int i = 0;i < n - 1;i ++) cout << tt[i] << endl;
    ans = 0;
    query();
    for(int i = 0;i < m;i ++){
        scanf("%s",op);
        if(op[0] == 'Q'){
            cout << ans << endl;
        }
        else{
            int a,b,l;
            scanf("%d %d %d",&a,&b,&l);
            update(a,b,l);
        }
    }
}
View Code

 

posted on 2016-08-26 23:01  Tob's_the_top  阅读(143)  评论(0编辑  收藏  举报

导航