[HG]子树问题 题解
前言
模拟赛赛时SubtaskR3没开long long丢了20分。
题意简述
题目描述
对于一棵有根树(设其节点数为 \(n\) ,则节点编号从 \(1\) 至 \(n\) ),如果它满足所有非根节点的编号均比起父亲更大,我们就说它是Y树。
此外,出题人给出了 \(k\) 个整数 \(a_1, \dots, a_k\),并规定,只要一棵有根树存在一个子树包含的节点数恰好为 \(a_1, \dots, a_k\) 中的某一个值,那么它不符合条件;
现给定 \(n,k,a1,\dots,ak\),并额外给定整数 \(L,R\) ,请你对于 \(d = L,L+1,\dots,R\) ,分别求出 \(n\) 个节点的深度为 \(d\) 的符合条件的Y树的数量。
数据范围及提示
对于所有测试点,保证 \(0 \leq k < n \leq 500\)
题解
部分分
首先讲暴力,DFS即可。
首先写一下 \(R \leq 3\) 的部分分。
很多同学赛时试图找规律,但是实际上并没有什么规律,
很显然,当树的高度为2的时候,只有一种情况,就是我们常说的"菊花图"。
那么当树的高度增加到3的时候,显然可以想到一种可行的变换,
我们保留一部分的节点,。
对于 \(L \geq n - 2\) 的部分分。
首先打表找规律,当树高 \(n - 1\) 时,方案数为 \(\frac{n \times (n - 1)}{2} - 2\)。
可以想象为一条链上拆下来一个节点,连接到另一个节点上。
减去的两种方案分别为节点 \(1\) 多计算了一次,节点 \(n\) 不能重新连接到节点 \(n-1\) 上
当树高 \(n - 2\) 时,显然从一条链上拆下来两个点,
我们可以分成两种情况来看
- 两个点连在一起
- 两个点分开
最后在减去一些零零散散的重复计算部分,就完成了。
代码
警告:以下代码为暴力代码,非常的长,可以跳过
#pragma GCC optimize(3)
#include <cstdio>
#include <cstring>
#include <vector>
#define MOD 998244353
using namespace std;
int n, k, L, R;
int a[26];
namespace SubtaskBrute{
vector<int> son[26];
int ban[26], sum[26];
int d[26];
long long ans = 0;
int qry;
bool solve(int u){
int u_s = son[u].size(); sum[u] = 1;
for (int i = 0; i < u_s; ++i){
if (!solve(son[u][i])) return 0;
sum[u] += sum[son[u][i]];
}
return (!ban[sum[u]]);
}
void DFS(int u){
if (u == n + 1){ ans += solve(1); return ; }
for (int i = 1; i < u; ++i){
if (d[i] + 1 > qry) continue;
d[u] = d[i] + 1;
son[i].push_back(u);
DFS(u + 1);
son[i].pop_back();
}
}
int res[26];
void index(){
for (int i = 1; i <= k; ++i) ban[a[i]] = 1;
d[1] = 1;
for (qry = L - 1; qry <= R; ++qry){
if (!qry) res[qry] = 0;
else{ ans = 0; DFS(2); res[qry] = ans; }
}
for (int i = L; i <= R; ++i) printf("%lld ", (res[i] - res[i - 1]) % 998244353);
}
}
namespace Subtask1{
int index(int num){
if (num <= 0) return 0;
puts("HJC AK IOI!");
}
}
namespace SubtaskR3{
long long f[505][505];
void index(){
f[1][1] = 1;
for (int i = 2; i <= n; ++i)
for (int j = 1; j < i; ++j)
f[i][j] = (f[i - 1][j - 1] + f[i - 1][j] * j) % MOD;
long long ans = -1;
for (int j = 1; j < n; ++j)
ans = (ans + f[n][j]) % MOD;
if (L <= 1) printf("0 ");
if (L <= 2) printf("1 ");
printf("%lld", ans);
}
}
namespace SubtaskL2{
inline long long getL1(long long n){
return (((n * (n - 1) >> 1) - 2) % 998244353);
}
void index(){
long long ans = 1 - (n << 1);
for (int i = 2; i < n; ++i)
for (int j = i + 1; j <= n; ++j){
ans += i - 1;
if (i == n - 1) --ans;
ans %= MOD;
}
for (int i = 2; i < n; ++i)
for (int j = i + 1; j <= n; ++j){
if (i == n - 1 && j == n)
ans = (ans + (n - 3) * (n - 3)) % MOD;
else{
ans += (i - 1) * (j - 2) % MOD;
ans %= MOD;
if (j == n - 1) ans -= (i - 1);
else if (j == n) ans -= (i - 1);
}
}
printf("%lld", (ans + MOD) % MOD);
if (R >= n - 1) printf(" %lld", getL1(n));
if (R >= n) printf(" 1");
}
}
int main(){
freopen("subtree.in", "r", stdin);
freopen("subtree.out", "w", stdout);
scanf("%d %d", &n, &k);
for (int i = 1; i <= k; ++i) scanf("%d", &a[i]);
scanf("%d %d", &L, &R);
if (n <= 10){ SubtaskBrute::index(); return 0; }
else if (R == 3 && k == 0){ SubtaskR3::index(); return 0; }
else if (L == n - 2 && k == 0){ SubtaskL2::index(); return 0; }
return 0;
}
正解
非常简单的动态规划。
我们定义状态 \(f[i][d]\) 表示大小为 \(i\) ,高度为 \(d\) 的Y树种类数。
我们转移状态的时候考虑以"合并"的方式转移。
为了避免重复计算,我们定义次小节点一定合并到最小节点上(最小节点为树根即1,显然次小节点无论如何都得连接在上面)
避免了重复计算以后,我们枚举合并进来的子树的大小,可以列出如下的式子:
那么您可能会觉得奇怪,这还没有考虑限制条件呢?
其实我们只需要在DP时,把限制条件所限制的子树大小,DP值清零即可。
代码
#include <cstdio>
#define MOD 998244353
int C[505][505], f[505][505];
bool ban[505];
int main(){
int n, k; scanf("%d %d", &n, &k);
for (int i = 0; i <= n; ++i){
C[i][0] = 1;
for (int j = 1; j <= i; ++j)
C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % MOD;
}
for (int i = 1; i <= k; ++i){
int x; scanf("%d", &x);
ban[x] = 1;
}
if (ban[1]){
int L, R; scanf("%d %d", &L, &R);
for (int i = L; i <= R; ++i)
printf("0%c", (i != R ? ' ' : '\n'));
return 0;
}
f[1][1] = 1;
for (int d = 2; d <= n; ++d){
f[1][d] = 1;
for (int i = 2; i <= n; ++i)
for (int j = 1; j < i; ++j)
f[i][d] = (f[i][d] + 1ll * f[i - j][d] * f[j][d - 1] % MOD * C[i - 2][j - 1]) % MOD;
for (int i = 1; i <= n; ++i)
if (ban[i]) f[i][d] = 0;
}
int L, R; scanf("%d %d", &L, &R);
for (int i = L; i <= R; ++i)
printf("%d%c", (f[n][i] - f[n][i - 1] + MOD) % MOD, (i != R ? ' ' : '\n'));
return 0;
}