题解
- 由于如果 \(2^{2n}\) 枚举点对统计贡献那么复杂度便不允许做其他操作,不是很好搞。那么就考虑将点对间的贡献转成单点的贡献。
- 仔细观察题目中的表格,发现其实可以理解为在网络树中的非叶子节点,若其子树内选A节点数大于选B节点数,那么所有选B节点就会与其与当前子树内所有除自己以外的节点产生1倍贡献。
- 那么考虑枚举自上向下dfs,中途枚举每一个非叶子结点的子树内 选A节点多还是选B节点多 ,然后dfs到叶子节点时根据其所有祖先计算此叶子结点选A或B的贡献,并自底向上dp。状态为 \(dp[u][i]\) 表示u点子树内i个叶子结点选A的最小花费。注意非叶子节点状态转移时要和当前枚举的状态相同。即若枚举当前非叶子节点子树内选A节点多,那么便不能从其两个儿子的状态转移到 \(i < [区间叶子结点数]\) 的状态,反之同理。根节点的dp数组中最小值即为答案。
- 考虑时间复杂度,其相当于对于每一个叶子结点枚举其祖先的 \(01\) 状态,并有一个 \(O(n)\) 的统计贡献,那么复杂度就是叶子节点数 $\times $ 祖先状态数 \(\times\) \(n\) 。\(O(2^n \cdot 2^n \cdot n)\) 。
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define rep(i, s, t) for(int i = s, __ = t; i <= __; ++i)
#define dwn(i, s, t) for(int i = s, __ = t; i >= __; --i)
const int INF = 2147483647;
const int MAXN = 31;
const int MAXM = 100 + 2024;
const int MOD = 998244353;
using namespace std;
inline int read(int x = 0, int f = 1){
char ch = getchar();
for(; !isdigit(ch); ch = getchar())if(ch == '-')f = -1;
for(; isdigit(ch); ch = getchar())x = ch - '0' + x * 10;
return x * f;
}
inline void write(ll x){
if(x < 0)x = -x, putchar('-');
if(x >= 10)write(x / 10); putchar(x % 10 + '0');
return ;
}
ll dp[MAXM][MAXM];
int n, s[MAXM], c[MAXM][2], cst[MAXM][MAXN];
int typ[MAXM];
inline void upd(ll &x, ll y){x = min(x, y); return ;}
void dfs(int u, int L, int R, int dep){
if(L == R){
int x = u - (1 << dep - 1) + 1;
dp[u][0] = c[x][1], dp[u][1] = c[x][0];
rep(i, 1, dep - 1)dp[u][typ[i]] += cst[x][i];
return ;
}
int mid = L + R >> 1; rep(i, 0, R - L + 1)dp[u][i] = INF;
typ[dep] = 0;
dfs(u << 1, L, mid, dep + 1); dfs(u << 1 | 1, mid + 1, R, dep + 1);
rep(i, 0, mid - L + 1)rep(j, 0, R - mid)if(i + j >= R - mid)
upd(dp[u][i + j], dp[u << 1][i] + dp[u << 1 | 1][j]);
typ[dep] = 1;
dfs(u << 1, L, mid, dep + 1); dfs(u << 1 | 1, mid + 1, R, dep + 1);
rep(i, 0, mid - L + 1)rep(j, 0, R - mid)if(i + j < R - mid)
upd(dp[u][i + j], dp[u << 1][i] + dp[u << 1 | 1][j]);
return ;
}
int getlc(int x, int y, int tag = 0){
x += (1 << n) - 1, y += (1 << n) - 1;
dwn(i, n, 1)if((x >> 1) == (y >> 1))return i; else x >>= 1, y >>= 1;
}
int main(){
n = read(); int lim = 1 << n;
rep(i, 1, lim)s[i] = read();
rep(i, 1, lim)c[i][s[i] ^ 1] = read();
rep(i, 1, lim)rep(j, i + 1, lim){
int x = getlc(i, j), w = read(); cst[i][x] += w, cst[j][x] += w;
}
dfs(1, 1, lim, 1);
ll ans = INF; rep(i, 0, lim)upd(ans, dp[1][i]);
write(ans); return 0;
}