T1: 树的直径(二)

考虑由关键点构成的虚树,答案一定是 2

由于关键点中深度最深的点一定是直径的某个端点,所以只需找到这个点,然后遍历其他点,通过lca求出两点间的距离,取最大值即可。

也可以跑两遍dfs求虚树直径,具体做法如下:

从任意一个点(不需要是关键点)开始第一遍dfs,求出关键点中距离这个点最远的点 u(如果有多个,任取一个)。再从 u 开始进行第二遍dfs,求出关键点中距离 u 最远的点 v(如果有多个,任取一个)。则 (u,v) 成为虚树的一条直径。

注:这里没必要构建出虚树

代码实现1
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)
using namespace std;
template<typename T>
struct lca {
int n, l;
vector<vector<int>> to;
vector<vector<T>> co;
vector<int> dep;
vector<T> costs;
vector<vector<int>> par;
lca(int n): n(n), to(n), co(n), dep(n), costs(n) {
l = 0;
while (1<<l <= n) ++l;
par = vector<vector<int>>(n, vector<int>(l, -1));
}
void addEdge(int a, int b, T c=1) {
to[a].push_back(b); co[a].push_back(c);
to[b].push_back(a); co[b].push_back(c);
}
void dfs(int v, int d=0, T c=0, int p=-1) {
par[v][0] = p;
dep[v] = d;
costs[v] = c;
rep(i, to[v].size()) {
int u = to[v][i];
if (u == p) continue;
dfs(u, d+1, c+co[v][i], v);
}
}
void init(int root=0) {
dfs(root);
rep(i, l-1) {
rep(v, n) {
par[v][i+1] = par[v][i]==-1 ? -1 : par[par[v][i]][i];
}
}
}
// LCA
int operator()(int a, int b) {
if (dep[a] > dep[b]) swap(a, b);
int gap = dep[b]-dep[a];
for (int i = l-1; i >= 0; --i) {
int len = 1<<i;
if (gap >= len) {
gap -= len;
b = par[b][i];
}
}
if (a == b) return a;
for (int i = l-1; i >= 0; --i) {
int na = par[a][i];
int nb = par[b][i];
if (na != nb) a = na, b = nb;
}
return par[a][0];
}
int length(int a, int b) {
int c = this->operator()(a, b);
return dep[a]+dep[b]-dep[c]*2;
}
T dist(int a, int b) {
int c = this->operator()(a, b);
return costs[a]+costs[b]-costs[c]*2;
}
};
void solve() {
int n, k;
cin >> n >> k;
vector<int> vs(k);
rep(i, k) cin >> vs[i], vs[i]--;
lca<int> g(n);
rep(i, n-1) {
int u, v;
cin >> u >> v;
--u; --v;
g.addEdge(u, v);
}
g.init();
int maxd = -1, a = -1;
for (int v : vs) {
if (g.dep[v] > maxd) {
maxd = g.dep[v];
a = v;
}
}
int ans = 0;
for (int v : vs) {
ans = max(ans, g.dist(a, v));
}
ans = (ans+1)/2;
cout << ans << '\n';
}
int main() {
int t;
cin >> t;
while (t--) solve();
return 0;
}
代码实现2
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)
using namespace std;
using P = pair<int, int>;
void solve() {
int n, k;
cin >> n >> k;
vector<bool> selected(n);
rep(i, k) {
int a;
cin >> a;
--a;
selected[a] = true;
}
vector<vector<int>> to(n);
rep(i, n-1) {
int u, v;
cin >> u >> v;
--u; --v;
to[u].push_back(v);
to[v].push_back(u);
}
auto dfs = [&](auto& f, int v, int d=0, int p=-1) -> P {
auto res = selected[v] ? P(d, v) : P(0, -1);
for (int u : to[v]) {
if (u == p) continue;
res = max(res, f(f, u, d+1, v));
}
return res;
};
int a = dfs(dfs, 0).second;
int diameter = dfs(dfs, a).first;
int ans = (diameter+1)/2;
cout << ans << '\n';
}
int main() {
int t;
cin >> t;
while (t--) solve();
return 0;
}

T2:简单 MST

对于 g=1,2,,r,遍历 [l,r]g 的所有倍数,找到 w 最小的那个,然后将它和剩下的倍数连边,接下来求MST即可

代码实现
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
struct UnionFind {
vector<int> d;
UnionFind(int n = 0): d(n, -1) {}
int find(int x) {
if (d[x] < 0) return x;
return d[x] = find(d[x]);
}
bool unite(int x, int y) {
x = find(x); y = find(y);
if (x == y) return false;
if (d[x] > d[y]) swap(x, y);
d[x] += d[y];
d[y] = x;
return true;
}
bool same(int x, int y) {
return find(x) == find(y);
}
int size(int x) {
return -d[find(x)];
}
};
void solve() {
int l, r;
cin >> l >> r;
vector<int> w(r+1);
for (int i = 2; i <= r; ++i) {
if (w[i]) continue;
for (int j = i; j <= r; j += i) {
w[j]++;
}
}
vector<vector<pair<int, int>>> es(15);
for (int g = 1; g <= r; ++g) {
int x = (l+g-1)/g*g;
int a = x;
for (int i = x; i <= r; i += g) {
if (w[a] > w[i]) {
a = i;
}
}
for (int b = x; b <= r; b += g) {
int c = w[a]+w[b]-w[gcd(a, b)];
es[c].emplace_back(a, b);
}
}
ll ans = 0;
UnionFind uf(r+1);
for (int c = 1; c <= 14; ++c) {
for (auto [a, b] : es[c]) {
if (uf.unite(a, b)) ans += c;
}
}
cout << ans << '\n';
}
int main() {
int t;
cin >> t;
while (t--) solve();
return 0;
}

T3: 序列切割

p(l,r,i) 表示操作 i 次后只剩下 a[l,r] 的概率,则答案为 p(l,r,k)×(al+al+1++ar)

初值为 p(1,n,0)=1p(l,r,0)=0

对于 i1,考虑转移:

SiLp(l,r,i)=x=r+1np(l,x,i1)1xl

l+r,还有一个额外的 p(l,l,i1) 贡献表示单元素序列不可分割

SiR,转移是类似的。

此时状态数为 O(n2k),转移次数为 O(n),总复杂度为 O(n3k)

前缀和优化即可。O(n2k)

代码实现
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < (n); ++i)
using namespace std;
using ll = long long;
//const int mod = 998244353;
const int mod = 1000000007;
struct mint {
ll x;
mint(ll x=0):x((x%mod+mod)%mod) {}
mint operator-() const {
return mint(-x);
}
mint& operator+=(const mint a) {
if ((x += a.x) >= mod) x -= mod;
return *this;
}
mint& operator-=(const mint a) {
if ((x += mod-a.x) >= mod) x -= mod;
return *this;
}
mint& operator*=(const mint a) {
(x *= a.x) %= mod;
return *this;
}
mint operator+(const mint a) const {
return mint(*this) += a;
}
mint operator-(const mint a) const {
return mint(*this) -= a;
}
mint operator*(const mint a) const {
return mint(*this) *= a;
}
mint pow(ll t) const {
if (!t) return 1;
mint a = pow(t>>1);
a *= a;
if (t&1) a *= *this;
return a;
}
// for prime mod
mint inv() const {
return pow(mod-2);
}
mint& operator/=(const mint a) {
return *this *= a.inv();
}
mint operator/(const mint a) const {
return mint(*this) /= a;
}
};
istream& operator>>(istream& is, mint& a) {
return is >> a.x;
}
ostream& operator<<(ostream& os, const mint& a) {
return os << a.x;
}
mint inv[505];
void solve() {
int n, k;
cin >> n >> k;
vector<int> a(n);
rep(i, n) cin >> a[i];
string s;
cin >> s;
vector dp(n, vector<mint>(n+1));
dp[0][n-1] = 1;
rep(i, k) {
vector old(n, vector<mint>(n+1));
swap(dp, old);
rep(l, n)for (int r = l; r < n; ++r) {
if (l == r) {
dp[l][r] += old[l][r];
if (s[i] == 'L') dp[l][r+1] -= old[l][r];
else if (l > 0) dp[l-1][l] -= old[l][r];
}
else {
mint val = old[l][r] * inv[r-l];
if (s[i] == 'L') dp[l][l] += val;
else dp[r][r] += val;
dp[l][r] -= val;
}
}
if (s[i] == 'L') {
rep(l, n) {
for (int r = 1; r < n; ++r) {
dp[l][r] += dp[l][r-1];
}
}
}
else {
rep(r, n) {
for (int l = n-2; l >= 0; --l) {
dp[l][r] += dp[l+1][r];
}
}
}
}
mint ans;
rep(l, n) {
mint sum;
for (int r = l; r < n; ++r) {
sum += a[r];
ans += sum*dp[l][r];
}
}
cout << ans << '\n';
}
int main() {
rep(i, 500) inv[i] = mint(i).inv();
int t;
cin >> t;
while (t--) solve();
return 0;
}