CF1481F AB Tree
说实话,这道题真的是一道十分神仙的题。
看见这道题最直观的感受就是,我们试图让同一深度的节点都能够填上相同的字符。
最优情况是 \(x\) 恰好可以被某些深度的点的个数瓜分完,也就是说答案的下界就是节点的最大深度。
既然已经知道下界,我们不妨来分析一下上界。
我们考虑一种构造方案,对于每一层,我们用 \(a\) 或 \(b\) 先填非叶子节点,我们不妨设该层非叶子有 \(l\) 个,总共还没有被填入字符的点有 \(m\) 个,由于非叶子节点至少有一个儿子,所以则有 \(l \le \frac{m}{2}\)。我们设还有 \(x\) 个 \(a\) 可以填,则必然有 \(l \le \max(x , m - x)\) ,所以该层的非叶子节点一定能被填上 \(a\) 或 \(b\) 中的一个的相同字符。接下来我们考虑填该层的叶子节点,如果用填该层非叶子节点的字符把他们填完了,那也不错。如果没有填完,就意味着有一种字符被填完了,我们用另一种字符填剩下的叶子节点,同时之后的每一层我们都可以用另一种字符填完。由于我们该层与非叶子节点字符不同的都是叶子节点,因此我们的答案恰好为最大深度加上 \(1\) 。
我个人认为这道题最大的难点可能就是运用贪心构造的方式计算出一个优秀的答案上界。
现在我们的问题就只有如何判断答案是否是最大深度。
由之前的分析可以得到一个 \(n^2\) 的背包 \(dp\) 。
接下来就本题第二大难点(也可能只是一个我第一次见的常见技巧)。我们可以按照每层的节点个数进行合并,那么这就被转化成为一个多重背包问题。
很显然,在最劣情况下,每层的节点个数为: \(1,2,3 \dots x\) 。由于我们一共只有 \(n\) 个节点,因此不同的节点个数最多只有 \(\sqrt n\) 个。
代码实现比较简单,没有什么好说的。
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
template <typename T>
void read(T &x) {
T f=1;x=0;char s=getchar();
while(s<'0'||s>'9') {if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9') {x=(x<<3)+(x<<1)+(s^'0');s=getchar();}
x *= f;
}
template <typename T>
void write(T x , char s='\n') {
if(x<0) {putchar('-');x=-x;}
if(!x) {putchar('0');putchar(s);return;}
int tmp[22]={},t=0;
while(x) tmp[t++]=x%10,x/=10;
while(t-->0) putchar(tmp[t]+'0');
putchar(s);
}
const int MAXN = 1e5 + 5;
int head[MAXN] , to[MAXN << 1] , nxt[MAXN << 1] , ecnt;
void add(int u , int v) {nxt[++ecnt] = head[u];head[u] = ecnt;to[ecnt] = v;}
int n , x , dep[MAXN] , cnt , deg[MAXN] , c[500] , val[500] , us[500] , bl[MAXN];
pair <int , int> siz[MAXN];
int md;
char ans[MAXN];
vector <int> nod[MAXN];
bool cmp(int x , int y) {return deg[x] > deg[y];}
void pre(int x , int fa) {
dep[x] = dep[fa] + 1;
nod[dep[x]].push_back(x);
md = max(md , dep[x]);
siz[dep[x]].first ++;
siz[dep[x]].second = dep[x];
for (int i = head[x]; i; i = nxt[i]) {
int v = to[i];
if(v == fa) continue;
pre(v , x);
}
}
int f[480][MAXN];
int main() {
read(n),read(x);
for (int i = 2; i <= n; ++i) {
int p;read(p);
add(p , i);
deg[p] ++ , deg[i] ++;
}
pre(1 , 0);
sort(siz + 1 , siz + 1 + md);
for (int i = 1; i <= md; ++i) {
if(siz[i].first != siz[i - 1].first) c[++cnt] = 1 , val[cnt] = siz[i].first;
else c[cnt] ++;
bl[siz[i].second] = cnt;
}
for (int i = 1; i <= n; ++i) f[1][i] = -1;
for (int j = 1; j <= c[1]; ++j) f[1][j * val[1]] = 0;
for (int i = 2; i <= cnt; ++i) {
f[i][0] = 0;
for (int j = 1; j <= n; ++j) {
if(f[(i - 1)][j] != -1) f[i][j] = j;
else if(j >= val[i] && f[i][j - val[i]] != -1 && (j - f[i][j - val[i]]) / val[i] <= c[i]) f[i][j] = f[i][j - val[i]];
else f[i][j] = -1;
}
}
if(f[cnt][x] != -1) {
printf("%d\n" , md);
int now = x;
for (int i = cnt; i >= 1; --i) {
us[i] = (now - f[i][now]) / val[i];
now = f[i][now];
}
for (int i = 1; i <= md; ++i) {
char cur;
if(us[bl[i]]) cur = 'a' , us[bl[i]] --;
else cur = 'b';
for (int j = 0; j < (int)nod[i].size(); ++j) {
ans[nod[i][j]] = cur;
}
}
for (int i = 1; i <= n; ++i) putchar(ans[i]);
}
else {
printf("%d\n" , md + 1);
int y = n - x;
char cx = 'a' , cy = 'b';
for (int i = 1; i <= n; ++i) {
sort(nod[i].begin() , nod[i].end() , cmp);
if(x < y) swap(x , y) , swap(cx , cy);
for (int j = 0; j < (int)nod[i].size(); ++j) {
if(nod[i][j]) ans[nod[i][j]] = cx , x --;
if(!x) swap(x , y) , swap(cx , cy);
}
}
for (int i = 1; i <= n; ++i) putchar(ans[i]);
}
return 0;
}