闲话 23.2.2

闲话

看很多人在学习龚体
我的评价是
为啥不做小孩召开法

今日推歌:回马枪 covered by 兰音Reine

再论某搞笑题

比赛刚结束就被爆标了(
我不知道是我太弱还是巨佬们太强(
反正我太弱是肯定的

\(\text{Sol 1}:O(k\log^2 k + k\log k \log n)\) by Sol1

Sol 1 by Sol1

我们容易表出 \(F(1, n)\)——这是广义斐波那契数列的形式,记 \(F(1, n) = Ax^n + By^n\) 的形式,这里可以扩域处理根式,因为可能二次剩余不存在;复杂度不会增加。

随后我们按经典的方式,枚举每一种造成贡献的路径。我们枚举一个长度为 \(k\) 的序列 \(\langle a\rangle\),其中每个元素 \(a_i \in [0, n]\),表示这路径在第 \(i\) 层走了 \(a_i\) 步;恒有 \(\sum a_i = n\)

把所有转移的贡献都一次乘入,我们可以写出

\[F(k, n) = \sum_{a} \left(Ax^{a_1} + By^{a_1}\right) \times \prod_{i = 2}^k t_i^{a_i} \prod_{j = 1}^{i - 1} s^{a_j} \]

观察 \(s\) 的上标里每个 \(a_i\) 的贡献,我们发现 \(a_i\) 会出现 \(k - i\) 次,因此提出去 \(s^{(k - 1)a_1}\) 就有

\[F(k, n) = \sum_{a} \left(Ax^{a_1} + By^{a_1}\right) s^{(k - 1)a_1} \times \prod_{i = 2}^k t_i^{a_i} s^{(k - i)a_i} \]

稍微整理一下就可以得到

\[F(k, n) = \sum_{a} \left(A\left(s^{k - 1}x\right)^{a_1} + B\left(s^{k - 1}y\right)^{a_1}\right) \times \prod_{i = 2}^k t_i^{a_i} s^{(k - i)a_i} \]

也就是说我们要求的值形如

\[\sum_{a} (s^{k - 1}x)^{a_1} \prod_{i = 2}^k \left(t_is^{k - i}\right)^{a_i} \]

这个形式可以直接构造生成函数,要求的就是

\[[z^n]\left( \frac{1}{1 - \left(s^{k - 1}x\right) z } \times \prod_{i = 2}^k \frac{1}{1 - \left(t_is^{k - 1}\right) z } \right) \]

这个可以朴素分治乘法+线性递推做到 \(O(k\log^2 k + k\log k \log n)\)

Sol1 的做法就到这。NaCly_Fish 通过一个我之前想过但是感觉没用的结构推导出了和这个式子同等的结构,并得到了更优的做法。

\(\text{Sol 2}:\)\(o(k\log^2 k + k\log k \log n)\) by NaCly_Fish

我们容易对第二维求和得到一个生成函数

\[F_k(x) = \frac{F_{k - 1}(sx)}{1 - t_kx} \]

边界是(没分解成 \(Ax^n + By^n\) 的形式)

\[F_1(x) = \frac{st_0 + (st_1 - st_0 a) x}{1 - ax - bx^2} \]

我当时的推导到这就结束了(

可以将 \(F_k(x)\) 展开得到

\[F_k(x) = \frac{st_0 + (st_1 - st_0 a) \left(s^{k - 1}x\right)}{1 - a\left(s^{k - 1}x\right) - b\left(s^{k - 1}x\right)^2} \prod_{i = 2}^k \frac{1}{1 - t_i s^{k - i}x} \]

这便得到了和上面同等的形式。直接做就行,挺好写的;线性递推甚至可以 \(O(k^2)\)。鰰的实现在这里可以看到复杂度薄纱标程

“EI 讲过,这种形式可以做分式分解。”

分子本质上是 \(O(1)\) 次求系数操作,不对复杂度作影响,只需要考虑处理分母。首先将分母做因式分解。我们记分母 \(Q(x)\) 被分解为

\[\frac{1}{Q(x)} = \prod_{i = 1}^m \frac{1}{(1 - q_i x)^{d_i}} \]

这里可能有重根,需要记一个 \(d_i\),上界是 \(m\) 而不是 \(k + 1\)。这里可能有二次非剩余,可能需要扩域。

如果我们可以把他改写成

\[\frac{1}{Q(x)} = \sum_{i = 1}^m \frac{P_i(x)}{(1 - q_ix)^{d_i}} \]

的形式,就可以轻松提取系数了,最终无非是一系列二项式系数加加乘乘的形式。

对每个 \(i\),由于这是通分的形式,我们可以知道

\[\frac{P_i(x)}{(1 - q_ix)^{d_i}} Q(x) \equiv 1 \pmod{(1 - q_ix)^{d_i}} \]

这样我们只需要求出每个 \(\dfrac{Q(x)}{(1 - q_ix)^{d_i}} \pmod{(1 - q_ix)^{d_i}}\) 后求出它在 \(\text{mod }(1 - q_ix)^{d_i}\) 意义下的逆元。

对第一步,由于

\[Q(x) = \prod_{i = 1}^m (1 - q_ix)^{d_i} \]

\(i = t\) 时,我们要计算的就是

\[F_t(x) = \prod_{i \neq t} (1 - q_i x)^{d_i} \text{ mod } (1 - q_tx)^{d_t} \]

这的做法应当与多项式多点求值类似。只需按照线段树结构维护,以计算 \(t = m\) 为例,初始时计算 \([1, mid]\) 段乘积模 \([mid + 1, m]\) 段乘积的结果,并带着该结果向右递归。每次将其乘入 \([l, mid]\) 段乘积并对 \([mid+1, r]\) 段乘积取模,到达叶子时结果即 \(F_m(x)\)。应当是 \(O(k \log^2 k)\) 的吧。

对第二步,假设我们要计算 \(F(x)\) 在模 \((1 - qx)^{d}\) 意义下的逆。鰰不复读那我复读吧 /qd

若答案为 \(P(x)\),则我们知道

\[P(x) F(x) \equiv 1 \pmod{(1 - qx)^{d}} \]

\[P(q^{-1}(1-x)) F(q^{-1}(1-t)) \equiv 1 \pmod{x^{d}} \]

它对答案的贡献即为

\[[x^n]\frac{P(x)}{(1-qx)^d} = [x^n] \sum_{i = 0}^{d-1} \frac{[t^i]P(q^{-1}(1-t))}{(1-qx)^{d-i}} \]

\(P(q^{-1}(1-t))\) 的系数可以在模 \(x^d\) 意义下做多项式求逆得到。假设 \(f_i = [t^i]P(q^{-1}(1-t))\),原式即

\[[x^n] \sum_{i = 0}^{d-1} f_i\frac{1}{(1-qx)^{n-i}} = q^n \sum_{i = 0}^{d-1} \binom{n+i}{i} f_{d-1-i} \]

大概可以 \(O(k)\)

总时间复杂度 \(O(k\log^2 k)\)

没实现完 /qd
#include <bits/stdc++.h>
using namespace std;
#define inline __attribute__((__gnu_inline__, __always_inline__, __artificial__)) inline
using pii = pair<int,int>; using vi = vector<int>; using vp = vector<pii>; using ll = long long; 
using ull = unsigned long long; using db = double; using ld = long double; using lll = __int128_t;
template<typename T1, typename T2> T1 max(T1 a, T2 b) { return a > b ? a : b; }
template<typename T1, typename T2> T1 min(T1 a, T2 b) { return a < b ? a : b; }
#define multi int _T_; cin >> _T_; for (int TestNo = 1; TestNo <= _T_; ++ TestNo)
#define timer cerr << 1. * clock() / CLOCKS_PER_SEC << '\n';
#define iot ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define rep(i,s,t) for (register int i = (s), i##_ = (t) + 1; i < i##_; ++ i)
#define pre(i,s,t) for (register int i = (s), i##_ = (t) - 1; i > i##_; -- i)
#define eb emplace_back
#define pb pop_back
const int N = 1e6 + 10;
const int inf = 0x3f3f3f3f;
const ll infll = 0x3f3f3f3f3f3f3f3fll;
const int mod = 998244353;

int k, m, st0, st1, a, b, s, t[N], d[N];
ll n;

int fac[N], inv[N], ifc[N];
inline int Norm(const int& u) { return u >= mod ? u - mod : u; }
inline int C(int n, int m) { return 1ll * fac[n] * ifc[m] % mod * ifc[n - m] % mod; }


int dlt;
struct Fd { // 域 F_p[dlt]
    int a, b;
    Fd(const int& _a = 0, const int& _b = 0) : a(Norm(_a)), b(Norm(_b)) {}
    inline friend Fd operator+ (const Fd& A, const Fd& B) { return Fd(Norm(A.a + B.a), Norm(A.b + B.b)); }
    inline friend Fd operator+ (const Fd& A, const int& B) { return Fd(Norm(A.a + B), A.b); }
    inline friend Fd operator- (const Fd& A, const Fd& B) { return Fd(Norm(A.a - B.a + mod), Norm(A.b - B.b + mod)); }
    inline friend Fd operator- (const Fd& A, const int& B) { return Fd(Norm(A.a - B + mod), A.b); }
    inline friend Fd operator- (const int& A, const Fd& B) { return Fd(Norm(A - B.a + mod), Norm(mod - B.b)); }
    inline friend Fd operator* (const Fd& A, const Fd& B) { return Fd((1ll * A.a * B.a + 1ll * A.b * B.b % mod * dlt) % mod, (1ll * A.b * B.a + 1ll * A.a * B.b) % mod); }
    inline friend Fd operator* (const Fd& A, const int& B) { return Fd(1ll * A.a * B % mod, 1ll * A.b * B % mod); }
    inline Fd operator -() { return Fd(Norm(mod - a), Norm(mod - b)); }
    inline friend bool operator< (const Fd& A, const Fd& B) { if (A.a != B.a) return A.a < B.a; else return A.b < B.b; }
} q[N];
map<Fd, int> qi; 
using poly = vector<Fd>;

int wl, w[N];
unsigned long long W[N];
void get(int n) { wl = 1; while (wl < n) wl <<= 1; }
int qp(int a, int b) { int ans = 1; while (b) { if (b & 1) ans = 1ll * ans * a % mod; a = 1ll * a * a % mod; b >>= 1; } return ans; }
Fd qp(Fd a, int b) { Fd ans = 1; while (b) { if (b & 1) ans = ans * a; a = a * a; b >>= 1; } return ans; }
void init(int n) {
    int t = 1;
    while ((1 << t) < n) t++;
    t = min(t - 1, 21);
    w[0] = 1, w[1 << t] = qp(31, 1 << 21 - t);
    for (int i = t; i >= 1; i--)
        w[1 << i - 1] = 1ll * w[1 << i] * w[1 << i] % mod;
    for (int i = 1; i < (1 << t); i++)
        w[i] = 1ll * w[i & i - 1] * w[i & -i] % mod;
    for (int i = 0; i <= (1 << t); i++)
        W[i] = (((unsigned __int128)w[i] << 64) + mod - 1) / mod;
    for (int i = 0; i <= (1 << t); i++) W[i] = W[i] % mod;
}
void DIF(poly &a) {
    int n = a.size();
    for (int mid = n >> 1; mid >= 1; mid >>= 1) {
        for (int i = 0, k = 0; i < n; i += mid << 1, k++) {
            for (int j = 0; j < mid; j++) {
                Fd x = a[i + j], y = a[i + j + mid] * W[k];
                a[i + j] = x + y;
                a[i + j + mid] = x - y;
            }
        }
    }
}
void DIT(poly &a) {
    int n = a.size();
    for (int mid = 1; mid < n; mid <<= 1) {
        for (int i = 0, k = 0; i < n; i += mid << 1, k++) {
            for (int j = 0; j < mid; j++) {
                Fd x = a[i + j], y = a[i + j + mid];
                a[i + j] = x + y;
                a[i + j + mid] = x - y;
                a[i + j + mid] = a[i + j + mid] * W[k];
            }
        }
    }
    int inv = qp(n, mod - 2);
    for (int i = 0; i < n; i++) a[i] = a[i] * inv;
    reverse(a.begin() + 1, a.begin() + n);
}
inline ostream& operator<<(ostream& out, const Fd& v) {
    out << '(' << v.a << ", " << v.b << ")";
    return out;
}
inline ostream& operator<<(ostream& out, const poly& p) {
    rep(i,0,p.size() - 1) out << p[i] << ' ';
    return out;
}
poly operator*(poly a, poly b) {
    int n = a.size(), m = b.size();
    poly ans(n + m - 1);
    rep(i,0,n+m-1) rep(j,0,i) if (j < n and i - j < m) ans[i] = ans[i] + a[j] * b[i - j];
    // get(n + m);
    // a.resize(wl), b.resize(wl);
    // DIF(a), DIF(b);
    // for (int i = 0; i < wl; i++) a[i] = a[i] * b[i];
    // DIT(a), a.resize(n + m - 1);
    return ans;
}
inline poly operator+(poly a, poly b) {
    int n = max(a.size(), b.size());
    a.resize(n), b.resize(n);
    poly ans(n);
    for (int i = 0; i < n; i++) ans[i] = a[i] + b[i];
    return ans;
}
inline poly operator-(poly a, poly b) {
    int n = max(a.size(), b.size());
    a.resize(n), b.resize(n);
    poly ans(n);
    for (int i = 0; i < n; i++) ans[i] = a[i] - b[i];
    return ans;
}
inline poly operator+(poly a, int b) {
    a[0] = a[0] + b;
    return a;
}
inline poly operator-(poly a, int b) {
    a[0] = a[0] - b;
    return a;
}
inline poly operator+(int b, poly a) {
    a[0] = a[0] + b;
    return a;
}
inline poly operator-(int b, poly a) {
    a[0] = a[0] - b;
    for (auto &x : a) x = -x;
    return a;
}
inline poly Inv(poly &f){
    int n = f.size();
    get(n), f.resize(wl);
    poly g(wl), tmp, ret;
    g[0] = qp(f[0], mod - 2);
    for (int len = 2; len <= wl; len <<= 1) {
        tmp.resize(len), ret.resize(len);
        for (int i = 0; i < len; i++) tmp[i] = f[i];
        for (int i = 0; i < (len >> 1); i++) ret[i] = g[i];

        DIF(tmp), DIF(ret);
        for (int i = 0; i < len; i++) tmp[i] = tmp[i] * ret[i];
        DIT(tmp);
        
        for (int i = 1; i < (len >> 1); i++) tmp[i] = 0;
        tmp[0] = mod - 1;
        DIF(tmp);
        for (int i = 0; i < len; i++) ret[i] = tmp[i] * ret[i];
        DIT(ret);
        for (int i = len >> 1; i < len; i++) g[i] = - ret[i];
    } g.resize(n);
    return g;
}
pair<poly, poly> operator/(poly a, poly b) {
    int n = a.size(), m = b.size();
    if (n < m) return {poly{Fd(1, 0)}, a};
    poly q(a), f(b);
    reverse(q.begin(), q.end()), reverse(f.begin(), f.end());
    f.resize(n - m + 1), f = Inv(f);
    q = q * f, q.resize(n - m + 1);
    reverse(q.begin(), q.end());
    b = b * q;
    poly r = a - b;
    r.resize(m - 1);
    return make_pair(q, r);
}
inline poly shift(poly f, int c) {
    c %= mod;
    c = c + (c < 0) * mod;
    if (c == 0) return f;
    int n = f.size();
    poly A(n), B(n), ret(n);
    for (int i = 0; i < n; ++i) A[n - i - 1] = f[i] * fac[i];
    for (int i = 0, pc = 1; i < n; ++i, pc = 1ll * pc * c % mod)
        B[i] = Fd(1ll * pc * ifc[i] % mod, 0);
    A = A * B, A.resize(n);
    for (int i = 0; i < n; ++i) ret[i] = A[n - i - 1] * ifc[i];
    return ret;
}

map<pair<int, int>, poly> P;
poly Fm[N];
poly prep(int l, int r) {
    // cout << l << ' ' << r << endl;
    if (l == r) {
        poly p; p.resize(d[l] + 1);
        Fd coef = Fd(1,0);
        rep(i,0,d[l]) p[i] = coef * C(d[l], i), coef = coef * (mod - q[l]);
        // cout << p << endl;
        return P[{l, l}] = p;
    } int mid = (l + r) >> 1;
    P[{l, r}] = prep(l, mid) * prep(mid + 1, r);
    return P[{l, r}];
}
void calc(int l, int r, const poly& Fupper) {
    cout << l << ' ' << r << endl;
    cout << Fupper << endl;
    poly f = poly{Fd(1, 0)}, g = f; 
    rep(i,1,m) {
        if (l <= i and i <= r) g = g * P[{i, i}];
        else f = f * P[{i, i}];
    }
    // cout << "f : " << f << endl;
    // cout << "g : " << g << endl;
    // cout << "Fp : " << P[{l, r}] << endl;
    cout << (f / g).second << endl;
    if (l == r) { Fm[l] = (Fupper / P[{l, l}]).second; cerr << Fm[l].size() << ' ' << d[l] << endl; return ; }
    int mid = (l + r) >> 1;
    calc(l, mid, (Fupper * P[{mid + 1, r}] / P[{l, mid}]).second);
    calc(mid + 1, r, (Fupper * P[{l, mid}] / P[{mid + 1, r}]).second);
}

int F(ll n) { // [x^n] (1 / Q(x))
    int ans = 0, nmod = n % mod, nmod1 = n % (mod - 1);
    rep(t,1,m) {
        int nume = 1;
        Fd nowans = 0;
        rep(i,0,d[t]-1) nowans = nowans + nume * Fm[t][d[t] - 1 - i] * ifc[i], nume = 1ll * nume * (nmod + i + 1) % mod;
        nowans = nowans * qp(q[t], nmod1);
        ans = Norm(ans + nowans.a);
    } return ans;
}

signed main() {
    poly f(5);
    f[0].a = 1; 
    f[0].a = 6; 
    f[0].a = 3; 
    f[0].a = 4; 
    f[0].a = 9; 
    cout << Inv(f) << endl;

    cin >> k >> n;
    cin >> st0 >> st1 >> a >> b >> s;
    rep(i,2,k) cin >> t[i];
    init(k);
	fac[0] = ifc[0] = 1;
	fac[1] = inv[1] = ifc[1] = 1;
	rep(i,2,k + 1) {
		fac[i] = 1ll * fac[i - 1] * i % mod;
		inv[i] = Norm(mod - 1ll * mod / i * inv[mod % i] % mod);
		ifc[i] = 1ll * ifc[i - 1] * inv[i] % mod;
	}

    // srand(time(0));
    // Fd aaa = Fd(rand(), 0);
    // cout << aaa << endl;
    // cout << poly{aaa, aaa, aaa} << endl;
    // cout << aaa * Fd(1, 0) << endl;
    // cout << poly{aaa, aaa, aaa} * poly{Fd(2, 0)} << endl;

    dlt = (1ll * a * a + 4ll * b) % mod;
    int s_ = 1;
    rep(i,0,k-2) qi[Norm(mod - 1ll * t[k - i] * s_ % mod)]++, s_ = 1ll * s_ * s % mod;
    s_ = 1ll * s_ * s % mod; // s_ = s^{k - 1}
    int inv2b = qp(2ll * b % mod, mod - 2);
    qi[Fd(Norm(mod - 1ll * inv2b * a % mod * s_ % mod), Norm(mod - 1ll * inv2b * s_ % mod))] ++;
    qi[Fd(Norm(mod - 1ll * inv2b * a % mod * s_ % mod), 1ll * inv2b * s_ % mod)] ++;

    for (auto [qt, dt] : qi) {
        ++ m; q[m] = qt, d[m] = dt;
    }

    cerr << "Tag" << endl;
    prep(1, m);
    cerr << "omg" << endl;
    rep(i,1,m) cerr << i << ' ' << q[i].a << ' ' << q[i].b << ' ' << d[i] << endl;
    calc(1, m, poly{Fd(1, 0)});
    rep(i,1,m) {
        cout << "my calc: " << endl << Fm[i] << endl;
        poly Fmtrue = poly{Fd(1, 0)};
        rep(j,1,m) if (i != j) {
            Fmtrue = Fmtrue * P[{j, j}];
            // cout << Fmtrue << endl;
            // cout << P[{j,j}] << endl;
        }
        // cout << "prev: " << endl << Fmtrue << endl;
        cout << "true ans: " << endl << (Fmtrue/P[{i, i}]).second << endl;
    } 
    rep(i,1,m) {
        Fd coef = mod - qp(q[i], mod - 2), cur = coef;
        rep(j,1,d[i]-1) Fm[i][j] = Fm[i][j] * cur, cur = cur * coef;
        shift(Fm[i], - 1);
        Fm[i] = Inv(Fm[i]);
    }

    cout << (1ll * st0 * F(n) + 1ll * Norm(st1 - 1ll * st0 * a % mod + mod) * s_ % mod * F(n - 1)) % mod << '\n';
}

关于为啥要开到负数……为了加个 corner case?

posted @ 2023-02-02 20:36  joke3579  阅读(108)  评论(3编辑  收藏  举报