uoj335
题意
问 \(\sum_{d_1,\ d_2,\ ...,\ d_n}\ [\sum_{i\ =\ 1}^n\ d_i\ =\ n\ -\ 2]\ \binom{n\ -\ 2}{d_1,\ d_2,\ ...,\ d_n}\ (\Pi_{i\ =\ 1}^n\ a_i^{d_i\ +\ 1}\ (d_i\ +\ 1)^m)\ (\sum_{i\ =\ 1}^n\ (d_i\ +\ 1)^m)\)
\(1\ \leq\ n\ \leq\ 3\ *\ 10^4\)
做法1
\(ans\ =\ (n\ -\ 2)!\ \sum_{i\ =\ 1}^n\ \sum_x\ \sum_{d_1,\ d_2,\ ...,\ d_n}\ [\sum_{j\ =\ 1}^n\ d_j\ =\ n\ -\ 2][d_i\ =\ x]\frac{(d_i\ +\ 1)^{2m}\ a_i^{d_i\ +\ 1}}{d_i!}\ \Pi_{j\ \neq\ i}\ \frac{(d_j\ +\ 1)^m\ a_j^{d_j\ +\ 1}}{d_j!}\)。
令 \(F(x)\ =\ \sum_{i\ \geq\ 0}\ \frac{(i\ +\ 1)^m\ x^i}{i!},\ G(x)\ =\ \sum_{i\ \geq\ 0}\ \frac{(i\ +\ 1)^{2m}\ x^i}{i!}\),则 \(ans\ =\ (n\ -\ 2)!\ \Pi_{i\ =\ 1}^n\ a_i\ [x^{n\ -\ 2}]\sum_{i\ =\ 1}^n\ G(a_i\ x)\ \Pi_{j\ \neq\ i}\ F(a_i\ x)\)。
令 \(H(x)\ =\ F(x)^{-1}\ G(x)\ \bmod\ x^{n\ -\ 1},\ I(x)\ = \Pi_{i\ =\ 1}^n\ F(a_i\ x)\),则 \(ans\ =\ (n\ -\ 2)!\ \Pi_{i\ =\ 1}^n\ a_i\ [x^{n\ -\ 2}]\sum_{i\ =\ 1}^n\ I(x)\ H(a_i\ x)\ =\ (n\ -\ 2)!\ \Pi_{i\ =\ 1}^n\ a_i\ [x^{n\ -\ 2}]I(x)\ \sum_{i\ =\ 1}^n\ H(a_i\ x)\)。
考虑如何求 \(I(x)\)。\(I(x)\ =\ e^{\sum_{i\ =\ 1}^n\ ln(F(a_i\ x))}\),令 \(ln(F(x))\ =\ \sum_{i\ \geq\ 0}\ p_i\ x^i\),则 \(I(x)\ =\ e^{\sum_{i\ =\ 1}^n\ \sum_{j\ \geq\ 0}\ p_j\ (a_i\ x)^j}\ =\ e^{\sum_{j\ \geq\ 0}\ p_j\ x^j\ \sum_{i\ =\ 1}^na_i^j}\)。\(\sum_{i\ =\ 1}^n\ a_i^j\ =\ [x^j]\ \sum_{i\ =\ 1}^n\ \frac{1}{1\ -\ a_ix}\)。
总时间复杂度 \(O(n\ log^2\ n)\)。
代码
#include <bits/stdc++.h>
#ifdef __WIN32
#define LLFORMAT "I64"
#else
#define LLFORMAT "ll"
#endif
using namespace std;
const int maxn = (1 << 17) | 10; // deal with n <= 1e5
const int mod = 998244353, proot = 3;
inline int pow_mod(int x, int n) { int y = 1; while(n) { if(n & 1) y = (long long) y * x % mod; x = (long long) x * x % mod; n >>= 1; } return y; }
namespace Poly {
struct poly {
vector<int> p;
poly() { p.clear(); }
poly(int n) { p.resize(n); }
poly(const vector<int> &q): p(q) {}
poly(int n, int *a) { p.resize(n); for (int i = 0; i < n; ++i) p[i] = a[i]; }
inline int size() const { return p.size(); }
inline void resize(int n) { p.resize(n); return; }
};
vector<vector<int> > w, rw;
void dft(int n, int *a, bool rev) {
for (int i = 0, j = 0; i < n; ++i) {
if(i < j) swap(a[i], a[j]);
for (int k = n >> 1; (j ^= k) < k; k >>= 1);
}
while((n >> w.size()) > 1) {
if(!w.size()) w = rw = vector<vector<int> >(1, vector<int>{1, mod - 1});
else {
int l = 1 << w.size() + 1;
int wn = pow_mod(proot, (mod - 1) / l), rwn = pow_mod(wn, mod - 2);
vector<int> a(l), b(l);
for (int i = 0; i < l; ++i) {
if(i & 1) a[i] = (long long) a[i - 1] * wn % mod, b[i] = (long long) b[i - 1] * rwn % mod;
else a[i] = w.back()[i >> 1], b[i] = rw.back()[i >> 1];
}
w.push_back(a); rw.push_back(b);
}
}
for (int foo = 0, hl = 1, l = 2; l <= n; hl = l, l <<= 1, ++foo) {
auto &bar = rev ? rw[foo] : w[foo];
for (int i = 0; i < n; i += l) for (int j = 0, *x = a + i, *y = x + hl; j < hl; ++j, ++x, ++y) {
int t = (long long) *y * bar[j] % mod; *y = (*x - t) % mod; *x = (*x + t) % mod;
}
}
if(rev) { int inv = pow_mod(n, mod - 2); for (int i = 0; i < n; ++i) a[i] = (long long) a[i] * inv % mod; }
return;
}
poly operator * (const poly &A, const poly &B) {
static int a[maxn], b[maxn];
int n = A.size(), m = B.size();
if(n < 10 || m < 10 || n + m - 1 < 80) {
int N = n + m - 1; static int c[maxn]; memset(c, 0, sizeof(c[0]) * N);
for (int i = 0; i < n; ++i) for (int x = A.p[i], j = 0; j < m; ++j) c[i + j] = ((long long) x * B.p[j] + c[i + j]) % mod;
return poly(N, c);
}
int N = 1; while(N < n + m - 1) N <<= 1;
for (int i = 0; i < N; ++i) a[i] = i < n ? A.p[i] : 0, b[i] = i < m ? B.p[i] : 0;
dft(N, a, 0); dft(N, b, 0); for (int i = 0; i < N; ++i) a[i] = (long long) a[i] * b[i] % mod; dft(N, a, 1);
return poly(n + m - 1, a);
}
poly operator * (const int &a, const poly &B) {
static int b[maxn]; int n = B.size();
for (int i = 0; i < n; ++i) b[i] = (long long) a * B.p[i] % mod;
return poly(n, b);
}
poly operator - (const poly &A, const poly &B) {
static int a[maxn]; int n = A.size(), m = B.size(), N = max(n, m);
for (int i = 0; i < N; ++i) a[i] = ((i < n ? A.p[i] : 0) - (i < m ? B.p[i] : 0)) % mod;
return poly(N, a);
}
poly operator + (const poly &A, const poly &B) {
static int a[maxn]; int n = A.size(), m = B.size(), N = max(n, m);
for (int i = 0; i < N; ++i) a[i] = ((i < n ? A.p[i] : 0) + (i < m ? B.p[i] : 0)) % mod;
return poly(N, a);
}
poly inv(int n, const poly &A) { // A(x)^-1 mod x^n
if(n == 1) { return poly(vector<int>{pow_mod(A.p[0], mod - 2)}); }
static poly B0, B, tA;
B0 = inv(n + 1 >> 1, A);
if(A.size() < n) tA = A, tA.resize(n);
else tA.p.clear(), tA.p.insert(tA.p.end(), A.p.begin(), A.p.begin() + n);
B = 2 * B0 - B0 * B0 * tA;
B.resize(n);
return B;
}
poly rev(const poly &A) { static poly B; B = A; reverse(B.p.begin(), B.p.end()); return B; }
poly operator / (const poly &A, const poly &B) {
static poly rA, rB, C, D; int n = A.size(), m = B.size();
rA = rev(A); rB = rev(B); C = inv(n - m + 1, rB); D = C * rA; D.resize(n - m + 1);
return rev(D);
}
poly operator % (const poly &A, const poly &B) { static poly D, ret; D = A / B; ret = A - B * D; ret.resize(B.size() - 1); return ret; }
poly Rem(const poly &A, const poly &B, const poly &D) { static poly ret; ret = A - B * D; ret.resize(B.size() - 1); return ret; }
poly dao(const poly &A) {
static int a[maxn]; int n = A.size() - 1;
for (int i = 0; i < n; ++i) a[i] = (long long) A.p[i + 1] * (i + 1) % mod;
return poly(n, a);
}
poly ji(const poly &A) {
static int a[maxn]; int n = A.size();
for (int i = 1; i <= n; ++i) a[i] = (long long) A.p[i - 1] * pow_mod(i, mod - 2) % mod; a[0] = 0;
return poly(n + 1, a);
}
poly ln(int n, const poly &A) { // ln(A(x)) mod x^n
static poly B, C;
C = dao(A); B = inv(n, A);
B = B * C; B = ji(B); B.resize(n);
return B;
}
poly exp(int n, const poly &A) { // e^A(x) mod x^n
if(n == 1) { return poly(vector<int>{1}); }
static poly B0, B, C, tA;
B0 = exp(n + 1 >> 1, A);
if(A.size() < n) tA = A, tA.resize(n);
else tA.p.clear(), tA.p.insert(tA.p.end(), A.p.begin(), A.p.begin() + n);
C = ln(n, B0);
B = B0 * (poly(vector<int>{1}) - C + tA);
B.resize(n);
return B;
}
poly pow_mod(int n, const poly &A, int N) { // A(x)^N mod x^n
static poly B;
int na = A.size(), i = 0;
while(i < na && A.p[i] == 0) ++i;
if(i == na || (long long) N * i >= n) return poly(vector<int>(n, 0));
if(i) {
B.p.clear();
for (int j = i; j < na; ++j) B.p.push_back(A.p[j]);
static poly C;
C = pow_mod(n - N * i, B, N);
B.p = vector<int>(N * i, 0);
B.p.insert(B.p.end(), C.p.begin(), C.p.end());
return B;
}
if(A.p[0] != 1) {
int t = ::pow_mod(A.p[0], mod - 2), s = ::pow_mod(A.p[0], N); B.resize(na);
for (int i = 0; i < na; ++i) B.p[i] = (long long) A.p[i] * t % mod;
B = pow_mod(n, B, N);
for (int i = 0; i < n; ++i) B.p[i] = (long long) s * B.p[i] % mod;
return B;
}
B = ln(n, A);
for (int i = 0; i < n; ++i) B.p[i] = (long long) B.p[i] * N % mod;
return exp(n, B);
}
}
using namespace Poly;
int main() {
int n, m;
cin >> n >> m;
if(n == 1) { puts("0"); return 0; }
vector<int> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
vector<int> fac(n), ifac(n);
fac[0] = 1; for (int i = 1; i < n; ++i) fac[i] = (long long) fac[i - 1] * i % mod;
ifac[n - 1] = pow_mod(fac[n - 1], mod - 2); for (int i = n - 1; i; --i) ifac[i - 1] = (long long) ifac[i] * i % mod;
poly f(n - 1), g(n - 1);
for (int i = 0; i < n - 1; ++i) f.p[i] = (long long) pow_mod(i + 1, m) * ifac[i] % mod, g.p[i] = (long long) pow_mod(i + 1, 2 * m) * ifac[i] % mod;
g = inv(n - 1, f) * g; g.resize(n - 1);
f = ln(n - 1, f);
function<pair<poly, poly>(int, int)> work = [&](int l, int r) {
if(l == r) { return make_pair(poly(vector<int>{1}), poly(vector<int>{1, (mod - a[l]) % mod})); }
int mid = l + r >> 1;
auto a = work(l, mid), b = work(mid + 1, r);
return make_pair(a.first * b.second + a.second * b.first, a.second * b.second);
};
auto t = work(0, n - 1);
poly p = t.first * inv(n - 1, t.second);
p.resize(n - 1);
for (int i = 0; i < n - 1; ++i) {
f.p[i] = (long long) f.p[i] * p.p[i] % mod;
g.p[i] = (long long) g.p[i] * p.p[i] % mod;
}
f = exp(n - 1, f) * g;
f.resize(n - 1);
int ans = (long long) fac[n - 2] * f.p.back() % mod;
for (int x: a) ans = (long long) x * ans % mod;
cout << (ans + mod) % mod << endl;
return 0;
}