追击圣诞老人
有一个经典的套路(序列合并):有一个dag,节点数非常多,直接存是存不下的。
但是不用把整个图建出来。
如果把一个点的状态向后继状态连边,边有边权,定义一条路径的权值为它的边权和。求最小的k种路径边权和。
这是个经典的问题。使用堆,每次把最小值取出来,把它的后继状态插入堆,把最小值pop掉k次即可。
如果k比较大可以二分。
这道题也可以这么解决,但是会发现一个点的出边数量非常多。
我们真的需要每次把一个点的所有出边都插入到堆里吗?
注意到一个点的所有出边,边权小的边是会先于边权大的边被pop出来的。所以可以把“插入一个点的所有出边的过程”拆成若干个“插入这个点的出边的k小值”的过程。
把一个点的a,b,c限制拆成若干条直上直下的链。每次在拓展一个新的状态的时候,把这3条直上直下的链都插入堆里。
接着根据我们之前的分析,把“插入一个点的所有出边的过程”拆成若干个“插入这个点的出边的k小值”的过程,就是把一条直上直下的链在最小值位置分裂成两段,都插入堆里。
只要删除k次,我们就可以得到答案。
由于每次最多会插入5条链,所以时间复杂度是\(O(n+k\log_2n)\)的。
我们需要求出一条链的最小值位置。注意到我们在轻重链剖分跳链的时候,我们要查询的是链的一段前缀,只有最后部分我们才需要查询链的一段区间。
所以可以预处理链的前缀最小值,查询一次时间复杂度降低为\(O(\log_2n)\)
使用虚树即可避免对链并的分类讨论。
一份被卡空间的代码,当然正确性是保证的。
#include<bits/stdc++.h>
using namespace std;
#define N 500010
int h[N],v[N],n,k,nxt[N],ec,sz[N],p[N],tp[N],a[N],id[N],ct,f[N],st[N],d[N],ss[N],ee[N],vi[N],sg[N],tt,mp[N*4],po[N],sv[N],tv[N],ts,rt,cv[N],mn[N*4];
int *mv[N],*pp[N];
void add(int x,int y){v[++ec]=y;nxt[ec]=h[x];h[x]=ec;}
struct no{
int x,y;
}c[N][4];
struct nn{
int po,x,y,z;
};
int operator <(nn x,nn y){
return x.z>y.z;
}
void d1(int x){
sz[x]=1;
for(int i=h[x];i;i=nxt[i]){
d[v[i]]=d[x]+1;
f[v[i]]=x;
d1(v[i]);
sz[x]+=sz[v[i]];
if(sz[v[i]]>sz[p[x]])
p[x]=v[i];
}
}
void d2(int x,int t){
id[x]=++ct;
st[ct]=a[x];
sv[ct]=x;
tp[x]=t;
if(p[x])
d2(p[x],t);
for(int i=h[x];i;i=nxt[i])
if(v[i]!=p[x])
d2(v[i],v[i]);
}
void bd(int o,int l,int r){
if(l==r){
mn[o]=st[l];
mp[o]=l;
return;
}
int md=(l+r)/2;
bd(o*2,l,md);
bd(o*2+1,md+1,r);
mn[o]=min(mn[o*2],mn[o*2+1]);
if(mn[o*2]==mn[o])
mp[o]=mp[o*2];
else
mp[o]=mp[o*2+1];
}
int ma,vp;
void qu(int o,int l,int r,int x,int y){
if(r<x||y<l)
return;
if(x<=l&&r<=y){
if(mn[o]<ma){
ma=mn[o];
vp=mp[o];
}
return;
}
int md=(l+r)/2;
qu(o*2,l,md,x,y);
qu(o*2+1,md+1,r,x,y);
}
int lca(int x,int y){
while(tp[x]!=tp[y]){
if(d[tp[x]]>d[tp[y]])
x=f[tp[x]];
else
y=f[tp[y]];
}
if(d[x]>d[y])
return y;
return x;
}
int qv(int x,int y){
int mp=0;
while(tp[x]!=tp[y]){
if(d[tp[x]]>d[tp[y]]){
int vp=mv[po[tp[x]]][id[x]-id[tp[x]]+1];
if(st[vp]<st[mp])
mp=vp;
x=f[tp[x]];
}
else{
int vp=mv[po[tp[y]]][id[y]-id[tp[y]]+1];
if(st[vp]<st[mp])
mp=vp;
y=f[tp[y]];
}
}
vp=0;
ma=1e9;
if(d[x]>d[y]){
qu(1,1,n,id[y],id[x]);
if(st[vp]<st[mp])
mp=vp;
}
else{
qu(1,1,n,id[x],id[y]);
if(st[vp]<st[mp])
mp=vp;
}
return mp;
}
int jp(int x,int y){
if(y<0)
return 0;
if(!y)
return x;
while(1){
if(d[x]-d[tp[x]]+1>y)
break;
else{
y-=d[x]-d[tp[x]]+1;
x=f[tp[x]];
}
}
return pp[po[tp[x]]][d[x]-y+1-d[tp[x]]];
}
int cp(int x,int y){
return id[x]<id[y];
}
vector<int>g[N];
void ins(int x,int *st){
if(!ts){
st[++ts]=x;
return;
}
int lc=lca(st[ts],x);
if(lc==st[ts])st[++ts]=x;
else{
while(ts>1&&id[st[ts-1]]>=id[lc]){
g[st[ts-1]].push_back(st[ts]);
ts--;
}
if(lc!=st[ts]){
g[lc].push_back(st[ts]);
st[ts]=lc;
}
st[++ts]=x;
}
}
void bd(vector<int>va,int id){
ts=0;
sort(va.begin(),va.end(),cp);
auto it=unique(va.begin(),va.end());
va.erase(it,va.end());
if(va.size()==1){
cv[id]=1;
c[id][0]=(no){va[0],va[0]};
return;
}
for(int i=0;i<va.size();i++)
ins(va[i],tv);
for(int i=2;i<=ts;i++)
g[tv[i-1]].push_back(tv[i]);
}
void dfs(int x,int id){
int ok=0;
for(int y:g[x]){
dfs(y,id);
if(!ok&&x==rt)
c[id][cv[id]++]=(no){x,y};
else
c[id][cv[id]++]=(no){jp(y,d[y]-d[x]-1),y};
ok=1;
}
g[x].clear();
}
priority_queue<nn>q;
void gf(int x,int y,int z){
if(z>1e8)
return;
if(!x||!y||d[x]>d[y])
return;
int mp=qv(x,y);
q.push((nn){sv[mp],x,y,z+st[mp]});
}
int main(){
//freopen("r.in","r",stdin);
//freopen("w.out","w",stdout);
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
a[0]=1e9;
for(int i=2;i<=n;i++){
int p;
scanf("%d",&p);
add(p,i);
}
d[0]=st[0]=1e9;
d1(1);
d2(1,1);
for(int i=1;i<=n;i++)
ss[i]=1e9;
for(int i=1;i<=n;i++){
ss[tp[i]]=min(id[i],ss[tp[i]]);
ee[tp[i]]=max(id[i],ee[tp[i]]);
if(!vi[tp[i]]){
vi[tp[i]]=1;
sg[++tt]=tp[i];
}
}
for(int i=1;i<=tt;i++){
mv[i]=new int[ee[sg[i]]-ss[sg[i]]+2];
pp[i]=new int[ee[sg[i]]-ss[sg[i]]+2];
po[sg[i]]=i;
for(int j=ss[sg[i]];j<=ee[sg[i]];j++){
int po=j-ss[sg[i]]+1;
pp[i][po]=sv[j];
if(j==ss[sg[i]])
mv[i][po]=j;
else{
mv[i][po]=mv[i][po-1];
if(st[mv[i][po]]>st[j])
mv[i][po]=j;
}
}
}
bd(1,1,n);
for(int i=1;i<=n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
vector<int>v;
v.push_back(x);
v.push_back(y);
v.push_back(z);
bd(v,i);
rt=lca(lca(x,y),z);
dfs(rt,i);
}
for(int i=1;i<=n;i++)
q.push((nn){i,i,i,a[i]});
for(int i=1;i<=k;i++){
if(q.empty())
break;
nn x=q.top();
q.pop();
printf("%d\n",x.z);
gf(x.x,f[x.po],x.z-a[x.po]);
gf(jp(x.y,d[x.y]-d[x.po]-1),x.y,x.z-a[x.po]);
for(int j=0;j<cv[x.po];j++)
gf(c[x.po][j].x,c[x.po][j].y,x.z);
}
}