Solution -「SV 2020 Round I」「SRM 551 DIV1」「TC 12141」SweetFruits
\(\mathcal{Description}\)
link.
给定 \(n\) 个水果,每个结点可能有甜度 \(v_i\),或不甜(\(v_i=-1\))。现在把这些水果串成一棵无根树。称一个水果“真甜”,当且仅当其本身和至少一个邻接水果是甜的。每个“真甜”水果对树的甜度产生 \(v_i\) 的贡献。求所有甜度不超过 \(maxv\) 的树。
\(n\le40\)。
\(\mathcal{Solution}\)
令无序地取恰好 \(i\) 个水果使其甜度和不超过 \(maxv\) 的方案数为 \(f_i\),树上恰有 \(i\) 个“真甜”果的方案数为 \(g_i\)。显然答案为 \(\sum f_ig_i\)。
\(\mathcal{Part~1}\)
求 \(f_i\)。
这是一个经典(我已经忘记)的 \(\text{meet in the middle}\) 问题。先取出甜水果全集,任意分成等大(或大小差 \(1\))的两部分。分别枚举两部分的子集,统计其甜度和及子集大小,并分别按甜度和为关键字升序排列,构成两个序列,令为 \(A,B\)。接下来用一个 \(\text{two-pointers}\) 的技巧。从小到大枚举 \(B\) 中的元素,注意到其甜度不减,所以可与其配对的 \(A\) 的元素范围逐渐向左减小。每次维护 \(A\) 中新的右端点,在暴力枚举在 \(A\) 中取的集合大小统计与 \(B\) 当前元素构成的方案数。(不像人话 qwq,看代码吧。)
这部分复杂度 \(\mathcal O(2^{\frac{n}2}n)\)。
\(\mathcal{Part~2}\)
求 \(g_i\)。
直接算出每个 \(g_i\) 貌似有些困难。我们先算出树上有不超过 \(i\) 个“真甜”果的方案数。设共有甜果 \(m\) 个,“真甜”果 \(k\) 个。不妨令“真甜”果为 \(1,2,\dots,k\),甜而非“真甜”(就叫它们清甜 w)果为 \(k+1,k+2,\dots,m\),不甜果为 \(m+1,m+2,\dots,n\)。想象一个完全图,删除所有“甜-清甜”与“清甜-清甜”的连边,用 \(\text{Matrix-Tree}\) 求出生成树个数。可以发现,这样的一棵生成树不可能让“清甜”变成“真甜”,所以这就是不超过 \(i\) 个“真甜”果的方案数。此时,再利用此前计算出的 \(g\) 减去多出来的一些方案即可。
这部分复杂度 \(\mathcal O(n^4)\),故总复杂度 \(\mathcal O(2^{\frac{n}2}n+n^4)\)。
\(\mathcal{Code}\)
#include <cstdio>
#include <vector>
#include <algorithm>
const int MAXN = 40, MOD = 1e9 + 7;
int n, val[MAXN + 5], maxv, inv[MAXN + 5], fac[MAXN + 5], ifac[MAXN + 5];
int chose[MAXN + 5], mayswt[MAXN + 5];
std::vector<int> swt;
std::vector<std::pair<int, int> > swtsum[2];
inline int qkpow ( int a, int b, const int p = MOD ) {
int ret = 1;
for ( ; b; a = 1ll * a * a % p, b >>= 1 ) ret = 1ll * ret * ( b & 1 ? a : 1 ) % p;
return ret;
}
struct MatrixTree {
int K[MAXN + 5][MAXN + 5];
inline void clear () {
for ( int i = 1; i <= n; ++ i ) {
for ( int j = 1; j <= n; ++ j ) {
K[i][j] = 0;
}
}
}
inline void add ( const int u, const int v ) {
++ K[u][u], ++ K[v][v], -- K[u][v], -- K[v][u];
if ( K[u][v] < 0 ) K[u][v] += MOD;
if ( K[v][u] < 0 ) K[v][u] += MOD;
}
inline int det () {
int ret = 1, swp = 1;
for ( int i = 1; i < n; ++ i ) {
for ( int j = i; j < n; ++ j ) {
if ( K[j][i] ) {
if ( i ^ j ) std::swap ( K[i], K[j] ), swp *= -1;
break;
}
}
if ( ! ( ret = 1ll * ret * K[i][i] % MOD ) ) return 0;
int inv = qkpow ( K[i][i], MOD - 2 );
for ( int j = i + 1; j < n; ++ j ) {
int d = 1ll * K[j][i] * inv % MOD;
for ( int k = i; k < n; ++ k ) {
K[j][k] = ( K[j][k] - 1ll * d * K[i][k] % MOD + MOD ) % MOD;
}
}
}
return ( ret * swp + MOD ) % MOD;
}
} mt;
inline void init () {
inv[1] = fac[0] = fac[1] = ifac[0] = ifac[1] = 1;
for ( int i = 2; i <= n; ++ i ) {
inv[i] = 1ll * ( MOD - MOD / i ) * inv[MOD % i] % MOD;
fac[i] = 1ll * i * fac[i - 1] % MOD;
ifac[i] = 1ll * inv[i] * ifac[i - 1] % MOD;
}
}
inline int comb ( const int n, const int m ) {
return n < m ? 0 : 1ll * fac[n] * ifac[m] % MOD * ifac[n - m] % MOD;
}
inline void calcSwt () {
int lef = swt.size () >> 1, rig = swt.size () - lef;
for ( int s = 0; s < 1 << lef; ++ s ) {
int bit = 0, sval = 0;
for ( int i = 0; i < lef; ++ i ) {
if ( ( s >> i ) & 1 ) {
++ bit, sval += swt[i];
}
}
if ( sval <= maxv ) swtsum[0].push_back ( std::make_pair ( sval, bit ) );
}
for ( int s = 0; s < 1 << rig; ++ s ) {
int bit = 0, sval = 0;
for ( int i = 0; i < rig; ++ i ) {
if ( ( s >> i ) & 1 ) {
++ bit, sval += swt[lef + i];
}
}
if ( sval <= maxv ) swtsum[1].push_back ( std::make_pair ( sval, bit ) );
}
std::sort ( swtsum[0].begin (), swtsum[0].end () );
std::sort ( swtsum[1].begin (), swtsum[1].end () );
int cnt[45] {};
for ( int i = 0; i ^ swtsum[0].size (); ++ i ) ++ cnt[swtsum[0][i].second];
for ( int i = 0, j = int ( swtsum[0].size () ) - 1; i ^ swtsum[1].size (); ++ i ) {
for ( ; ~ j && swtsum[1][i].first + swtsum[0][j].first > maxv; -- cnt[swtsum[0][j --].second] );
for ( int k = 0; k <= lef; ++ k ) {
chose[k + swtsum[1][i].second] = ( chose[k + swtsum[1][i].second] + cnt[k] ) % MOD;
}
}
}
class SweetFruits {
public:
inline int countTrees ( std::vector<int> tval, const int tmaxv ) {
n = tval.size (), init ();
for ( int i = 1; i <= n; ++ i ) val[i] = tval[i - 1];
maxv = tmaxv;
for ( int i = 1; i <= n; ++ i ) if ( ~ val[i] ) swt.push_back ( val[i] );
calcSwt ();
for ( int i = 0; i <= ( int ) swt.size (); ++ i ) {
mt.clear ();
for ( int u = 1; u <= n; ++ u ) {
for ( int v = u + 1; v <= n; ++ v ) {
if ( v <= i || ( int ) swt.size () < v ) {
mt.add ( u, v );
}
}
}
mayswt[i] = mt.det ();
for ( int j = 1; j < i - 1; ++ j ) {
mayswt[i] = ( mayswt[i] - 1ll * comb ( i, j ) * mayswt[i - j] % MOD + MOD ) % MOD;
}
if ( i ) mayswt[i] = ( mayswt[i] - mayswt[0] + MOD ) % MOD;
}
int ans = 0;
for ( int i = 0; i <= ( int ) swt.size (); ++ i ) if ( i ^ 1 ) ans = ( ans + 1ll * mayswt[i] * chose[i] ) % MOD;
return ans;
}
};