树形dp总结
树形dp模型:
- 以某一个节点为根,满足一定条件下的最大结果。这类题多自上向下转移,由已经算好的u的维护信息 更新 v所维护的信息。
- H题中的1-k问题,可以转换成第k大的模型去解决。
- 如J题涉及两点间路径选择,一般转移很多。
- F题,与背包问题组合
int n, q;
int head[N << 1], cnt = 0;
int to[N << 1], nxt[N << 1];
int res[N], siz[N], son[N], fa[N];
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
}
void dfs(int u, int pre){
siz[u] = 1;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
dfs(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
if(siz[son[u]] * 2 > siz[u]){
int rt = res[son[u]];
while((siz[u] - siz[rt]) * 2 > siz[u]) rt = fa[rt];
res[u] = rt;
}
else res[u] = u;
}
int main()
{
scanf("%d%d",&n,&q);
cnt = 0;
for(int i = 0; i <= n; ++ i) head[i] = -1;
for(int i = 2; i <= n; ++ i){
int x; scanf("%d",&x);
fa[i] = x;
add(x, i);
}
dfs(1, 0);
while(q --){
int x; scanf("%d",&x);
printf("%d\n",res[x]);
}
return 0;
}
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int a[N], dp[N], res[N];
vector<int> sol[N];
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs(int u, int pre){
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
int tt = sol[u].size();
res[v] = dp[u];
for(int i = 0; i < tt; ++ i){
sol[v].push_back(gcd(sol[u][i], a[v]));
res[v] = max(res[v], sol[v][i]);
}
sol[v].push_back(dp[u]);
sort(sol[v].begin(),sol[v].end());
sol[v].erase(unique(sol[v].begin(),sol[v].end()), sol[v].end());
dp[v] = gcd(dp[u], a[v]);
dfs(v, u);
}
}
int main()
{
scanf("%d",&n);
cnt = 0;
for(int i = 0; i <= n; ++ i){
head[i] = -1;
}
for(int i = 1; i <= n; ++ i) scanf("%d",&a[i]);
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
dp[1] = a[1];
res[1] = a[1];
sol[1].push_back(0);
dfs(1, 0);
for(int i = 1; i <= n; ++ i){
if(i == n) printf("%d\n",res[i]);
else printf("%d ",res[i]);
}
return 0;
}
int n, m, d;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int down[N], up[N];
int res = 0;
int vis[N];
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs1(int u, int pre){
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
dfs1(v, u);
if(down[v] != -1) down[u] = max(down[u], down[v] + 1);
}
if(vis[u]){
down[u] = max(down[u], 0);
}
}
bool cmp(int x, int y){
return down[x] > down[y];
}
void dfs(int u, int pre){
vector<int> sol; sol.clear();
if(max(up[u], down[u]) <= d) res ++;
if(vis[u]) up[u] = max(up[u], 0);
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
sol.push_back(v);
}
sort(sol.begin(), sol.end(), cmp);
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
if(v == sol[0]){
int maxx = -1;
if(up[u] != -1) maxx = max(maxx, up[u]);
if(sol.size() > 1 && down[sol[1]] != -1) maxx = max(maxx, down[sol[1]] + 1);
if(maxx != -1) up[v] = maxx + 1;
}
else{
int maxx = -1;
if(up[u] != -1) maxx = max(maxx, up[u]);
if(down[sol[0]] != -1) maxx = max(maxx, down[sol[0]] + 1);
if(maxx != -1) up[v] = maxx + 1;
}
dfs(v, u);
}
}
int main()
{
scanf("%d%d%d",&n,&m,&d);
cnt = 0;
for(int i = 0; i <= n; ++ i){
head[i] = -1; up[i] = down[i] = -1;
}
for(int i = 1; i <= m; ++ i){
int x; scanf("%d",&x); vis[x] = 1;
}
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
if(m == 0){
printf("0\n");
return 0;
}
dfs1(1, 0);
dfs(1, 0);
printf("%d\n",res);
return 0;
}
int n, B;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int dep[N], len[N], fa[N];
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v] ,head[v] = cnt ++;
}
void dfs(int u, int pre){
fa[u] = pre;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
len[v] = len[u] + 1;
dfs(v, u);
dep[u] = max(dep[u], dep[v] + 1);
}
}
int main()
{
scanf("%d%d",&n,&B);
cnt = 0;
for(int i = 0; i <= n + 10; ++ i) head[i] = -1;
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
dfs(1, 0);
int rt = B, res = 0, num = 0;
while(rt != 1){
if(len[rt] * 2 <= len[B]) break;
res = max(res, (len[rt] + dep[rt]) * 2);
rt = fa[rt];
}
printf("%d\n",res);
return 0;
}
E_CF219D Choosing Capital for Treeland
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1], c[N << 1];
int dp[N], num[N][2];
vector<int> sol;
void add(int u, int v, int w){
to[cnt] = v, c[cnt] = w, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, c[cnt] = -w, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs1(int u,int pre){
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i], w = c[i];
if(v == pre) continue;
if(w == 1) num[u][1] ++;
else num[u][0] ++;
dfs1(v, u);
num[u][1] += num[v][1];
num[u][0] += num[v][0];
}
}
void dfs(int u, int pre){
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i], w = c[i];
if(v == pre) continue;
dp[v] = dp[u] + w;
dfs(v, u);
}
}
int main()
{
scanf("%d",&n);
cnt = 0;
for(int i = 0; i <= n; ++ i) head[i] = -1;
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y, 1);
}
dfs1(1, 0);
dp[1] = num[1][0];
dfs(1, 0);
int minn = INF;
for(int i = 1; i <= n; ++ i) minn = min(minn, dp[i]);
for(int i = 1; i <= n; ++ i){
if(dp[i] == minn) sol.push_back(i);
}
int tt = sol.size();
printf("%d\n",minn);
for(int i = 0; i < tt; ++ i){
if(i == tt - 1) printf("%d\n",sol[i]);
else printf("%d ",sol[i]);
}
return 0;
}
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int siz[N];
vector<int> res;
bool vis[N];
int dp[N][N];
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs(int u, int pre){
vector<int> sol;
siz[u] = 1;
dp[u][0] = 1;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
dfs(v, u);
siz[u] += siz[v];
sol.push_back(siz[v]);
}
if(pre != 0) sol.push_back(n - siz[u]);
int tt = sol.size();
for(int i = 0; i < tt; ++ i){
for(int j = n - 1; j >= 0; -- j){
if(dp[u][j]) dp[u][j + sol[i]] = 1;
}
}
for(int i = 1; i < n - 1; ++ i){
if(dp[u][i] && !vis[i]){
vis[i] = true;
vis[n - i - 1] = true;
res.push_back(i);
if(n - i - 1 != i) res.push_back(n - i - 1);
}
}
}
int main()
{
scanf("%d",&n);
cnt = 0;
for(int i = 0; i <= n; ++ i) head[i] = -1;
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
dfs(1, 0);
sort(res.begin(), res.end());
int tt = res.size();
printf("%d\n",tt);
for(int i = 0; i < tt; ++ i){
printf("%d %d\n",res[i], n - 1 - res[i]);
}
return 0;
}
int n, k;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
ll dp[N][520];
ll res = 0;
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs(int u, int pre){
dp[u][0] = 1;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
dfs(v, u);
for(int i = 0; i < k; ++ i){
res += dp[u][i] * dp[v][k - i - 1];
}
for(int i = 1; i <= k; ++ i){
dp[u][i] += dp[v][i - 1];
}
}
}
int main()
{
scanf("%d%d",&n,&k);
cnt = 0;
for(int i = 0; i <= n; ++ i) head[i] = -1;
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
dfs(1, 0);
printf("%lld\n",res);
return 0;
}
H_CF1153D Serval and Rooted Tree
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int val[N], dp[N], res = 0;
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs(int u, int pre){
int tp = INF, flag = 0;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
flag = 1;
dfs(v, u);
if(val[u]) tp = min(tp, dp[v]);
else dp[u] += dp[v];
}
if(val[u]) dp[u] = tp;
if(!flag){
dp[u] = 1, res ++;
}
}
int main()
{
scanf("%d",&n);
cnt = 0;
for(int i = 1; i <= n; ++ i) head[i] = -1;
for(int i = 1; i <= n; ++ i) scanf("%d",&val[i]);
for(int i = 2; i <= n; ++ i){
int x; scanf("%d",&x);
add(x, i);
}
dfs(1, 0);
printf("%d\n",res - dp[1] + 1);
return 0;
}
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int up[N], dep[N], val1[N], val2[N];
int res = 0;
struct node{
int val, si;
};
bool cmp(node a, node b){
return a.val > b.val;
}
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs1(int u, int pre){
int max1 = 0, max2 = 0;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
dfs1(v, u);
val2[u] = max(val2[u], val2[v]);
if(dep[v] + 1 > max1){
max2 = max1;
max1 = dep[v] + 1;
}
else if(dep[v] + 1 > max2) max2 = dep[v] + 1;
}
dep[u] = max1;
val2[u] = max(val2[u], max1 + max2);
}
void dfs2(int u, int pre){
vector<node> sol;
int dep1 = 0, dep2 = 0, tdep = 0;
int max1 = val1[pre], max2 = 0, tp = 0;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
if(dep[v] + 1 > dep1){
dep2 = dep1;
dep1 = dep[v] + 1;
tdep = v;
}
else if(dep[v] + 1 > dep2) dep2 = dep[v] + 1;
sol.push_back((node){dep[v] + 1, v});
if(val2[v] > max1){
max2 = max1;
max1 = val2[v];
tp = v;
}
else if(val2[v] > max2) max2 = val2[v];
}
sol.push_back((node){up[u], u});
sort(sol.begin(),sol.end(),cmp);
int tt = sol.size();
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
if(v == tdep) up[v] = 1 + max(up[u], dep2);
else up[v] = 1 + max(up[u], 1 + dep1);
if(tt > 2){
int tval = 0;
if(v == sol[0].si) tval = sol[1].val + sol[2].val;
else if(v == sol[1].si) tval = sol[0].val + sol[2].val;
else tval = sol[0].val + sol[1].val;
val1[u] = tval;
}
else val1[u] = up[u];
if(v == tp) val1[u] = max(val1[u], max2);
else val1[u] = max(val1[u], max1);
res = max(res, val2[v] * val1[u]);
dfs2(v, u);
}
}
int main()
{
scanf("%d",&n);
cnt = 0;
for(int i = 0; i <= n; ++ i) head[i] = -1;
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
dfs1(1, 0);
dfs2(1, 0);
printf("%d\n",res);
return 0;
}
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1], c[N << 1];
int root[2][N], num[2][N];
int tx[N], ty[N], tz[N];
ll res = 0;
void add(int u, int v, int w){
to[cnt] = v, c[cnt] = w, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, c[cnt] = w, nxt[cnt] = head[v], head[v] = cnt ++;
}
int Find(int op, int x){
return root[op][x] == x ? x : root[op][x] = Find(op, root[op][x]);
}
void Union(int op, int x, int y){
int tx = Find(op, x), ty = Find(op, y);
if(tx != ty){
root[op][tx] = ty;
num[op][ty] += num[op][tx];
num[op][tx] = 0;
}
}
void solve(){
for(int i = 1; i <= n; ++ i){
if(root[1][i] == i){
res += 1ll * num[1][i] * (num[1][i] - 1);
}
if(root[0][i] == i){
res += 1ll * num[0][i] * (num[0][i] - 1);
}
}
// cout<<res<<endl;
for(int u = 1; u <= n; ++ u){
int num0 = 0, num1 = 0;
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i], w = c[i];
if(w){
num1 = 1;
}
else{
num0 = 1;
}
}
if(num1 && num0){
num1 = num[1][Find(1, u)] - 1;
num0 = num[0][Find(0, u)] - 1;
res += 1ll * num1 * num0;
}
}
}
int main()
{
scanf("%d",&n);
cnt = 0;
for(int i = 0; i <= n; ++ i){
head[i] = -1; root[0][i] = root[1][i] = i;
num[0][i] = num[1][i] = 1;
}
for(int i = 1; i < n; ++ i){
scanf("%d%d%d",&tx[i],&ty[i],&tz[i]);
add(tx[i], ty[i], tz[i]);
if(tz[i]){
Union(1, tx[i], ty[i]);
}
else{
Union(0, tx[i], ty[i]);
}
}
solve();
printf("%lld\n",res);
return 0;
}
K_CF1092F Tree with Maximum Cost
int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
ll a[N], dep[N], dp[N], val[N];
ll all, res;
void add(int u, int v){
to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs(int u, int pre){
val[u] = a[u];
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
dep[v] = dep[u] + 1;
dfs(v, u);
val[u] += val[v];
}
}
void dfs1(int u, int pre){
res = max(res, dp[u]);
for(int i = head[u]; i != -1; i = nxt[i]){
int v = to[i];
if(v == pre) continue;
dp[v] = dp[u] + all - 2 * val[v];
dfs1(v, u);
}
}
int main()
{
scanf("%d",&n);
cnt = 0;
for(int i = 0; i <= n; ++ i) head[i] = -1;
for(int i = 1; i <= n; ++ i){
scanf("%d",&a[i]);
all += a[i];
}
for(int i = 1; i < n; ++ i){
int x, y; scanf("%d%d",&x,&y);
add(x, y);
}
dfs(1, 0);
for(int i = 1; i <= n; ++ i){
dp[1] += a[i] * dep[i];
}
dfs1(1, 0);
printf("%lld\n",res);
return 0;
}