@codechef - TREDEG@ Trees and Degrees
@description@
考虑所有的有 N 个节点的树(节点编号为 1 ∼ N)。我们认为两棵树不同,当且仅当存在一对节点 u 和 v,满足一棵树中 u 和 v 之间有边,另一棵中没有。
设 T 为等概率随机选择的一棵树。我们记 T 中每个节点的度数分别为 d1,d2,...,dN。令 A = (d1 ·d2 ·····dN)^K,请求出 A 的期望值。
我们可以证明,A 的期望值可以写作分数 P/Q 的形式,其中 P 和 Q 为互质的正整数,且 Q 与 998,244,353 互质。
你需要计算 P ·Q^(−1) 对 998,244,353 取模的结果,其中 Q^(−1) 代表 Q 在模 998,244,353 意义下的乘法逆元。
输入格式
输入的第一行包含一个整数 T,代表测试数据的组数。接下来是 T 组数据。 每组数据仅有一行,包含两个整数 N 和 K。
输出格式
对于每组数据,输出一行,包含一个整数 P ·Q^(−1) (mod 998244353)。
数据范围与子任务
• 1 ≤ T ≤ 1,00 • 2 ≤ N ≤ 2,000,000 • 1 ≤ K ≤ 10^9 • 单个输入中∑N ≤ 2,000,000
子任务 1(20 分): • T = 10 • 2 ≤ N ≤ 7
子任务 2(30 分): • 单 个 输 入 中 ∑N ≤ 100,000
子任务 3(50 分): • K = 1
样例输入
2
3 1
4 2
样例输出
2
748683279
样例解释
第一组数据:共有三种不同的树,可以表示为三条路径:1−2−3、1−3−2、3−1−2。 期望值为((1·2·1)1 + (1·1·2)1 + (2·1·1)1)/3 = 2,因此 P = 2、Q = 1、Q^(−1) = 1,答案为 2。
@solution@
这道题的数据范围。。。是把两道题合成一道题来考的 。。。
关于点的度数的树计数问题,自然想到 prufer 序列。又因 n 个点的树共有 \(n^{n-2}\) 种,由 prufer 序列的一些性质可以得到答案的表达式:
再稍加变形可以变成卷积的形式:
我们令 \(D(x) = \sum_{i=0}\frac{(i+1)^K}{i!}*x^i\),则 \(D^{n}(x)\) 中 \(x^{n-2}\) 对应的系数就是上面卷积的答案,不妨令其为 A。
则 \(ans = \frac{n!}{n^{n-2}}*A\)。问题在于怎么求解 A。
一种方法是使用多项式 exp + ln 实现多项式求幂。不过该方法常数较大,不能通过子任务 3,所以子任务 3 我们需要另辟蹊径。
考虑 K = 1 时:
于是就可以线性求出 n-2 项的值了。
@accepted code@
#include<cstdio>
#include<algorithm>
using namespace std;
const int MAXN = 4000000;
const int MOD = 998244353;
const int G = 3;
int pow_mod(int b, int p) {
int ret = 1;
while( p ) {
if( p & 1 ) ret = 1LL*ret*b%MOD;
b = 1LL*b*b%MOD;
p >>= 1;
}
return ret;
}
int fct[MAXN + 5], ifct[MAXN + 5];
int pw[20 + 5], ipw[20 + 5];
struct poly{
void debug(int *A, int n, char ch) {
printf("%c ", ch);
for(int i=0;i<n;i++)
printf("%d ", A[i]);
puts("");
}
void copy(int *A, int *B, int n) {
for(int *p=A,*q=B,i=0;i<n;p++,q++,i++)
(*p) = (*q);
}
void clear(int *A, int l, int r) {
for(int *p=A+l,i=l;i<r;p++,i++)
(*p) = 0;
}
void ntt(int *A, int len, int type) {
for(int i=0,j=0;i<len;i++) {
if( i < j ) swap(A[i], A[j]);
for(int k=(len>>1);(j^=k)<k;k>>=1);
}
for(int i=1;(1<<i)<=len;i++) {
int s = (1<<i), t = (s>>1);
int u = (type == 1) ? pw[i] : ipw[i];
for(int j=0;j<len;j+=s) {
for(int k=0,p=1;k<t;k++,p=1LL*p*u%MOD) {
int x = A[j+k], y = 1LL*p*A[j+k+t]%MOD;
A[j+k] = (x + y)%MOD, A[j+k+t] = (x + MOD - y)%MOD;
}
}
}
if( type == -1 ) {
int iv = pow_mod(len, MOD-2);
for(int i=0;i<len;i++)
A[i] = 1LL*A[i]*iv%MOD;
}
}
int length(int n) {
int len; for(len = 1; len < n; len <<= 1);
return len;
}
int tmp1[MAXN + 5];
void poly_inv(int *A, int *B, int n) {
int len = length((n<<1) - 1);
clear(B, 0, len);
if( n == 1 ) {
B[0] = pow_mod(A[0], MOD-2);
return ;
}
poly_inv(A, B, (n + 1) >> 1);
copy(tmp1, A, n);
ntt(tmp1, len, 1), ntt(B, len, 1);
for(int *p=B,*q=tmp1,i=0;i<len;p++,q++,i++)
(*p) = 1LL*(*p)*(2 + MOD - 1LL*(*q)*(*p)%MOD)%MOD;
ntt(B, len, -1);
clear(tmp1, 0, len), clear(B, n, len);
}
int tmp2[MAXN + 5], tmp3[MAXN + 5];
void poly_mul(int *A, int *B, int *C, int n, int m, int k) {
int len = length(n+m-1);
copy(tmp2, A, n), copy(tmp3, B, m);
ntt(tmp2, len, 1), ntt(tmp3, len, 1);
for(int i=0;i<len;i++) C[i] = 1LL*tmp2[i]*tmp3[i]%MOD;
ntt(C, len, -1);
clear(C, k, len), clear(tmp2, 0, len), clear(tmp3, 0, len);
}
void poly_int(int *A, int *B, int n) {
for(int i=n-1;i>=0;i--)
A[i+1] = 1LL*B[i]*pow_mod(i+1, MOD-2)%MOD;
A[0] = 0;
}
void poly_deri(int *A, int *B, int n) {
for(int i=1;i<n;i++)
A[i-1] = 1LL*B[i]*i%MOD;
A[n-1] = 0;
}
int tmp4[MAXN + 5];
void poly_ln(int *A, int *B, int n) {
poly_inv(A, B, n);
poly_deri(tmp4, A, n);
poly_mul(tmp4, B, B, n, n, n);
poly_int(B, B, n-1);
}
int tmp5[MAXN + 5], tmp6[MAXN + 5];
void poly_exp(int *A, int *B, int n) {
int len = length((n<<1) - 1);
clear(B, 0, len);
if( n == 1 ) {
B[0] = 1;
return ;
}
poly_exp(A, B, (n + 1) >> 1);
copy(tmp5, A, n); poly_ln(B, tmp6, n);
ntt(tmp5, len, 1), ntt(tmp6, len, 1), ntt(B, len, 1);
for(int i=0;i<len;i++)
B[i] = 1LL*B[i]*(1 + MOD - tmp6[i] + tmp5[i])%MOD;
ntt(B, len, -1);
clear(B, n, len), clear(tmp5, 0, len), clear(tmp6, 0, len);
}
int tmp7[MAXN + 5];
void poly_pow(int *A, int n, int k) {
poly_ln(A, tmp7, n);
for(int i=0;i<n;i++)
tmp7[i] = 1LL*tmp7[i]*k%MOD;
poly_exp(tmp7, A, n);
}
}oper;
int f[MAXN + 5];
void solve(){
int N, K; scanf("%d%d", &N, &K);
if( K == 1 ) {
int ans = 0;
for(int i=0,p=1;i<=N-2;i++,p=1LL*p*N%MOD)
ans = (ans + 1LL*p*ifct[i]%MOD*fct[N]%MOD*ifct[N-2-i]%MOD*ifct[i+2]%MOD)%MOD;
printf("%lld\n", 1LL*pow_mod(N, (MOD-1)-(N-2))*fct[N-2]%MOD*ans%MOD);
return ;
}
int len; for(len = 1; len <= N; len <<= 1);
for(int i=0;i<len;i++)
f[i] = 1LL*pow_mod(i+1, K)*pow_mod(fct[i], MOD-2)%MOD;
oper.poly_pow(f, len, N);
printf("%lld\n", 1LL*pow_mod(N, (MOD-1)-(N-2))*fct[N-2]%MOD*f[N-2]%MOD);
}
void init() {
for(int i=0;i<=21;i++)
pw[i] = pow_mod(G, (MOD-1)/(1<<i)), ipw[i] = pow_mod(pw[i], MOD-2);
fct[0] = 1;
for(int i=1;i<=MAXN;i++)
fct[i] = 1LL*fct[i-1]*i%MOD;
ifct[MAXN] = pow_mod(fct[MAXN], MOD-2);
for(int i=MAXN-1;i>=0;i--)
ifct[i] = 1LL*ifct[i+1]*(i+1)%MOD;
}
int main() {
init();
int T; scanf("%d", &T);
for(int i=1;i<=T;i++)
solve();
}
@details@
算是比较中规中矩的一道题吧。
不过,把数据范围拆成两类,把一道题的算法拆成两种这个操作。。。
我还以为我的常数太大了过不了。