【luogu AT3728】Squirrel Migration(思维)(扩展:分治NTT)

Squirrel Migration

题目链接:luogu AT3728

题目大意

给你一棵 n 个点的树,问你有多少个长度为 n 的排列是所有排列中权值最大的。
一个排列的权值是它每个位置跟它下标在树上距离的和。

思路

考虑先看怎样会有最大的贡献,发现不好想想,于是考虑改变统计方法。
考虑每条边的贡献,那它有一个上界,就是 \(2\min(S_1,S_2)\),其中 \(S_1,S_2\) 是把它删掉之后两个连通块的大小。

那一个很好的想法就是用重心做根,因为每个子树大小都不会大于它补数的大小,那问题就是每个点对应的连着的边就是一定不能在它所在的子树内。
那考虑容斥,有 \(i\) 个连向自己就是 \(\binom{S}{i}^2i!\)\(S\) 是所在子树大小,就是选谁连,连向谁,怎么安排匹配)

那如果总共有 \(i\) 个连向了自己,得到的权值是 \(f_i\),那给答案的贡献就是 \((-1)^if_i(n-i)!\)(剩下的自己随便匹配,因为是容斥)
那在 ATcoder 上这个你背包转移一下就可以,不过不难发现是卷积形式,所以可以暴力卷积。

那如果你把模数改成 \(998244353\) 的话就就可以分治 NTT 跑更大的数据。
如果不是这样的质数用 MTT 也不是不行。

代码

#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
//#define mo 998244353
//如果要分治 NTT 要模数满足,否则就要分治 MTT
#define mo 1000000007
#define cpy(f, g, x) memcpy(f, g, sizeof(int) * (x))
#define clr(f, x) memset(f, 0, sizeof(int) * (x))

using namespace std;

const int N = 1e5 + 100;
const int pN = N * 8;
struct node {
	int to, nxt;
}e[N << 1];
int n, le[N], KK, sz[N], max_root, root;
int f[pN], g[pN];
int jc[N], inv[N], invs[N];

int add(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int dec(int x, int y) {return x < y ? x - y + mo : x - y;}
int mul(int x, int y) {return 1ll * x * y % mo;}
int ksm(int x, int y) {int re = 1; while (y) {if (y & 1) re = mul(re, x); x = mul(x, x); y >>= 1;} return re;}
int C(int n, int m) {if (m < 0 || m > n) return 0; return mul(jc[n], mul(invs[m], invs[n - m]));}

void Init() {
	jc[0] = 1; for (int i = 1; i < N; i++) jc[i] = mul(jc[i - 1], i);
	inv[0] = inv[1] = 1; for (int i = 2; i < N; i++) inv[i] = mul(inv[mo % i], mo - mo / i);
	invs[0] = 1; for (int i = 1; i < N; i++) invs[i] = mul(invs[i - 1], inv[i]);
}

void Add(int x, int y) {e[++KK] = (node){y, le[x]}; le[x] = KK;}

void dfs0(int now, int father) {
	sz[now] = 1;
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			dfs0(e[i].to, now); sz[now] += sz[e[i].to];
		}
}

void find_root(int now, int father, int sum) {
	int maxn = sum - sz[now];
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			find_root(e[i].to, now, sum);
			maxn = max(maxn, sz[e[i].to]);
		}
	if (maxn < max_root) max_root = maxn, root = now;
}

int dfs1(int now, int father) {
	int re = 1;
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) re += dfs1(e[i].to, now);
	return re;
}

//struct Poly {
//	int G, Gv, an[pN];
//	vector <int> D[31], Dv[31];
//	
//	void Init() {
//		G = 3; Gv = ksm(G, mo - 2);
//		for (int i = 1, d = 0; i < pN; i <<= 1, d++) {
//			int Gs = ksm(G, (mo - 1) / (i << 1)), Gvs = ksm(Gv, (mo - 1) / (i << 1));
//			for (int j = 0, w = 1, wv = 1; j < i; j++, w = mul(w, Gs), wv = mul(wv, Gvs))
//				D[d].push_back(w), Dv[d].push_back(wv);
//		}
//	}
//	
//	void get_an(int limit, int l_size) {
//		for (int i = 0; i < limit; i++)
//			an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1));
//	}
//	
//	void NTT(int *f, int op, int limit) {
//		for (int i = 0; i < limit; i++) if (an[i] < i) swap(f[i], f[an[i]]);
//		for (int mid = 1, d = 0; mid < limit; mid <<= 1, d++) {
////			int Wn = ksm((op == 1) ? G : Gv, (mo - 1) / (mid << 1));
////			int Wn = (op == 1) ? Gs[mid] : Gvs[mid];
//			for (int j = 0, R = mid << 1; j < limit; j += R)
////				for (int k = 0, w = 1; k < mid; k++, w = mul(w, Wn)) {
////					int x = f[j | k], y = mul(w, f[j | mid | k]);
//				for (int k = 0; k < mid; k++) {
////					int x = f[j | k], y = mul(tmp[k], f[j | mid | k]);
//					int x = f[j | k], y = mul((op == 1) ? D[d][k] : Dv[d][k], f[j | mid | k]);
//					f[j | k] = add(x, y); f[j | mid | k] = dec(x, y);
//				}
//		}
//		if (op == -1) {
//			int limv = ksm(limit, mo - 2);
//			for (int i = 0; i < limit; i++) f[i] = mul(f[i], limv); 
//		}
//	}
//	
//	void px(int *f, int *g, int limit) {
//		for (int i = 0; i < limit; i++) f[i] = mul(f[i], g[i]);
//	}
//	
//	void times(int *f, int *g, int n, int m, int T) {
//		int limit = 1, l_size = 0;
//		while (limit < n + m) limit <<= 1, l_size++;
//		get_an(limit, l_size);
//		clr(f + n, limit - n); clr(g + m, limit - m);
//		static int tmp[pN]; cpy(tmp, g, m); clr(tmp + m, limit - m);
//		NTT(f, 1, limit); NTT(tmp, 1, limit);
//		px(f, tmp, limit); NTT(f, -1, limit);
//		clr(f + T, limit - T); clr(tmp, limit);
//	}
//}P;
//分治 NTT 的扩展

vector <int> Mul(vector <int> X, vector <int> Y) {
//	for (int i = 0; i < X.size(); i++) f[i] = X[i];
//	for (int i = 0; i < Y.size(); i++) g[i] = Y[i];
//	P.times(f, g, X.size(), Y.size(), X.size() + Y.size() - 1);
//	vector <int> Z; Z.resize(X.size() + Y.size());
//	for (int i = 0; i < Z.size(); i++) Z[i] = f[i];
//	while (Z.size() && Z.back() == 0) Z.pop_back();
//	return Z;
	
	vector <int> Z; Z.resize(X.size() + Y.size());
	for (int i = 0; i < X.size(); i++)
		for (int j = 0; j < Y.size(); j++)
			Z[i + j] = add(Z[i + j], mul(X[i], Y[j]));
	return Z;
}

struct cmp {
	bool operator ()(vector <int> x, vector <int> y) {
		return x.size() > y.size();
	}
};
priority_queue <vector <int>, vector <vector <int> >, cmp> q;

vector <int> clacF() {
	while (q.size() > 1) {
		vector <int> X = q.top(); q.pop();
		vector <int> Y = q.top(); q.pop();
//		for (int i = 0; i < X.size(); i++) printf("%d ", X[i]); printf("\n");
//		for (int i = 0; i < Y.size(); i++) printf("%d ", Y[i]); printf("\n");
		q.push(Mul(X, Y));
	}
	return q.top();
}

int main() {
	scanf("%d", &n); Init();
	for (int i = 1; i < n; i++) {
		int x, y; scanf("%d %d", &x, &y);
		Add(x, y); Add(y, x);
	}
	
	dfs0(1, 0); max_root = n + 1; find_root(1, 0, n);
	for (int i = le[root]; i; i = e[i].nxt) {
		int siz = dfs1(e[i].to, root);
		vector <int> tmp;
		for (int j = 0; j <= siz; j++) tmp.push_back(mul(mul(C(siz, j), C(siz, j)), jc[j]));
		q.push(tmp);
	}
	
//	P.Init();
	vector <int> f = clacF(); int ans = 0;
	for (int i = 0, di = 1; i < n; i++, di = mo - di) {
		ans = add(ans, mul(di, mul(f[i], jc[n - i])));
	}
	printf("%d", ans);
	
	return 0;
}
posted @ 2022-09-27 10:41  あおいSakura  阅读(13)  评论(0编辑  收藏  举报