2017多校第8场 HDU 6133 Army Formations 线段树合并

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6133

题意:给你一棵n个节点的二叉树,每个节点有一个提交任务的时间,每个节点总的提交任务的罚时为:提交这个节点和其子树所有的任务,每个任务提交时间的总和为该点的罚时。求每个节点提交完所有任务的最小罚时。

解法:根据题意,我们可以知道每个节点的提交的最小罚时为,按照任务的提交时间从小到大的来提交任务,可以得到最小的罚时。所以我们可以用线段树合并,先建立权值线段树,记录权值区间L到R的所有权值sum与size,线段树上的每一个点的ans为ans[lchild]+ans[rchild]+size[rchild]*sum[lchild]。

现场队友写了Splay合并过了,我下来写线段树合并一直MLE,然后交g++卡过了。。。

#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
const int maxn = 100002;
typedef long long LL;
int n,m,cnt,rt[maxn],a[maxn],b[maxn];
LL ans[maxn];
struct edge{
    int v,next;
}E[maxn*2];
int head[maxn], edgecnt;
void init(){
    memset(head,-1,sizeof(head));
    edgecnt=0;
}
void add(int u, int v){
    E[edgecnt].v=v,E[edgecnt].next=head[u],head[u]=edgecnt++;
}
struct node{
    int ls,rs,sz;
    LL ans,sum;
}T[maxn*18];
int Merge(int u, int v){
    if(!u||!v) return u+v;
    if(T[u].ls||T[u].rs){
        T[u].ls=Merge(T[u].ls,T[v].ls);
        T[u].rs=Merge(T[u].rs,T[v].rs);
        T[u].sum=T[T[u].ls].sum+T[T[u].rs].sum;
        T[u].sz=T[T[u].ls].sz+T[T[u].rs].sz;
        T[u].ans=T[T[u].ls].ans+T[T[u].rs].ans+T[T[u].ls].sum*T[T[u].rs].sz;
    }
    else{
        T[u].ans=T[u].ans+T[v].ans+T[u].sum*T[v].sz;
        T[u].sum=T[u].sum+T[v].sum;
        T[u].sz=T[u].sz+T[v].sz;
    }
    return u;
}
void dfs(int u, int fa){
    for(int i = head[u]; ~i; i=E[i].next){
        int v = E[i].v;
        if(v!=fa){
            dfs(v,u);
            Merge(rt[u],rt[v]);
        }
    }
    ans[u]=T[rt[u]].ans;
}
void build(int &node, int l, int r, int pos)
{
    node = ++cnt;
    T[node].sum=T[node].ans=b[pos];
    T[node].sz=1;
    T[node].ls=T[node].rs=0;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(pos<=mid) build(T[node].ls, l, mid, pos);
    else build(T[node].rs, mid+1, r, pos);
}
int main()
{
    int _;
    scanf("%d", &_);
    while(_--){
        cnt=0;
        scanf("%d", &n);
        init();
        for(int i=0; i<maxn; i++) T[i].sz=T[i].ans=T[i].sum=0;
        for(int i=1; i<=n; i++) scanf("%d", &a[i]), b[i]=a[i];
        sort(b+1,b+n+1);
        m = unique(b+1,b+n+1)-b-1;
        for(int i=1; i<=n; i++) a[i] = lower_bound(b+1,b+m+1,a[i])-b;
        for(int i=1; i<=n; i++) build(rt[i],1,m,a[i]);
        for(int i=1; i<n; i++){
            int u, v;
            scanf("%d%d",&u,&v);
            add(u,v);
            add(v,u);
        }
        dfs(1,-1);
        for(int i=1; i<=n; i++){
            printf("%lld", ans[i]);
            if(i!=n) printf(" ");
            else printf(" \n");
        }
    }
    return 0;
}

 

posted @ 2017-08-20 16:01  zxycoder  阅读(316)  评论(0编辑  收藏  举报