[luogu8294]最大权独立集问题
记\(ls\)和\(rs\)分别为\(k\)的左右儿子\(,sub_{k}\)表示以\(k\)为根的子树中节点集合
定义\(f_{k,i,j}\)表示以\(k\)为根的子树中,子树内\(d_{i}\)与子树外\(d_{j}\)发生交换的最小代价,则
\[f_{k,i,j}=d_{i}+d_{j}+\begin{cases}
0&(ls=rs=\empty)\\
\begin{cases}
\min_{x\in sub_{ls}}f_{ls,x,j}&i=k\\f_{ls,i,k}&i\in sub_{ls}
\end{cases}&(ls\ne \empty,rs=\empty)\\
\begin{cases}
\min_{x\in sub_{ls},y\in sub_{rs}}\min(f_{ls,x,j}+f_{rs,y,x},f_{rs,y,j}+f_{ls,x,y})&i=k\\
\min_{y\in sub_{rs}}\min(f_{ls,i,k}+f_{rs,y,j},f_{rs,y,k}+f_{ls,i,y})&i\in sub_{ls}
\end{cases}&(ls,rs\ne \empty)
\end{cases}
\]
(根据\(ls\)和\(rs\)的对称性,这里省略了部分情况)
另外,最终答案需要对根节点做一个类似的分类讨论,具体略
暴力转移,时空复杂度均为\(o(n^{3})\),无法通过
观察转移式子,构造\(\begin{cases}f0(k,j)=f_{k,k,j}\\f1(k,i)=f_{k,i,fa}\\f2(k,j)=\min_{i\in sub_{k}}f_{k,i,j}\\f3(k,i)=\min_{j\in sub_{rs}}f_{rs,j,k}+f_{ls,i,j}\end{cases}\),则
\[f_{k,i,j}=\begin{cases}f0(k,j)&i=k\\f1(ls,i)+d_{i}+d_{j}&i\in sub_{ls},ls\ne \empty,rs=\empty\\\min(f1(ls,i)+f2(rs,j),f3(k,i))+d_{i}+d_{j}&i\in sub_{ls},ls\ne \empty,rs\ne\empty\end{cases}
\]
利用\(f[0-3]\)在线\(o(1)\)算出\(f_{k,i,j}\),空间复杂度降为\(o(n^{2})\)
对式子简单优化(求\(f0\)时需要将\(f\)展开),时间复杂度也降为\(o(n^{2})\)
#include<bits/stdc++.h>
using namespace std;
#define N 5005
#define ll long long
int n,x,d[N],ls[N],rs[N],dfn[N],vis[N];
ll ans,f0[N][N],f1[N][N],f2[N][N],f3[N][N];
vector<int>v0,v[N];
ll get(int k,int i,int j){
if (i==k)return f0[k][j];
if (!rs[k])return f1[ls[k]][i]+d[i]+d[j];
if (dfn[i]<dfn[rs[k]])return min(f1[ls[k]][i]+f2[rs[k]][j],f3[k][i])+d[i]+d[j];
return min(f1[rs[k]][i]+f2[ls[k]][j],f3[k][i])+d[i]+d[j];
}
void dfs(int k,int fa){
dfn[k]=++dfn[0],v[k].push_back(k);
if (ls[k]){
dfs(ls[k],k);
for(int i:v[ls[k]])v[k].push_back(i);
}
if (rs[k]){
dfs(rs[k],k);
for(int i:v[rs[k]])v[k].push_back(i);
}
v0.clear();
memset(vis,0,sizeof(vis));
for(int i:v[k])vis[i]=1;
for(int i=1;i<=n;i++)
if (!vis[i])v0.push_back(i);
if (!ls[k]){
for(int j:v0)f0[k][j]=0;
if (fa){
for(int i:v[k])f1[k][i]=0;
}
int s=1e9;
for(int i:v[k])s=min(s,d[i]);
for(int j:v0)f2[k][j]=s;
}
else{
if (!rs[k]){
for(int j:v0)f0[k][j]=f2[ls[k]][j];
if (fa){
for(int x:v[ls[k]])f1[k][k]=min(f1[k][k],get(ls[k],x,fa));
for(int i:v[ls[k]])f1[k][i]=get(ls[k],i,k);
}
ll s=1e18;
for(int i:v[ls[k]])s=min(s,get(ls[k],i,k)+d[i]);
for(int j:v0)f2[k][j]=min(f2[ls[k]][j]+d[k],s);
}
else{
for(int j:v0)f0[k][j]=f0[ls[k]][j]+f2[rs[k]][ls[k]];
if (ls[ls[k]]){
ll s1=1e18,s2=1e18;
for(int x:v[ls[ls[k]]]){
s1=min(s1,f1[ls[ls[k]]][x]+d[x]+f2[rs[k]][x]);
s2=min(s2,f3[ls[k]][x]+d[x]+f2[rs[k]][x]);
}
if (!rs[ls[k]]){
for(int j:v0)f0[k][j]=min(f0[k][j],s1+d[j]);
}
else{
for(int j:v0)f0[k][j]=min(f0[k][j],min(s1+f2[rs[ls[k]]][j],s2)+d[j]);
s1=s2=1e18;
for(int x:v[rs[ls[k]]]){
s1=min(s1,f1[rs[ls[k]]][x]+d[x]+f2[rs[k]][x]);
s2=min(s2,f3[ls[k]][x]+d[x]+f2[rs[k]][x]);
}
for(int j:v0)f0[k][j]=min(f0[k][j],min(s1+f2[ls[ls[k]]][j],s2)+d[j]);
}
}
for(int j:v0)f0[k][j]=min(f0[k][j],f0[rs[k]][j]+f2[ls[k]][rs[k]]);
if (ls[rs[k]]){
ll s1=1e18,s2=1e18;
for(int y:v[ls[rs[k]]]){
s1=min(s1,f1[ls[rs[k]]][y]+d[y]+f2[ls[k]][y]);
s2=min(s2,f3[rs[k]][y]+d[y]+f2[ls[k]][y]);
}
if (!rs[rs[k]]){
for(int j:v0)f0[k][j]=min(f0[k][j],s1+d[j]);
}
else{
for(int j:v0)f0[k][j]=min(f0[k][j],min(s1+f2[rs[rs[k]]][j],s2)+d[j]);
s1=s2=1e18;
for(int y:v[rs[rs[k]]]){
s1=min(s1,f1[rs[rs[k]]][y]+d[y]+f2[ls[k]][y]);
s2=min(s2,f3[rs[k]][y]+d[y]+f2[ls[k]][y]);
}
for(int j:v0)f0[k][j]=min(f0[k][j],min(s1+f2[ls[rs[k]]][j],s2)+d[j]);
}
}
for(int i:v[ls[k]])
for(int j:v[rs[k]])f3[k][i]=min(f3[k][i],get(rs[k],j,k)+get(ls[k],i,j));
for(int i:v[rs[k]])
for(int j:v[ls[k]])f3[k][i]=min(f3[k][i],get(ls[k],j,k)+get(rs[k],i,j));
if (fa){
for(int x:v[ls[k]])f1[k][k]=min(f1[k][k],get(ls[k],x,fa)+f2[rs[k]][x]);
for(int y:v[rs[k]])f1[k][k]=min(f1[k][k],get(rs[k],y,fa)+f2[ls[k]][y]);
for(int i:v[ls[k]])f1[k][i]=min(get(ls[k],i,k)+f2[rs[k]][fa],f3[k][i]);
for(int i:v[rs[k]])f1[k][i]=min(get(rs[k],i,k)+f2[ls[k]][fa],f3[k][i]);
}
ll s=1e18,s1=1e18,s2=1e18;
for(int i:v[ls[k]])s=min(s,f3[k][i]+d[i]),s1=min(s1,get(ls[k],i,k)+d[i]);
for(int i:v[rs[k]])s=min(s,f3[k][i]+d[i]),s2=min(s2,get(rs[k],i,k)+d[i]);
for(int j:v0)f2[k][j]=min(f0[k][j]+d[k],min(s,min(s1+f2[rs[k]][j],s2+f2[ls[k]][j])));
}
}
for(int j:v0)f0[k][j]+=d[k]+d[j];
if (fa){
for(int i:v[k])f1[k][i]+=d[i]+d[fa];
}
for(int j:v0)f2[k][j]+=d[j];
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%d",&d[i]);
for(int i=2;i<=n;i++){
scanf("%d",&x);
if (!ls[x])ls[x]=i;
else rs[x]=i;
}
memset(f0,0x3f,sizeof(f0));
memset(f1,0x3f,sizeof(f1));
memset(f2,0x3f,sizeof(f2));
memset(f3,0x3f,sizeof(f3));
ans=1e18,dfs(1,0);
if (!ls[1])ans=0;
else{
if (!rs[1]){
for(int x:v[ls[1]])ans=min(ans,get(ls[1],x,1));
}
else{
for(int x:v[ls[1]])
for(int y:v[rs[1]])ans=min(ans,min(get(ls[1],x,1)+get(rs[1],y,x),get(rs[1],y,1)+get(ls[1],x,y)));
}
}
printf("%lld\n",ans);
return 0;
}