bzoj 2286(虚树+树形dp) 虚树模板

树链求并又不会写,学了一发虚树,再也不虚啦~

2286: [Sdoi2011]消耗战

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 5002  Solved: 1869
[Submit][Status][Discuss]

Description

在一场战争中,战场由n个岛屿和n-1个桥梁组成,保证每两个岛屿间有且仅有一条路径可达。现在,我军已经侦查到敌军的总部在编号为1的岛屿,而且他们已经没有足够多的能源维系战斗,我军胜利在望。已知在其他k个岛屿上有丰富能源,为了防止敌军获取能源,我军的任务是炸毁一些桥梁,使得敌军不能到达任何能源丰富的岛屿。由于不同桥梁的材质和结构不同,所以炸毁不同的桥梁有不同的代价,我军希望在满足目标的同时使得总代价最小。
侦查部门还发现,敌军有一台神秘机器。即使我军切断所有能源之后,他们也可以用那台机器。机器产生的效果不仅仅会修复所有我军炸毁的桥梁,而且会重新随机资源分布(但可以保证的是,资源不会分布到1号岛屿上)。不过侦查部门还发现了这台机器只能够使用m次,所以我们只需要把每次任务完成即可。

Input

第一行一个整数n,代表岛屿数量。

接下来n-1行,每行三个整数u,v,w,代表u号岛屿和v号岛屿由一条代价为c的桥梁直接相连,保证1<=u,v<=n且1<=c<=100000。

第n+1行,一个整数m,代表敌方机器能使用的次数。

接下来m行,每行一个整数ki,代表第i次后,有ki个岛屿资源丰富,接下来k个整数h1,h2,…hk,表示资源丰富岛屿的编号。

 

Output

输出有m行,分别代表每次任务的最小代价。

 

 

Sample Input

10
1 5 13
1 9 6
2 1 19
2 4 8
2 3 91
5 6 8
7 5 4
7 8 31
10 7 9
3
2 10 6
4 5 7 8 3
3 9 4 6

Sample Output

12
32
22
 

  如题,就是做虚树+树形dp。
  虚树是什么呢?他是针对原来的树,剔除对结果无影响的点,剩下的点连接起来的一颗新的树。这样的树可以尽可能的避免对无影响点的计算,从而降低求解的时间复杂度。
  这样的一颗虚树能降低时间复杂度的前提要求是,建立这颗树时间为O(k),空间为O(k),k为影响点的数目。
  首先我们想到针对每个任务的朴素做法。先dfs向下求每个点和根割离的最小代价val[u]。然后针对目标点,dfs向上的时候求每个点u的最小代价dp[u]=min(Σdp[son[u]],val[u]),其中son[u]为u的孩子节点,而任务目标点的dp[u]初值为val[u],其余为0。这样的确能做出来,但是时间复杂度爆炸。岛屿数为n,任务数为m。那么时间复杂度O(mn)肯定过不了的,巨爆炸。
  但我们观察到ΣKi是≤50w的。首先我们算出val[u]。我们对于每个任务,分别建立一个空间为O(Ki)的虚树。我们在这颗书上做树形dp。这样我们时间复杂度也就降到了O(ΣKi)了。
  那每个任务的影响点有哪些呢?很明显的任务给出的丰富节点和他们的lca、根节点1为影响点。首先k个点的lca一定小于k个,这个我就不在这里证明了。那么接下来就是求lca点了。我们把所有丰富节点按照dfs序排序,然后每相邻两个求lca。这些点就是所有的lca点了。注意lca点和丰富节点可能是同一个点,所以记得标记一下。
  然后我们求每个点他们的父亲。求父亲的话你要多想想dfs时间戳的性质-每个节点的子树是在一个连续区间上的。因此我们再把这所有影响点按dfs序排序,然后维护一个栈。从头到尾处理,每次不断弹栈直到dfs时间戳上栈顶的点的子树区间包含该处理点,那么该栈顶点就是该点的父亲了。这样求完一颗虚树就建好了。
  然后在上面跑个树形dp就AC了~。
  1 #include<bits/stdc++.h>
  2 #define clr(x) memset(x,0,sizeof(x))
  3 #define clr_1(x) memset(x,-1,sizeof(x))
  4 #define mod 1000000007
  5 #define LL long long
  6 #define INF 0x3f3f3f3f
  7 using namespace std;
  8 const int N=5e5+10;
  9 int n,m,t,u,v;
 10 int bit[20];
 11 struct edg
 12 {
 13     int next,to,val;
 14 }edge[N];
 15 int head[N],tot;
 16 void addedge(int u,int v,int val)
 17 {
 18     edge[++tot].val=val;
 19     edge[tot].next=head[u];
 20     edge[tot].to=v;
 21     head[u]=tot;
 22     return ;
 23 }
 24 int fa[N][20],fro[N],bac[N],clk,dep[N],pre,vfa[N],vis[N],val[N];
 25 LL dp[N];
 26 stack<int> sta;
 27 int cntp,p[N],cntall;
 28 void init()
 29 {
 30     clr_1(head);
 31     tot=0;
 32     clk=0;
 33     bit[0]=1;
 34     val[1]=INF;
 35     for(int i=1;i<20;i++)
 36         bit[i]=bit[i-1]<<1;
 37     return ;
 38 }
 39 void dfs(int u,int father,int deep)
 40 {
 41     int p;
 42     fro[u]=++clk;
 43     dep[u]=deep;
 44     fa[u][0]=father;
 45     for(int i=1;bit[i]<=deep;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
 46 //    cout<<"dep["<<u<<"]:"<<deep<<" fa["<<u<<"]: ";
 47 //    for(int i=0;bit[i]<=deep;i++) cout<<fa[u][i]<<" ";
 48 //    cout<<endl;
 49     for(int i=head[u];i!=-1;i=edge[i].next)
 50     if(edge[i].to!=father)
 51     {
 52         val[edge[i].to]=min(val[u],edge[i].val);
 53         dfs(edge[i].to,u,deep+1);
 54     }
 55     bac[u]=clk;
 56     return ;
 57 }
 58 int lca(int u,int v)
 59 {
 60     if(dep[u]<dep[v]) swap(u,v);
 61     int tmp=dep[u]-dep[v];
 62     for(int i=0;bit[i]<=tmp;i++)
 63         if(tmp&bit[i]) u=fa[u][i];
 64     int i=19;
 65     while(bit[i]>dep[u]) i--;
 66     for(i;i>=0;i--)
 67         if(fa[u][i]!=fa[v][i]) {u=fa[u][i]; v=fa[v][i];}
 68     return u==v?u:fa[u][0];
 69 }
 70 bool cmp(int a,int b)
 71 {
 72     return fro[a]<fro[b];
 73 }
 74 int main()
 75 {
 76     init();
 77     scanf("%d",&n);
 78     for(int i=2;i<=n;i++)
 79     {
 80         scanf("%d%d%d",&u,&v,&t);
 81         addedge(u,v,t);
 82         addedge(v,u,t);
 83     }
 84     dfs(1,1,0);
 85     scanf("%d",&m);
 86     for(int i=1;i<=m;i++)
 87     {
 88         scanf("%d",&cntp);
 89         for(int i=1;i<=cntp;i++)
 90         {
 91             scanf("%d",p+i);
 92             vis[p[i]]=1;
 93         }
 94         sort(p+1,p+cntp+1,cmp);
 95         cntall=cntp;
 96         for(int i=2;i<=cntp;i++)
 97         {
 98             pre=lca(p[i],p[i-1]);
 99 //            cout<<pre<<"~"<<endl;
100             if(!vis[pre])
101             {
102                 vis[pre]=2;
103                 p[++cntall]=pre;
104             }
105         }
106         if(!vis[1])
107         {
108             vis[1]=2;
109             p[++cntall]=1;
110         }
111         sort(p+1,p+cntall+1,cmp);
112         while(!sta.empty())
113             sta.pop();
114         sta.push(p[1]);
115         dp[p[1]]=0;
116         for(int i=2;i<=cntall;i++)
117         {
118 //            cout<<p[i]<<endl;
119             dp[p[i]]=0;
120             while(!sta.empty() && fro[p[i]]>bac[sta.top()]) sta.pop();
121             vfa[p[i]]=sta.top();
122             sta.push(p[i]);
123         }
124         for(int i=cntall;i>1;i--)
125         {
126             if(vis[p[i]]==2)
127                 dp[p[i]]=min(dp[p[i]],(LL)val[p[i]]);
128             else
129                 dp[p[i]]=val[p[i]];
130             dp[vfa[p[i]]]+=dp[p[i]];
131 //            cout<<p[i]<<" "<<val[p[i]]<<" "<<dp[p[i]]<<" "<<vfa[p[i]]<<endl;
132         }
133         printf("%lld\n",dp[p[1]]);
134         for(int i=1;i<=cntall;i++)
135             vis[p[i]]=0;
136     }
137     return 0;
138 }
View Code

 

这个板子太假2333。

然后我们做虚树的板子需要啥呢?

首先数据结构

1 struct edg
2 {
3     int next,to;
4 }edge[N],vedge[N];//实树和虚树的边
5 int head[N],vhead[N],etot,vtot;//实树和虚树的头,边总数。
6 int bit[20];//倍增法lca需要计算2^i的值
7 int timed;//dfs序计数器
8 int dep[N],fa[N][20],dfn[N];//实树点的深度,倍增lca里跃迁2^k的祖先,每个实树点对应的dfs序
9 int pt[N],cnt;//虚树的点和点总数

 

然后是需要做的函数

 1 void init()//初始化所有条件
 2 {
 3     clr_1(head);
 4     clr_1(vhead);
 5     etot=vtot=0;
 6     bit[0]=1;
 7     for(int i=1;i<20;i++)
 8         bit[i]=bit[i-1]<<1;
 9     timed=0;
10     return;
11 }
12 void addedge(int u,int v)//加入实边
13 {
14     edge[++etot]=(edg){head[u],v};
15     head[u]=etot;
16     return;
17 }
18 void vaddedge(int u,int v)//加入虚边
19 {
20     vedge[++vtot]=(edg){vhead[u],v};
21     vhead[u]=vtot;
22     return;
23 }
24 void dfs(int u,int fat,int d)//一遍dfs把dfn,dep,fa[n][k]求出
25 {
26     dfn[u]=++timed;
27     dep[u]=d;
28     fa[u][0]=fat;
29     for(int i=1;bit[i]<=d;i++)
30         fa[u][i]=fa[fa[u][i-1]][i-1];
31     int p;
32     for(int i=head[u];i!=-1;i=edge[i].next)
33     {
34         p=edge[i].to;
35         if(p==fat) continue;
36         dfs(p,u,d+1);
37     }
38     return ;
39 }
40 int lca(int u,int v)//倍增求lca
41 {
42     if(dep[u]<dep[v]) swap(u,v);
43     int tmp=dep[u]-dep[v];
44     for(int i=0;bit[i]<=tmp;i++)
45         if(tmp&bit[i]) u=fa[u][i];
46     for(int i=19;i>=0;i--)
47     if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
48     return u==v?u:fa[u][0];
49 }
50 bool cmp(int a,int b)//dfn排序比较函数
51 {
52     return dfn[a]<dfn[b];
53 }
54 int sta[N],top;//暂存处理dfs链的栈和栈顶
55 void getvt(int *pt,int &cnt)
56 {
57     sort(pt+1,pt+cnt+1,cmp);
58     top=0;
59     int f;
60     for(int i=1,cntp=cnt;i<=cntp;i++)
61     {
62         if(top==0) {sta[++top]=pt[i]; continue;}
63         f=lca(pt[i],sta[top]);
64         while(top>1 && dep[f]<dep[sta[top-1]])
65             vaddedge(sta[top-1],sta[top]),top--;
66         if(dep[f]<dep[sta[top]])
67             vaddedge(f,sta[top]),top--;
68         if(top>0 && sta[top]!=f) sta[++top]=f,pt[++cnt]=f;
69         sta[++top]=pt[i];
70     }
71     while(top>1)
72         vaddedge(sta[top-1],sta[top]),top--;
73     sort(pt+1,pt+cnt+1,cmp);
74     return ;
75 }

 

posted @ 2018-04-11 17:53  hk_lin  阅读(766)  评论(0编辑  收藏  举报