luogu P5243 [USACO19FEB]Moorio Kart P
https://www.luogu.com.cn/problem/P5243
大意就是给出若\(k\)棵树,每棵树选一条路径,然后花\(k x\)的长度把他们连成一个环
问环的长度\(\ge Y\)的方案数
首先肯定可以先把\(Y -= kx\)
然后就变成了一个经典问题,每个树暴力\(n_i^2\)跑出来所有的方案,路径长度最多\(min(n_i^2,Y)\)种
然后背包转移,一共\(k\)个,时间复杂度为\(O(\sum n_i^2 +\sum min(Y, n_i^2)Y)\)
时间复杂度上限为\(O(nY\sqrt{Y})\)
code:
#include<bits/stdc++.h>
#define ll long long
#define mod 1000000007
#define N 2550
using namespace std;
struct edge {
int v, nxt, c;
} e[N << 1];
int p[N], eid;
void init() {
memset(p, -1, sizeof p);
eid = 0;
}
void insert(int u, int v, int c) {
e[eid].v = v;
e[eid].c = c;
e[eid].nxt = p[u];
p[u] = eid ++;
}
int fa[N];
int get(int x) {
return fa[x] == x? x : (fa[x] = get(fa[x]));
}
void merge(int x, int y) {
x = get(x), y = get(y);
fa[x] = y;
}
ll sum[N][N], gs[N][N];
int n, m, X, Y;
void dfs(int u, int fa, int l) {
if(fa) (sum[get(u)][min(l, Y)] += l) %= mod, gs[get(u)][min(l, Y)] ++;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v, c = e[i].c;
if(v == fa) continue;
dfs(v, u, l + c);
}
}
ll f[N][2], g[N][2];
int main() {
init();
scanf("%d%d%d%d", &n, &m, &X, &Y);
for(int i = 1; i <= n; i ++) fa[i] = i;
for(int i = 1; i <= m; i ++) {
int u, v, c;
scanf("%d%d%d", &u, &v, &c);
insert(u, v, c), insert(v, u, c);
merge(u, v);
}
int sz = 0;
for(int i = 1; i <= n; i ++) if(get(i) == i) sz ++;
Y = max(0, Y - sz * X);
for(int i = 1; i <= n; i ++) dfs(i, 0, 0);
f[0][0] = 1, f[0][1] = 0;
for(int i = 1; i <= n; i ++) if(get(i) == i) {
for(int j = 0; j <= Y; j ++) g[j][0] = f[j][0], g[j][1] = f[j][1], f[j][0] = f[j][1] = 0;
for(int j = 0; j <= Y; j ++) if(gs[i][j]) {
for(int k = 0; k <= Y; k ++) if(g[k][0]) {
int o = min(j + k, Y);
(f[o][0] += g[k][0] * gs[i][j] % mod) %= mod;
(f[o][1] += g[k][0] * sum[i][j] % mod + g[k][1] * gs[i][j]) %= mod;
}
}
}
// for(int i = 1; i <= n; i ++) {
// for(int j = 0; j <= Y; j ++) printf("%lld ", gs[i][j]); printf("\n");
// }
// printf("%d %lld %lld\n", Y, f[Y][0], f[Y][1]);
ll ans = (f[Y][1] + f[Y][0] * sz % mod * X % mod) % mod;
for(int i = 1; i < sz; i ++) ans = ans * i % mod;
printf("%lld", ans * ((mod + 1) / 2) % mod);
return 0;
}