uoj424

题意

\(F_A(l,\ r)\ =\ \min\{i\ |\ A_i\ =\ \min_{j\ =\ l}^r\{A_j\}\}\),称 \(A\)\(B\) 同构当且仅当 \(\forall\ l,\ r,\ F_A(l,\ r)\ =\ F_B(l,\ r)\)
问有多少个长度为 \(n\) 的不同构的序列 \(A\) 满足 \(A_i\ \in\ [1,\ m]\)

\(1\ \leq\ n,\ m\ \leq\ 10^5\)

做法1

两序列同构当且仅当笛卡尔树相同。定义一个点的深度为其到根路径上左儿子点的个数,所求即为大小为 \(n\),最大深度不超过 \(m\) 的二叉树个数。注意 \(n\ <\ m\) 时答案为 \(0\)

\(f(i,\ j)\) 表示 \(i\) 个点,最大深度不超过 \(j\) 的个数,则有 \(f(i,\ j)\ =\ \sum_{k\ =\ 0}^{i\ -\ 1}\ f(k,\ j\ -\ 1)\cdot\ f(i\ -\ k\ -\ 1,\ j)\)。令 \(F_j(x)\ =\ \sum_{i}\ f(i,\ j)\ x^i\),则有 \(F_j(x)\ =\ \frac{1}{1\ -\ x\cdot F_{j-1}(x)}\)。令 \(F_j(x)\ =\ \frac{A_j(x)}{B_j(x)}\),则可以发现 \(A_j(x)\ =\ B_{j-1}(x),\ B_j(x)\ =\ B_{j-1}(x)\ -\ x\cdot A_{j-1}(x)\),可以用矩阵快速幂求 \(A_m(x),\ B_m(x)\)

注意到模数为 \(998244353\),故可以求出 \(A_m(x),\ B_m(x)\) 的点值表示,即 \(A_m(\omega^0),\ A_m(\omega^1),\ ...,\ A_m(\omega^N),\ B_m(\omega^0),\ B_m(\omega^1),\ ...,\ B_m(\omega^N)\) 之后用 \(IDFT\) 即可得到 \(A_m(x),\ B_m(x)\)

时间复杂度 \(O(n\ log\ n)\)

代码

#include <bits/stdc++.h>

#ifdef __WIN32
#define LLFORMAT "I64"
#else
#define LLFORMAT "ll"
#endif

using namespace std;

constexpr int mod = 998244353, proot = 3;
constexpr int maxn = (1 << 19) | 10; // deal with n <= 1e5

inline int pow_mod(int x, int n) { int y = 1; while(n) { if(n & 1) y = (long long) y * x % mod; x = (long long) x * x % mod; n >>= 1; } return y; }

namespace Poly {
    struct poly {
        vector<int> p;

        poly() { p.clear(); }
        poly(const vector<int> &q): p(q) {}
        poly(int n, int *a) { p.resize(n); for (int i = 0; i < n; ++i) p[i] = a[i]; }

        inline int size() const { return p.size(); }
        inline void resize(int n) { p.resize(n); return; }
    };

    void dft(int n, int *a, bool rev) {
        for (int i = 0, j = 0; i < n; ++i) {
            if(i < j) swap(a[i], a[j]);
            for (int k = n >> 1; (j ^= k) < k; k >>= 1);
        }
        for (int hl = 1, l = 2; l <= n; hl = l, l <<= 1) {
            int wn = pow_mod(proot, (mod - 1) / l); if(rev) wn = pow_mod(wn, mod - 2);
            for (int i = 0; i < n; i += l) for (int j = 0, *x = a + i, *y = x + hl, w = 1; j < hl; ++j, ++x, ++y, w = (long long) w * wn % mod) {
                int t = (long long) *y * w % mod; *y = (*x - t) % mod; *x = (*x + t) % mod;
            }
        }
        if(rev) { int inv = pow_mod(n, mod - 2); for (int i = 0; i < n; ++i) a[i] = (long long) a[i] * inv % mod; }
        return;
    }

    poly operator * (const poly &A, const poly &B) {
        static int a[maxn], b[maxn];
        int n = A.size(), m = B.size();
        if(n < 10 || m < 10 || n + m - 1 < 80) {
            int N = n + m - 1; static int c[maxn]; memset(c, 0, sizeof(c[0]) * N);
            for (int i = 0; i < n; ++i) for (int x = A.p[i], j = 0; j < m; ++j) c[i + j] = ((long long) x * B.p[j] + c[i + j]) % mod;
            return poly(N, c);
        }
        int N = 1; while(N < n + m - 1) N <<= 1;
        for (int i = 0; i < N; ++i) a[i] = i < n ? A.p[i] : 0, b[i] = i < m ? B.p[i] : 0;
        dft(N, a, 0); dft(N, b, 0); for (int i = 0; i < N; ++i) a[i] = (long long) a[i] * b[i] % mod; dft(N, a, 1);
        return poly(n + m - 1, a);
    }

    poly operator * (const int &a, const poly &B) {
        static int b[maxn]; int n = B.size();
        for (int i = 0; i < n; ++i) b[i] = (long long) a * B.p[i] % mod;
        return poly(n, b);
    }

    poly operator - (const poly &A, const poly &B) {
        static int a[maxn]; int n = A.size(), m = B.size(), N = max(n, m);
        for (int i = 0; i < N; ++i) a[i] = ((i < n ? A.p[i] : 0) - (i < m ? B.p[i] : 0)) % mod;
        return poly(N, a);
    }

    poly operator + (const poly &A, const poly &B) {
        static int a[maxn]; int n = A.size(), m = B.size(), N = max(n, m);
        for (int i = 0; i < N; ++i) a[i] = ((i < n ? A.p[i] : 0) + (i < m ? B.p[i] : 0)) % mod;
        return poly(N, a);
    }

    poly inv(int n, const poly &A) { // A(x)^-1 mod x^n
        if(n == 1) { return poly(vector<int>{pow_mod(A.p[0], mod - 2)}); }
        static poly B0, B, tA;
        B0 = inv(n + 1 >> 1, A);
        if(A.size() < n) tA = A, tA.resize(n);
        else tA.p.clear(), tA.p.insert(tA.p.end(), A.p.begin(), A.p.begin() + n);
        B = 2 * B0 - B0 * B0 * tA;
        B.resize(n);
        return B;
    }

    poly rev(const poly &A) { static poly B; B = A; reverse(B.p.begin(), B.p.end()); return B; }

    poly operator / (const poly &A, const poly &B) {
        static poly rA, rB, C, D; int n = A.size(), m = B.size();
        rA = rev(A); rB = rev(B); C = inv(n - m + 1, rB); D = C * rA; D.resize(n - m + 1);
        return rev(D);
    }

    poly operator % (const poly &A, const poly &B) { static poly D, ret; D = A / B; ret = A - B * D; ret.resize(B.size() - 1); return ret; }

    poly Rem(const poly &A, const poly &B, const poly &D) { static poly ret; ret = A - B * D; ret.resize(B.size() - 1); return ret; }

    poly dao(const poly &A) {
        static int a[maxn]; int n = A.size() - 1;
        for (int i = 0; i < n; ++i) a[i] = (long long) A.p[i + 1] * (i + 1) % mod;
        return poly(n, a);
    }

    poly ji(const poly &A) {
        static int a[maxn]; int n = A.size();
        for (int i = 1; i <= n; ++i) a[i] = (long long) A.p[i - 1] * pow_mod(i, mod - 2) % mod; a[0] = 0;
        return poly(n + 1, a);
    }

    poly ln(int n, const poly &A) { // ln(A(x)) mod x^n
        static poly B, C;
        C = dao(A); B = inv(n, A);
        B = B * C; B = ji(B); B.resize(n);
        return B;
    }

    poly exp(int n, const poly &A) { // e^A(x) mod x^n
        if(n == 1) { return poly(vector<int>{1}); }
        static poly B0, B, C, tA;
        B0 = exp(n + 1 >> 1, A);
        if(A.size() < n) tA = A, tA.resize(n);
        else tA.p.clear(), tA.p.insert(tA.p.end(), A.p.begin(), A.p.begin() + n);
        C = ln(n, B0);
        B = B0 * (poly(vector<int>{1}) - C + tA);
        B.resize(n);
        return B;
    }

    poly pow_mod(int n, const poly &A, int N) { // A(x)^N mod x^n
        static poly B;
        int na = A.size(), i = 0;
        while(i < na && A.p[i] == 0) ++i;
        if(i == na || (long long) N * i >= n) return poly(vector<int>(n, 0));
        if(i) {
            B.p.clear();
            for (int j = i; j < na; ++j) B.p.push_back(A.p[j]);
            static poly C;
            C = pow_mod(n - N * i, B, N);
            B.p = vector<int>(N * i, 0);
            B.p.insert(B.p.end(), C.p.begin(), C.p.end());
            return B;
        }
        if(A.p[0] != 1) {
            int t = ::pow_mod(A.p[0], mod - 2), s = ::pow_mod(A.p[0], N); B.resize(na);
            for (int i = 0; i < na; ++i) B.p[i] = (long long) A.p[i] * t % mod;
            B = pow_mod(n, B, N);
            for (int i = 0; i < n; ++i) B.p[i] = (long long) s * B.p[i] % mod;
            return B;
        }
        B = ln(n, A);
        for (int i = 0; i < n; ++i) B.p[i] = (long long) B.p[i] * N % mod;
        return exp(n, B);
    }
}

using namespace Poly;

struct matrix {
	vector<vector<int> > num;

	matrix() {}
	matrix(int n, int m) { num = vector<vector<int> >(n, vector<int>(m, 0)); }
	matrix(const vector<vector<int> > &a) { num = a; }

	friend matrix operator *(const matrix &a, const matrix &b) {
		int n = a.num.size(), m = b.num[0].size(), l = b.num.size();
		matrix c(n, m);
		for (int i = 0; i < n; ++i) for (int j = 0; j < m; ++j) for (int k = 0; k < l; ++k) c.num[i][j] = ((long long) a.num[i][k] * b.num[k][j] + c.num[i][j]) % mod;
		return c;
	}
};

matrix I(int n) {
	matrix ret(n, n);
	for (int i = 0; i < n; ++i) ret.num[i][i] = 1;
	return ret;
}

int n, m;

namespace BF {
	constexpr int mod = 998244353, maxn = 2010;

	int dp[maxn][maxn];

	void main() {
		dp[1][1] = 1;
		for (int i = 1; i < n; ++i) for (int t, j = 1; j <= m; ++j) if(t = dp[i][j]) {
			for (int k = 1; k <= min(j + 1, m); ++k) dp[i + 1][k] = (dp[i + 1][k] + t) % mod;
		}
		int ans = 0;
		for (int j = 1; j <= m; ++j) ans = (ans + dp[n][j]) % mod;
		cout << ans << endl;
		return;
	}
}

int main() {
	auto pow_mod = [&](int x, int n) {
		int y = 1;
		while(n) {
			if(n & 1) y = (long long) y * x % mod;
			x = (long long) x * x % mod;
			n >>= 1;
		}
		return y;
	};

	cin >> n >> m;
	if(m > n) { cout << "0\n"; return 0; }
	if(n <= 8 && m <= 8) {
		BF::main();
		return 0;
	}
	int N = 1;
	while(N <= m || N <= n) N <<= 1;
	int *a = new int[N], *b = new int[N];
	for (int wn = pow_mod(proot, (mod - 1) / N), i = 0, x = 1; i < N; ++i, x = (long long) x * wn % mod) {
		matrix trans(vector<vector<int> >{vector<int>{0, 1}, vector<int>{-x, 1}}), res = I(2);
		for (int n = m; n; ) {
			if(n & 1) res = res * trans;
			trans = trans * trans;
			n >>= 1;
		}
		a[i] = (res.num[0][0] + res.num[0][1]) % mod;
		b[i] = (res.num[1][0] + res.num[1][1]) % mod;
	}
	dft(N, a, 1);
	dft(N, b, 1);
	poly A(N, a), B(N, b);
	B = inv(N, B);
	A = A * B;
	cout << (A.p[n] + mod) % mod << endl;
	return 0;
}

做法2

仍然考虑求二叉树个数。打表发现其为从 \((0,\ 0)\) 每次向左或向上走到 \((n,\ n)\) 不经过 \(y\ =\ x\ +\ 1\) 及以上的点和 \(y\ =\ x\ -\ m\ -\ 1\) 及以下的点的方案数。

考虑容斥,先计算都经过的,再减去经过 \(y\ =\ x\ +\ 1\)\(y\ =\ x\ -\ m\ -\ 1\) 的,再加上先经过 \(y\ =\ x\ +\ 1\) 再经过 \(y\ =\ x\ -\ m\ -\ 1\) 的和先经过 \(y\ = \ x\ -\ m\ -\ 1\) 再经过 \(y\ =\ x\ +\ 1\)的,再减去经过 \(y\ =\ x\ +\ 1\ \to\ y\ =\ x\ -\ m\ -\ 1\ \to\ y\ =\ x\ +\ 1\)\(y\ =\ x\ -\ m\ -\ 1\ \to\ y\ =\ x\ +\ 1\ \to\ y\ = \ x\ -\ m\ -\ 1\) 的......

计算经过 \(y\ =\ x\ +\ 1\ \to\ y\ =\ x\ -\ m\ -\ 1\ \to\ y\ =\ x\ +\ 1\) 的方案数考虑 \(n\times n\) 的网格上有两条线,先沿 \(y\ =\ x\ +\ 1\) 翻折,再沿 \(y\ =\ x\ -\ m\ -\ 1\) 翻折第一次后的直线翻折,再沿 \(y\ =\ x\ +\ 1\) 翻折两次后的直线翻折后从 \((0,\ 0)\)\((n,\ n)\) 翻折三次后点的位置的路径数。

时间复杂度 \(O(n)\)

代码

#include <bits/stdc++.h>

#ifdef __WIN32
#define LLFORMAT "I64"
#else
#define LLFORMAT "ll"
#endif

using namespace std;

constexpr int mod = 998244353;

int main() {
	constexpr auto pow_mod = [&](int x, int n) {
		int y = 1;
		while(n) {
			if(n & 1) y = (long long) y * x % mod;
			x = (long long) x * x % mod;
			n >>= 1;
		}
		return y;
	};

	int n, m;
	cin >> n >> m;
	if(n < m) { cout << "0\n"; return 0; }
	int N = 2 * n;
	vector<int> fac(N + 1), ifac(N + 1);
	fac[0] = 1;
	for (int i = 1; i <= N; ++i) fac[i] = (long long) fac[i - 1] * i % mod;
	ifac[N] = pow_mod(fac[N], mod - 2);
	for (int i = N; i; --i) ifac[i - 1] = (long long) ifac[i] * i % mod;
	auto C = [&](int n, int m) { return n < m || m < 0 ? 0 : (long long) fac[n] * ifac[m] % mod * ifac[n - m] % mod; };

	int ans = C(N, n);
	for (int i = 0, x = n + 1, y = n + m + 1; x <= N || y <= N; i ^= 1) {
		if(i) {
			ans = ((long long) ans + (long long) C(N, x) + (long long) C(N, y)) % mod;
			x += 1;
			y += m + 1;
		}
		else {
			ans = ((long long) ans - (long long) C(N, x) - (long long) C(N, y)) % mod;
			y += 1;
			x += m + 1;
		}
	}
	cout << (ans + mod) % mod << endl;
	return 0;
}
posted @ 2018-11-19 11:19  King_George  阅读(408)  评论(0编辑  收藏  举报