Solution -「HDU #6566」The Hanged Man
\(\mathcal{Description}\)
Link.
给定一棵含 \(n\) 个点的树,每个结点有两个权值 \(a\) 和 \(b\)。对于 \(k\in[1,m]\),分别求
\[\left|\arg\max_{\sum_{u\in S} a_u=k}\sum_{u\in S}b_u\right|
\]
其中 \(S\) 是树上的一个独立点集。
测试数据组数 \(\le20\),\(n\le50\),\(m\le5\times10^3\)。
\(\mathcal{Solution}\)
一个 naive 的 DP 方法是,令 \(f(u,i,0/1)\) 表示 \(u\) 子树内,入 \(S\) 的点 \(a\) 值之和为 \(i\),且 \(u\) 入/未入集时最大的 \(b\) 值之和及其方案数。这种做法慢在需要 \(\mathcal O(m^2)\) 合并两棵子树的 DP 信息。我们尝试转化成一种仅需要单点更新的 DP。
考虑点分树,因为结点 \(u\) 在原树上的邻接点必然是点分树上 \(u\) 的祖先或子树内的孩子,所以可以在点分树上以 DFN 的顺序 DP,令状态 \(f(i,s,j)\) 表示考虑了 DFN 在 \([1,i]\) 之间的结点,设 \(u\) 是 DFN 为 \(i\) 的结点,那么 \(s\) 为 \(u\) 的祖先中被选结点的深度集合,\(j\) 为 \(a\) 值之和,状态表示最大值及方案数。
以上,每次暴力插入一个结点更新 DP 数组,就能做到 \(\mathcal O(Tnm2^{\log n})=\mathcal O(Tn^2m)\) 了。
(我写 T 了,也许是代码问题,抱歉 qwq。
\(\mathcal{Code}\)
会 T 掉的可怜代码。
/* Clearink */
#include <cstdio>
#define rep( i, l, r ) for ( int i = l, repEnd##i = r; i <= repEnd##i; ++i )
#define per( i, r, l ) for ( int i = r, repEnd##i = l; i >= repEnd##i; --i )
typedef long long LL;
inline void chkmax( int& a, const int b ) { a < b && ( a = b, 0 ); }
const int MAXN = 100, MAXM = 5e3;
int n, m, a[MAXN + 5], b[MAXN + 5];
int siz[MAXN + 5], wgt[MAXN + 5], ban[MAXN + 5], dep[MAXN + 5];
bool vis[MAXN + 5];
struct Graph {
int ecnt, head[MAXN + 5], to[MAXN * 2 + 5], nxt[MAXN * 2 + 5];
inline void operator () ( const int u, const int v ) {
to[++ecnt] = v, nxt[ecnt] = head[u], head[u] = ecnt;
}
inline void clear( const int n ) {
ecnt = 0;
rep ( i, 1, n ) head[i] = 0;
}
} srcT, divT;
#define adj( T, u, v ) \
for ( int i = T.head[u], v; v = T.to[i], i; i = T.nxt[i] )
struct Atom {
int val; LL cnt;
Atom(): val( 0 ), cnt( 0 ) {}
Atom( const int v, const LL c ): val( v ), cnt( c ) {}
inline Atom operator + ( const int x ) {
return cnt ? Atom( val + x, cnt ) : Atom();
}
inline Atom& operator += ( const Atom& u ) {
if ( val == u.val ) cnt += u.cnt;
if ( val < u.val ) *this = u;
return *this;
}
} f[MAXN * 4 + 5][MAXM + 5];
inline void findG( const int u, const int fa, const int all, int& rt ) {
siz[u] = 1, wgt[u] = 0;
adj ( srcT, u, v ) if ( !vis[v] && v != fa ) {
findG( v, u, all, rt ), siz[u] += siz[v];
chkmax( wgt[u], siz[v] );
}
chkmax( wgt[u], all - siz[u] );
if ( !rt || wgt[u] < wgt[rt] ) rt = u;
}
inline void build( const int u ) {
vis[u] = true;
adj ( srcT, u, v ) if ( !vis[v] ) {
int rt = 0; findG( v, 0, siz[v], rt );
divT( u, rt ), dep[rt] = dep[u] + 1;
build( rt );
}
}
inline void solve( const int u, int& lasd ) {
int d = dep[u];
rep ( i, 1 << d, ( 1 << lasd << 1 ) - 1 ) {
Atom *cf = f[i & ( ( 1 << d ) - 1 )], *lf = f[i];
rep ( j, 0, m ) cf[j] += lf[j], lf[j] = Atom();
}
rep ( i, 0, ( 1 << d ) - 1 ) {
if ( i & ban[u] ) continue;
Atom *cf = f[i | 1 << d], *lf = f[i];
rep ( j, 0, m - a[u] ) cf[j + a[u]] += lf[j] + b[u];
}
lasd = d;
adj ( divT, u, v ) solve( v, lasd );
}
int main() {
int T; scanf( "%d", &T );
rep ( cas, 1, T ) {
scanf( "%d %d", &n, &m );
srcT.clear( n ), divT.clear( n );
rep ( i, 1, n ) vis[i] = false;
// f[][] is cleared when outputing;
// ban[] is cleared when calculating.
rep ( i, 1, n ) scanf( "%d %d", &a[i], &b[i] );
for ( int i = 2, u, v; i <= n; ++i ) {
scanf( "%d %d", &u, &v );
srcT( u, v ), srcT( v, u );
}
int rt = 0; findG( 1, 0, n, rt );
build( rt );
rep ( u, 1, n ) {
ban[u] = 0;
adj ( srcT, u, v ) ban[u] |= ( dep[v] < dep[u] ) << dep[v];
}
f[0][0].cnt = 1;
int lasd = 0; solve( rt, lasd );
printf( "Case %d:\n", cas );
rep ( i, 1, m ) {
Atom ans;
rep ( j, 0, ( 1 << lasd << 1 ) - 1 ) {
ans += f[j][i], f[j][i] = Atom();
}
printf( "%lld%c", ans.cnt, i ^ m ? ' ' : '\n' );
}
}
return 0;
}