CF1401D 最大分布树
1 CF1401D 最大分布树
2 题目描述
给出一个由 \(𝑛\) 个节点组成的树。用满足以下条件的方式给它的 \(𝑛−1\) 条边赋予权值:
- 每个权值必须大于0;
- 所有 \(𝑛−1\) 个权值的乘积应等于 \(𝑘\);
- 所有 \(𝑛−1\) 个权值中 \(1\) 的个数必须是可能的最小值。
让我们将 \(𝑓(𝑢,𝑣)\) 定义为从节点 \(𝑢\) 到节点 \(𝑣\) 的简单路径上的数字之和。另外,
定义 \(\begin{matrix} \sum_{i=1}^{n-1} \sum_{j=i+1}^n f(i,j) \end{matrix}\) 是树的分布指数,找出最大可能分布指数。由于答案可能太大,输出模 \(10^9 + 7\) 的结果。由于 \(𝑘\) 可能非常大,所以给出的是 \(𝑘\) 的质数因子。
3 题解
我们先来看一个性质:当 \(a < b\) 且 \(c < d\) 时,\(ad + bc < bd + ac\)。推导如下:
\[\because a < b \\
\]
\[\therefore b - a > 0 \\
\]
\[\because c < d, b-a > 0 \\
\]
\[\therefore (b-a)c < (b-a)d \\
\]
\[\therefore bc - ac < bd - ad \\
\]
\[\therefore ad + bc < bd + ac
\]
这个式子的意思是,用大的数乘以大的数,再用次大的数乘以次大的数,以此类推得到的和是最大的,这证明了这道题可以用贪心。
回到本题,我们发现:一条边对于答案的贡献为 左边节点数 \(*\)右边节点数\(*\) 边权。而每条边的左边节点数与右边节点数已经给出,所以我们可以把所有的边的左边节点数 \(*\) 右边节点数这一定值放入优先队列中。根据上述性质,我们只需要将所有质因子也递减排序,然后用最大的质因子乘以优先队列中最大的数,次大的质因子乘以次大的数,以此类推即可。
这里需要注意:如果 \(m < n-1\),即质因子个数比边数少,那么由于质数无法拆解,我们就需要把一些边的权值变为 \(1\)。如果 \(m > n-1\),即质因子个数比边数多,那么就需要将最大的 \(m - n + 2\) 个质因子乘在一起,当作第一个因子。
4 代码(空格警告):
#include <iostream>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 2e6 + 10;
const int mod = 1e9+7;
#define int long long
int T, n, m, u, v, tot, root, ans, cnt;
int head[N], ver[N], last[N], size[N];
int p[N], w[N];
priority_queue<int> q;
bool cmp(int x, int y)
{
return x > y;
}
void add(int x, int y)
{
ver[++tot] = y;
last[tot] = head[x];
head[x] = tot;
}
void dfs(int x, int fa)
{
size[x] = 1;
for (int i = head[x]; i; i = last[i])
{
int y = ver[i];
if (y == fa) continue;
dfs(y, x);
size[x] += size[y];
}
}
void dfs1(int x, int fa)
{
for (int i = head[x]; i; i = last[i])
{
int y = ver[i];
if (y == fa) continue;
q.push(size[y] * (n - size[y]));
dfs1(y, x);
}
}
signed main()
{
cin >> T;
while (T--)
{
ans = 0;
cnt = 0;
tot = 0;
memset(head, 0, sizeof(head));
memset(ver, 0, sizeof(ver));
memset(last, 0, sizeof(last));
cin >> n;
for (int i = 1; i < n; i++)
{
cin >> u >> v;
add(u, v);
add(v, u);
}
cin >> m;
for (int i = 1; i <= m; i++) cin >> p[i];
sort(p+1, p+m+1, cmp);
root = 1;
dfs(root, 0);
dfs1(root, 0);
if (m < n-1)
{
for (int i = 1; i <= m; i++) w[i] = p[i];
for (int i = m+1; i < n; i++) w[i] = 1;
}
if (m > n-1)
{
w[1] = 1;
for (int i = 1; i <= m-n+2; i++) w[1] = (w[1] * p[i]) % mod;
for (int i = 2; i <= n-1; i++) w[i] = p[m-n+i+1];
}
if (m == n-1) for (int i = 1; i < n; i++) w[i] = p[i];
while (q.size())
{
ans = (ans + (q.top() * w[++cnt]) % mod) % mod;
q.pop();
}
cout << ans << endl;
}
return 0;
}