题解 LA4390
题目大意 多组数据,每组数据给定两个个正整数 \(n, m\) 和一棵 \(n\) 个节点的树,输出给树标号使儿子的编号大于父亲的编号的方案数对 \(m\) 取模的值,不保证 \(m\) 是质数。
分析 考虑这样一棵子树,它的所有根节点的大小 \(siz[son[i]]\) 和标号方案数 \(ans[son[i]]\) 都已知,那么可以得到,整棵树的方案数等于所有子树的方案数的乘积乘上对每个子树分配不同编号的总的方案数。换句话说,假如整棵树的编号为 \(1-tot\),则每个字树的编号对应着 \([2,tot]\) 中的 \(siz[son[i]]\) 个数,那么分配标号的方案数就是
\[\prod_{i=1}^{sonnum}C_{\sum_{j=i}^{n}siz[son[i]]}^{siz[son[i]]}
\]
那么可以得到,总的方案数为
\[\prod_{i=1}^{sonnum}ans[son[i]]C_{\sum_{j=i}^{n}siz[son[i]]}^{siz[son[i]]}
\]
对这个表达式化简代入后我们可以得到一个很漂亮的表达式
\[ans=\frac{n!}{\prod_{i=1}^nsiz[i]}
\]
由于模数不一定是质数,我们需要分别把分子和分母中 \(m\) 的质因数先去除,再逆元,最后再乘上这些质因数,还要注意这道题 DFS 可能会 RE,所以最好用 BFS 处理。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5E+5 + 5;
int T, n, mod, tot, ans;
int siz[maxn];
int head, tail, que[maxn];
bool inque[maxn];
vector<int> son[maxn];
pair<int, int> d[100];
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(!b) return x = 1, y = 0, a;
ll d = exgcd(b, a % b, y, x);
return y -= a / b * x, d;
}
ll Inv(ll x)
{
ll res, y;
exgcd(x, mod, res, y);
return (res % mod + mod) % mod;
}
ll qPow(ll a, ll b)
{
a %= mod;
ll res = 1;
while(b) {
if(b & 1) res = res * a % mod;
a = a * a % mod, b >>= 1;
}
return res;
}
void BFS()
{
memset(inque, false, sizeof inque);
head = 0, que[tail = 1] = 1, inque[1] = 1;
while(head < tail) {
++head;
for(int i = 0; i < son[que[head]].size(); ++i) {
if(!inque[son[que[head]][i]]) {
que[++tail] = son[que[head]][i];
inque[son[que[head]][i]] = 1;
}
}
}
for(int i = n; i >= 1; --i) {
int u = que[i];
siz[u] = 1;
for(int j = 0; j < son[u].size(); ++j)
siz[u] += siz[son[u][j]];
}
}
int main()
{
scanf("%d", &T);
while(T--) {
tot = 0, ans = 0;
scanf("%d%d", &n, &mod);
for(int i = 1; i <= n; ++i)
son[i].clear();
ll x;
for(int i = 2; i <= n; ++i)
scanf("%lld", &x), son[x].push_back(i);
BFS();
x = mod;
for(int i = 2; i * i <= x; ++i) {
if(x % i == 0) {
d[++tot] = { i, 0 };
while(x % i == 0) x /= i;
}
}
if(x > 1) d[++tot] = { x, 0 };
ll p = 1, q = 1;
for(int i = 1; i <= n; ++i) {
x = i;
for(int j = 1; j <= tot && x > 1; ++j)
while(x % d[j].first == 0) x /= d[j].first, ++d[j].second;
p = p * x % mod;
x = siz[i];
for(int j = 1; j <= tot && x > 1; ++j)
while(x % d[j].first == 0) x /= d[j].first, --d[j].second;
q = q * x % mod;
}
ans = p * Inv(q) % mod;
for(int i = 1; i <= tot; ++i)
ans = ans * qPow(d[i].first, d[i].second) % mod;
printf("%lld\n", ans);
}
}