bzoj 4182
首先很容易看出这是一个树上多重背包问题
设状态$f[i][j]$表示以$i$为根的子树中利用的体积是$j$
但是题目中有要求:选择的点集必须是一个联通块
这要怎么处理?
点分治!
首先我们利用点分治的思想,每次拎起一个根节点进行处理,要求这个根节点必选,然后在子树内进行dp
为了保证根节点必选(至少选一个),所以我们在初值时按根节点先选一个处理,也就是在最大合法体积上先去掉一个根节点的体积,然后进行dfs更新,对于子树中每个点同理。
多重背包用二进制优化
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <iostream> #include <algorithm> #include <queue> #include <stack> using namespace std; const int inf=0x3f3f3f3f; struct Edge { int next; int to; }edge[1005]; int head[505]; int w[505]; int v[505]; int d[505]; int maxp[505]; int siz[505]; int f[505][5005]; bool vis[505]; int s,rt; int cnt=1; int ans=0; int n,m; int T; void init() { memset(head,-1,sizeof(head)); memset(vis,0,sizeof(vis)); memset(f,0,sizeof(f)); ans=0; cnt=1; } void add(int l,int r) { edge[cnt].next=head[l]; edge[cnt].to=r; head[l]=cnt++; } void get_rt(int x,int fx) { siz[x]=1,maxp[x]=0; for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(to==fx||vis[to])continue; get_rt(to,x); siz[x]+=siz[to],maxp[x]=max(maxp[x],siz[to]); } maxp[x]=max(maxp[x],s-siz[x]); if(maxp[x]<maxp[rt])rt=x; } void dfs(int x,int fx,int lim) { if(lim<=0)return; int j=d[x]; for(int i=1;i<j;j-=i,i<<=1) { for(int k=lim;k>=i*v[x];k--)f[x][k]=max(f[x][k],f[x][k-i*v[x]]+i*w[x]); } for(int k=lim;k>=j*v[x];k--)f[x][k]=max(f[x][k],f[x][k-j*v[x]]+j*w[x]); for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to]||to==fx)continue; for(int j=0;j<=lim-v[to];j++)f[to][j]=f[x][j]+w[to]; dfs(to,x,lim-v[to]); for(int j=0;j<=lim-v[to];j++)f[x][j+v[to]]=max(f[x][j+v[to]],f[to][j]); } } void solve(int x) { vis[x]=1; for(int i=0;i<=m-v[x];i++)f[x][i]=w[x]; dfs(x,0,m-v[x]); for(int i=0;i<=m-v[x];i++)ans=max(ans,f[x][i]); for(int i=head[x];i!=-1;i=edge[i].next) { int to=edge[i].to; if(vis[to])continue; rt=0,s=siz[to],maxp[rt]=inf; get_rt(to,0); solve(rt); } } int main() { scanf("%d",&T); while(T--) { scanf("%d%d",&n,&m); init(); for(int i=1;i<=n;i++)scanf("%d",&w[i]); for(int i=1;i<=n;i++)scanf("%d",&v[i]); for(int i=1;i<=n;i++)scanf("%d",&d[i]),d[i]--; for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y),add(y,x); } rt=0; maxp[rt]=s=n; get_rt(1,0); solve(rt); printf("%d\n",ans); } return 0; }