【luogu AT3728】Squirrel Migration(思维)(扩展:分治NTT)
Squirrel Migration
题目链接:luogu AT3728
题目大意
给你一棵 n 个点的树,问你有多少个长度为 n 的排列是所有排列中权值最大的。
一个排列的权值是它每个位置跟它下标在树上距离的和。
思路
考虑先看怎样会有最大的贡献,发现不好想想,于是考虑改变统计方法。
考虑每条边的贡献,那它有一个上界,就是 \(2\min(S_1,S_2)\),其中 \(S_1,S_2\) 是把它删掉之后两个连通块的大小。
那一个很好的想法就是用重心做根,因为每个子树大小都不会大于它补数的大小,那问题就是每个点对应的连着的边就是一定不能在它所在的子树内。
那考虑容斥,有 \(i\) 个连向自己就是 \(\binom{S}{i}^2i!\)(\(S\) 是所在子树大小,就是选谁连,连向谁,怎么安排匹配)
那如果总共有 \(i\) 个连向了自己,得到的权值是 \(f_i\),那给答案的贡献就是 \((-1)^if_i(n-i)!\)(剩下的自己随便匹配,因为是容斥)
那在 ATcoder 上这个你背包转移一下就可以,不过不难发现是卷积形式,所以可以暴力卷积。
那如果你把模数改成 \(998244353\) 的话就就可以分治 NTT 跑更大的数据。
如果不是这样的质数用 MTT 也不是不行。
代码
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
//#define mo 998244353
//如果要分治 NTT 要模数满足,否则就要分治 MTT
#define mo 1000000007
#define cpy(f, g, x) memcpy(f, g, sizeof(int) * (x))
#define clr(f, x) memset(f, 0, sizeof(int) * (x))
using namespace std;
const int N = 1e5 + 100;
const int pN = N * 8;
struct node {
int to, nxt;
}e[N << 1];
int n, le[N], KK, sz[N], max_root, root;
int f[pN], g[pN];
int jc[N], inv[N], invs[N];
int add(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int dec(int x, int y) {return x < y ? x - y + mo : x - y;}
int mul(int x, int y) {return 1ll * x * y % mo;}
int ksm(int x, int y) {int re = 1; while (y) {if (y & 1) re = mul(re, x); x = mul(x, x); y >>= 1;} return re;}
int C(int n, int m) {if (m < 0 || m > n) return 0; return mul(jc[n], mul(invs[m], invs[n - m]));}
void Init() {
jc[0] = 1; for (int i = 1; i < N; i++) jc[i] = mul(jc[i - 1], i);
inv[0] = inv[1] = 1; for (int i = 2; i < N; i++) inv[i] = mul(inv[mo % i], mo - mo / i);
invs[0] = 1; for (int i = 1; i < N; i++) invs[i] = mul(invs[i - 1], inv[i]);
}
void Add(int x, int y) {e[++KK] = (node){y, le[x]}; le[x] = KK;}
void dfs0(int now, int father) {
sz[now] = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
dfs0(e[i].to, now); sz[now] += sz[e[i].to];
}
}
void find_root(int now, int father, int sum) {
int maxn = sum - sz[now];
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) {
find_root(e[i].to, now, sum);
maxn = max(maxn, sz[e[i].to]);
}
if (maxn < max_root) max_root = maxn, root = now;
}
int dfs1(int now, int father) {
int re = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father) re += dfs1(e[i].to, now);
return re;
}
//struct Poly {
// int G, Gv, an[pN];
// vector <int> D[31], Dv[31];
//
// void Init() {
// G = 3; Gv = ksm(G, mo - 2);
// for (int i = 1, d = 0; i < pN; i <<= 1, d++) {
// int Gs = ksm(G, (mo - 1) / (i << 1)), Gvs = ksm(Gv, (mo - 1) / (i << 1));
// for (int j = 0, w = 1, wv = 1; j < i; j++, w = mul(w, Gs), wv = mul(wv, Gvs))
// D[d].push_back(w), Dv[d].push_back(wv);
// }
// }
//
// void get_an(int limit, int l_size) {
// for (int i = 0; i < limit; i++)
// an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1));
// }
//
// void NTT(int *f, int op, int limit) {
// for (int i = 0; i < limit; i++) if (an[i] < i) swap(f[i], f[an[i]]);
// for (int mid = 1, d = 0; mid < limit; mid <<= 1, d++) {
//// int Wn = ksm((op == 1) ? G : Gv, (mo - 1) / (mid << 1));
//// int Wn = (op == 1) ? Gs[mid] : Gvs[mid];
// for (int j = 0, R = mid << 1; j < limit; j += R)
//// for (int k = 0, w = 1; k < mid; k++, w = mul(w, Wn)) {
//// int x = f[j | k], y = mul(w, f[j | mid | k]);
// for (int k = 0; k < mid; k++) {
//// int x = f[j | k], y = mul(tmp[k], f[j | mid | k]);
// int x = f[j | k], y = mul((op == 1) ? D[d][k] : Dv[d][k], f[j | mid | k]);
// f[j | k] = add(x, y); f[j | mid | k] = dec(x, y);
// }
// }
// if (op == -1) {
// int limv = ksm(limit, mo - 2);
// for (int i = 0; i < limit; i++) f[i] = mul(f[i], limv);
// }
// }
//
// void px(int *f, int *g, int limit) {
// for (int i = 0; i < limit; i++) f[i] = mul(f[i], g[i]);
// }
//
// void times(int *f, int *g, int n, int m, int T) {
// int limit = 1, l_size = 0;
// while (limit < n + m) limit <<= 1, l_size++;
// get_an(limit, l_size);
// clr(f + n, limit - n); clr(g + m, limit - m);
// static int tmp[pN]; cpy(tmp, g, m); clr(tmp + m, limit - m);
// NTT(f, 1, limit); NTT(tmp, 1, limit);
// px(f, tmp, limit); NTT(f, -1, limit);
// clr(f + T, limit - T); clr(tmp, limit);
// }
//}P;
//分治 NTT 的扩展
vector <int> Mul(vector <int> X, vector <int> Y) {
// for (int i = 0; i < X.size(); i++) f[i] = X[i];
// for (int i = 0; i < Y.size(); i++) g[i] = Y[i];
// P.times(f, g, X.size(), Y.size(), X.size() + Y.size() - 1);
// vector <int> Z; Z.resize(X.size() + Y.size());
// for (int i = 0; i < Z.size(); i++) Z[i] = f[i];
// while (Z.size() && Z.back() == 0) Z.pop_back();
// return Z;
vector <int> Z; Z.resize(X.size() + Y.size());
for (int i = 0; i < X.size(); i++)
for (int j = 0; j < Y.size(); j++)
Z[i + j] = add(Z[i + j], mul(X[i], Y[j]));
return Z;
}
struct cmp {
bool operator ()(vector <int> x, vector <int> y) {
return x.size() > y.size();
}
};
priority_queue <vector <int>, vector <vector <int> >, cmp> q;
vector <int> clacF() {
while (q.size() > 1) {
vector <int> X = q.top(); q.pop();
vector <int> Y = q.top(); q.pop();
// for (int i = 0; i < X.size(); i++) printf("%d ", X[i]); printf("\n");
// for (int i = 0; i < Y.size(); i++) printf("%d ", Y[i]); printf("\n");
q.push(Mul(X, Y));
}
return q.top();
}
int main() {
scanf("%d", &n); Init();
for (int i = 1; i < n; i++) {
int x, y; scanf("%d %d", &x, &y);
Add(x, y); Add(y, x);
}
dfs0(1, 0); max_root = n + 1; find_root(1, 0, n);
for (int i = le[root]; i; i = e[i].nxt) {
int siz = dfs1(e[i].to, root);
vector <int> tmp;
for (int j = 0; j <= siz; j++) tmp.push_back(mul(mul(C(siz, j), C(siz, j)), jc[j]));
q.push(tmp);
}
// P.Init();
vector <int> f = clacF(); int ans = 0;
for (int i = 0, di = 1; i < n; i++, di = mo - di) {
ans = add(ans, mul(di, mul(f[i], jc[n - i])));
}
printf("%d", ans);
return 0;
}