2017多校第8场 HDU 6133 Army Formations 线段树合并
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6133
题意:给你一棵n个节点的二叉树,每个节点有一个提交任务的时间,每个节点总的提交任务的罚时为:提交这个节点和其子树所有的任务,每个任务提交时间的总和为该点的罚时。求每个节点提交完所有任务的最小罚时。
解法:根据题意,我们可以知道每个节点的提交的最小罚时为,按照任务的提交时间从小到大的来提交任务,可以得到最小的罚时。所以我们可以用线段树合并,先建立权值线段树,记录权值区间L到R的所有权值sum与size,线段树上的每一个点的ans为ans[lchild]+ans[rchild]+size[rchild]*sum[lchild]。
现场队友写了Splay合并过了,我下来写线段树合并一直MLE,然后交g++卡过了。。。
#include <stdio.h> #include <string.h> #include <algorithm> using namespace std; const int maxn = 100002; typedef long long LL; int n,m,cnt,rt[maxn],a[maxn],b[maxn]; LL ans[maxn]; struct edge{ int v,next; }E[maxn*2]; int head[maxn], edgecnt; void init(){ memset(head,-1,sizeof(head)); edgecnt=0; } void add(int u, int v){ E[edgecnt].v=v,E[edgecnt].next=head[u],head[u]=edgecnt++; } struct node{ int ls,rs,sz; LL ans,sum; }T[maxn*18]; int Merge(int u, int v){ if(!u||!v) return u+v; if(T[u].ls||T[u].rs){ T[u].ls=Merge(T[u].ls,T[v].ls); T[u].rs=Merge(T[u].rs,T[v].rs); T[u].sum=T[T[u].ls].sum+T[T[u].rs].sum; T[u].sz=T[T[u].ls].sz+T[T[u].rs].sz; T[u].ans=T[T[u].ls].ans+T[T[u].rs].ans+T[T[u].ls].sum*T[T[u].rs].sz; } else{ T[u].ans=T[u].ans+T[v].ans+T[u].sum*T[v].sz; T[u].sum=T[u].sum+T[v].sum; T[u].sz=T[u].sz+T[v].sz; } return u; } void dfs(int u, int fa){ for(int i = head[u]; ~i; i=E[i].next){ int v = E[i].v; if(v!=fa){ dfs(v,u); Merge(rt[u],rt[v]); } } ans[u]=T[rt[u]].ans; } void build(int &node, int l, int r, int pos) { node = ++cnt; T[node].sum=T[node].ans=b[pos]; T[node].sz=1; T[node].ls=T[node].rs=0; if(l==r) return; int mid=(l+r)>>1; if(pos<=mid) build(T[node].ls, l, mid, pos); else build(T[node].rs, mid+1, r, pos); } int main() { int _; scanf("%d", &_); while(_--){ cnt=0; scanf("%d", &n); init(); for(int i=0; i<maxn; i++) T[i].sz=T[i].ans=T[i].sum=0; for(int i=1; i<=n; i++) scanf("%d", &a[i]), b[i]=a[i]; sort(b+1,b+n+1); m = unique(b+1,b+n+1)-b-1; for(int i=1; i<=n; i++) a[i] = lower_bound(b+1,b+m+1,a[i])-b; for(int i=1; i<=n; i++) build(rt[i],1,m,a[i]); for(int i=1; i<n; i++){ int u, v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } dfs(1,-1); for(int i=1; i<=n; i++){ printf("%lld", ans[i]); if(i!=n) printf(" "); else printf(" \n"); } } return 0; }