链接
Description
小王来到了一片森林,森林中有一些树和连接两棵树的无向道路,保证这些道路能把森林连通。
小王对这片森林做了一些考察,有了两个奇怪的发现:
1)森林中的树总共分为两种,不妨记为 0 型树和 1 型树。
2)这些道路的长度都是 2 的整数次幂且互不相同,第 i 条道路的长度为 \(2_i\)。
小王又发现了这片森林的一个神奇之处,任何两棵类型不同的树之间都可以构成一组链接,这一对链接的能量值为两棵树之间的最短路。
好奇的小王想知道这片森林所有链接的能量值之和,请你来帮帮他。
思路
因为每一条边的边权是\(2^i\)因为\(\sum_{i=0}^{n-1}2^i = 2^n-1\)看看二进制就明白了
然后我们可以发现如果两个点已经联通了,如果再加一条边,那条边的权值一定比最短路长,所以可以直接跑\(kruskal\)求生成树,然后在生成树上稿下边的操作。
因为我们已经有了生成树了,所以可以直接树形\(DP\)(换根可以做),我们可以直接把\(1\)结点点权看做\(0\),然后将点权和\(1\)相同的看成\(0\),不同的看成\(1\),然后我们设\(siz_x\)为以\(1\)位根节点时\(x\)的子树上有多少点权为\(1\)的点,\(dp_x\)为结点\(x\)到所有点权为\(1\)的点的权值。
我们第一遍\(dfs\)求出\(siz\)和\(dp1\),然后考虑如何转移到别的结点上,对于\(x\)的子节点\(to\)我们可以发现,\(dp_{to}=dp_x+(k-2*siz_{to})*dis\)
code
#include <bits/stdc++.h>
#define N 100010
#define M 1010
#define int long long
using namespace std;
const int mod = 1e9 + 7;
int n, m, add_edge, fath[N], cnt, dit[N], point;
int siz[N], ye[N], head[N << 1], dp[N];
struct node {
int next, to, dis;
} edge[N << 1];
int read() {
int s = 0, f = 0; char ch = getchar();
while (!isdigit(ch)) f |= (ch == '-'), ch = getchar();
while (isdigit(ch)) s = s * 10 + (ch ^ 48), ch = getchar();
return f ? -s : s;
}
void add(int from, int to, int dis) {
edge[++add_edge].next = head[from];
edge[add_edge].to = to;
edge[add_edge].dis = dis;
head[from] = add_edge;
}
int father(int a) {
if (fath[a] != a) fath[a] = father(fath[a]);
return fath[a];
}
int q_pow(int a, int b) {
int ans = 1;
while (b) {
if (b & 1) ans = (ans * a) % mod;
a = (a * a) % mod;
b >>= 1;
}
return ans;
}
void dfs(int x, int fa) {
for (int i = head[x]; i; i = edge[i].next) {
int to = edge[i].to;
if (to == fa) continue;
dit[to] = (dit[x] + edge[i].dis) % mod;
dfs(to, x), siz[x] += siz[to];
}
}
void dfs2(int x, int fa) {
for (int i = head[x]; i; i = edge[i].next) {
int to = edge[i].to;
if (to == fa) continue;
int jia = ((point - siz[to]) % mod + mod) % mod, jian = (siz[to] * edge[i].dis) % mod;
jia = (jia * edge[i].dis) % mod;
dp[to] = ((dp[x] + jia - jian) % mod + mod) % mod;
dfs2(to, x);
}
}
signed main() {
n = read(), m = read();
ye[1] = read();
int sy = 0;
if (ye[1] == 1) sy = 1, ye[1] ^= 1;
for (int i = 2, x; i <= n; i++) {
x = read();
if (sy) ye[i] = (x ^ 1);
else ye[i] = x;
siz[i] = ye[i];
}
for (int i = 1; i <= n; i++) fath[i] = i;
for (int i = 1, x, y; i <= m; i++) {
x = read(), y = read();
int fx = father(x), fy = father(y), d = q_pow(2, i);
if (fx != fy) {
fath[fx] = fy, cnt++;
add(x, y, d), add(y, x, d);
}
if (cnt == n - 1) break;
}
dfs(1, 1), point = siz[1];
for (int i = 1; i <= n; i++)
if (ye[i] == 1) dp[1] = (dp[1] + dit[i]) % mod;
dfs2(1, 1);
int ans = 0;
for (int i = 1; i <= n; i++)
if (ye[i] == 0) ans = (ans + dp[i]) % mod;
cout << ans;
return 0;
}