BZOJ 2314 士兵的放置(支配集)
显然是\(DP\)。
设\(dp[i][0/1/2]\)代表以i为根且\(i上有士兵放置/i被控制但i上没有士兵/i没有被控制\)的最小代价。
\(g[i][0/1/2]\)代表对应的方案数。
然后运用乘法原理和加法原理转移即可。
转移是我写过的树形\(DP\)里比较\(X\)(不可描述)的。
所以还是看代码吧。。(虽然可能也看不懂)
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
#define int long long
const int mod=1032992941;
const int INF=1e9;
const int N=501000;
int cnt,head[N];
struct edge{
int to,nxt;
}e[N*2];
void add_edge(int u,int v){
cnt++;
e[cnt].nxt=head[u];
e[cnt].to=v;
head[u]=cnt;
}
int read(){
int sum=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){sum=sum*10+ch-'0';ch=getchar();}
return sum*f;
}
int dp[N][3],g[N][3],book[N];
int ksm(int x,int b){
int tmp=1;
while(b){
if(b&1)tmp=tmp*x%mod;
b>>=1;
x=x*x%mod;
}
return tmp;
}
void dfs(int u,int f){
dp[u][0]=1;
g[u][0]=g[u][1]=g[u][2]=1;
bool flag=false;
bool mmp=false;
int hhh=0;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f)continue;
flag=true;
dfs(v,u);
int tmp=min(dp[v][0],min(dp[v][1],dp[v][2]));
dp[u][0]+=tmp;
int w=0;
if(dp[v][0]==tmp)w=(w+g[v][0])%mod;
if(dp[v][1]==tmp)w=(w+g[v][1])%mod;
if(dp[v][2]==tmp)w=(w+g[v][2])%mod;
g[u][0]=g[u][0]*w%mod;
dp[u][2]=min(INF,dp[u][2]+dp[v][1]),g[u][2]=g[u][2]*g[v][1]%mod;
if(dp[v][0]<dp[v][1])mmp=true;
if(dp[v][1]==dp[v][0])hhh++;
}
if(flag==false){
dp[u][1]=INF;g[u][1]=0;
return;
}
if(mmp){
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f)continue;
int tmp=min(dp[v][0],dp[v][1]);
dp[u][1]+=tmp;
int w=0;
if(dp[v][0]==tmp)w=(w+g[v][0])%mod;
if(dp[v][1]==tmp)w=(w+g[v][1])%mod;
g[u][1]=g[u][1]*w%mod;
}
}
else{
if(hhh==0){
g[u][1]=0;
int mn=INF,awsl=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f)continue;
mn=min(mn,dp[v][0]-dp[v][1]);
awsl=awsl*g[v][1]%mod;
dp[u][1]+=dp[v][1];
}
dp[u][1]+=mn;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f)continue;
if(dp[v][0]-dp[v][1]==mn)
g[u][1]=(g[u][1]+awsl*ksm(g[v][1],mod-2)%mod*g[v][0]%mod)%mod;
}
}
else{
int awsl=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f)continue;
int tmp=min(dp[v][0],dp[v][1]);
dp[u][1]+=tmp;
int w=0;
if(dp[v][0]==tmp)w=(w+g[v][0])%mod;
if(dp[v][1]==tmp)w=(w+g[v][1])%mod;
g[u][1]=g[u][1]*w%mod;
awsl=awsl*g[v][1]%mod;
}
g[u][1]=(g[u][1]-awsl+mod)%mod;
}
}
}
int n;
signed main(){
n=read();
for(int i=1;i<n;i++){
int u=read(),v=read();
add_edge(u,v);add_edge(v,u);
}
dfs(1,0);
printf("%lld\n",min(dp[1][0],dp[1][1]));
if(dp[1][0]<dp[1][1])printf("%lld",g[1][0]);
else if(dp[1][0]>dp[1][1])printf("%lld",g[1][1]);
else printf("%lld",(g[1][0]+g[1][1])%mod);
}