Codeforces 1666J. Job Lookup
题目大意
一个 \(n\times n(1\leq n\leq200)\) 的矩阵 \(c(0\leq c_{ij}\leq10^9)\) ,构造一棵节点编号为 \(1~n\) 的二叉树,其任意一个节点的左子树内所有节点编号都小于它,右子树内所有节点编号都大于它,设 \(d_{ij}\) 为 \(i\sim j\) 的最短路径,使得 \(\sum_{1\leq i<j\leq n}c_{ij}d_{ij}\) 最小,输出每个节点的根(根节点输出 \(0\) )。
思路
由题意可知我们要构造出的二叉树的中序遍历就是 \(1\sim n\) ,于是我们可以在这上面进行区间 \(dp\) ,我们考虑计算每一条边对答案的贡献,设 \(f_{i,j}\) 为考虑节点 \(i\sim j\) 所构成的子树时贡献的最小值。在区间内枚举子树根节点 \(k\) 进行转移,首先 \(f_{i,j}\) 会直接包含 \(f_{i,k-1}+f_{k+1,j}\) 之后对于左子树的根向 \(k\) 的连边,其贡献就为 \(\sum_{1\leq x\leq k-1,k\leq y\leq n}c_{xy}\) ,对于右子树连向 \(k\) 的边同理,我们可以对 \(c\) 预处理二维前缀和来求得边的贡献,在转移时记录当前区间最优时的根,最后 \(f_{1,n}\) 的值就是最小值,我们从 \([l,r]\) 再进行一次 \(dfs\) 就可以求出树中每个节点的根,复杂度 \(O(n^3)\) 。
代码
#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
//#define int LL
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
#pragma warning(disable :4996)
const double eps = 1e-8;
const LL mod = 1000000007;
const LL MOD = 998244353;
const int maxn = 310;
LL N, C[maxn][maxn], f[maxn][maxn], S[maxn][maxn], root[maxn][maxn], fa[maxn];
LL sum(int x1, int y1, int x2, int y2)
{
return S[x2][y2] - S[x1 - 1][y2] - S[x2][y1 - 1] + S[x1 - 1][y1 - 1];
}
int dfs(int l, int r)
{
if (l == r)
return l;
int k = root[l][r];
if (l <= k - 1)
{
int lf = dfs(l, k - 1);
fa[lf] = k;
}
if (k + 1 <= r)
{
int rf = dfs(k + 1, r);
fa[rf] = k;
}
return k;
}
void solve()
{
for (int i = 1; i <= N; i++)
{
for (int j = 1; j <= N; j++)
S[i][j] = S[i - 1][j] + S[i][j - 1] - S[i - 1][j - 1] + C[i][j];
}
for (int i = 1; i <= N; i++)
root[i][i] = i, f[i][i] = 0;
for (int i = 2; i <= N; i++)
{
for (int l = 1; l + i - 1 <= N; l++)
{
int r = l + i - 1;
for (int k = l; k <= r; k++)
{
LL tmp;
if (k == l)
tmp = f[k + 1][r] + sum(k + 1, 1, r, N) - sum(k + 1, k + 1, r, r);
else if (k == r)
tmp = f[l][k - 1] + sum(l, 1, k - 1, N) - sum(l, l, k - 1, k - 1);
else
tmp = f[l][k - 1] + sum(l, 1, k - 1, N) - sum(l, l, k - 1, k - 1) + f[k + 1][r] + sum(k + 1, 1, r, N) - sum(k + 1, k + 1, r, r);
if (tmp < f[l][r])
f[l][r] = tmp, root[l][r] = k;
}
}
}
dfs(1, N);
for (int i = 1; i <= N; i++)
cout << fa[i] << ' ';
cout << endl;
}
int main()
{
IOS;
cin >> N;
memset(f, INF, sizeof(f));
for (int i = 1; i <= N; i++)
{
for (int j = 1; j <= N; j++)
cin >> C[i][j];
}
solve();
return 0;
}