【模板】子集卷积 / 集合幂级数 1
日记 2024.4.7:子集卷积
为什么是集合幂级数 \(1\) 呢,因为有 【笔记】集合幂级数 2 - caijianhong - 博客园
记号
\(F(x)=\sum_{S\subseteq [n]}f_Sx^{S}\) 是一个集合幂级数,其中 \([n]=\{1, 2, \cdots, n\}\),\(f_S\) 是一个数组,然后写成像生成函数(其实是“形式幂级数”)的形式,\(x^S\) 的这个指数是没意义的。
FWT / FMT
即两个集合幂级数的并、交、异或卷积运算。
https://www.cnblogs.com/caijianhong/p/template-fwt.html
子集卷积
感觉集合幂级数的乘法卷积会好一点?
给出 \(F(x)=\sum_{S\subseteq [n]}f_Sx^{S}\) 与 \(G(x)=\sum_{S\subseteq [n]}g_Sx^{S}\),想求的是 \(H(x)=\sum_{S\in[n]}\sum_{T\subseteq S}f_{T}g_{S\setminus T}x^S\)。就和 \(H(x)=\sum_{S\in[n]}\sum_{T\cup R = S, T\cap R = \varnothing}f_{T}g_{R}x^S\) 一个意思,都是要求卷起来的集合不交。
我们已经掌握的集合幂级数的并运算,是能求 \(\sum_{S\subseteq[n]}\sum_{T\cup R}f_{T}g_{R}x^S\),可能 \(T, R\) 有交。但是我们考虑如果它们有交,那么 \(|T|+|R|>|T\cup R|\),我们发现只要控制好 \(|T|+|R|=|T\cup R|\) 就能使其无交。
引入新的 \(y\) 记号,将 \(F(x)\) 写为 \(\hat F(x, y) = \sum_{S\subseteq [n]}f_Sx^{S}y^{|S|}\)(也许可以叫作……集合占位幂级数)。然后考虑求 \(\hat F(x, y)\) 与 \(\hat G(x, y)\) 的,\(x\) 上做集合的并,\(y\) 上做整数加法,求卷积,得到 \(\hat H(x, y)=\sum_{S\subseteq [n]}\sum_{k=0}^n\hat h_{S, k}x^Sy^k\)。我们转回去,使得 \(h_S=[x^Sy^{|S|}]\hat H(x, y)\) 即可,其余的项都不要了。
实现时对 \(y\) 这一维暴力加法卷积,每次都是将 \(y^k\) 的系数,即一个集合幂级数,做并卷积。但是不需要每次都如此暴力,对于同一个 \(y\),我们只用做一次 FWT,然后过程中不断点乘、加法,最后统一 FWT 回去。这是基于 FWT 是一个线性的运算。
一共做了 \(O(n)\) 次 \(O(n2^n)\) 的 FWT,所以复杂度为 \(O(n^22^n)\)。
求导公式复习
后面要反复使用,默认所求导的元是 \(x\),即 \(F'(x)\) 这个记号说的是 \(\dfrac{\mathrm d}{\mathrm dx}F(x)\),这个非常重要!
子集 exp
原理
\(\exp x=\sum_{i=0}x^i/(i!)\),这里 \(x\) 可以是任何东西,只要幂运算有定义,所以考虑扔个集合幂级数进去,定义乘法是子集卷积。
\(\exp F(x)=\sum_{i=0}(F(x))^i/(i!)\)?很深刻的东西。不过我们最好先声明一下 \([x^\varnothing]F(x)=0\),不然这个级数不收敛,有可能无意义。这样,我们规定一个全新的东西,我们算 \(\exp F(x)-1\) 这个东西,即 \(\sum_{i=1}(F(x))^i/(i!)\),这样避开收敛之类的讨论。
考虑怎么算,将 \(F(x)\) 改成 \(\hat F(x, y)\),然后求 \(\exp \hat F(x, y)\)。以 \(y\) 为主元(?),就是对 \(y\) 那一维算 \(\exp\),将集合幂级数当作 \(y\) 的系数,然后集合幂级数做并卷积。对 \(x\) 维做 \(O(n)\) 次 FWT 之后,\(x\) 这一维做的就是加法、点乘,每一个数字都没什么关系。所以不妨做完 FWT 之后回去枚举 \(S\),以 \(x\) 为主元,将 \(y\) 这一维拿出来,一共 \(O(n)\) 个数字,算 \(\exp\)。然后再将 \(\exp\) 算的东西按顺序塞回去,再 IFWT 还原。
现在我们只需要关心形式幂级数的 \(\exp\) 怎么算。原来的“形式”是一个 FWT 之后的数组,我们枚举每一位拆成 \(O(2^n)\) 次 \(\exp\),现在“形式”就是一些 modint 了。
计算
\(O(n\log n)\) 的 \(\exp\) 常数比较厉害。因为 \(n\) 太小了,考虑 \(O(n^2)\) 能否求 \(\exp\)?
设 \(G(x)=\exp F(x)-1\),注意 \(G(0)=0\)。然后两边求导:
有一个 \(G\) 的求导,提取 \([x^{n-1}]\) 系数:(因为非常自闭的 \(G(0)=0\) 所以不得不特判一下了)
然后这是可以计算的。\(O(n^2)\) 递推每一项系数。
组合意义
想必是选出若干不相交集合组合成 \(S\) 的方案数。
子集 ln
原理
\(\ln 0\) 无意义,这使人很自闭,我们只好去考虑 \(\ln(1+x)=\sum_{i=1}(-1)^{i+1}x^i/i\)。这个东西不太好看啊!但是这个定义是很好的。我们再次声明 \([x^\varnothing]F(x)=0\)。
然后复读一遍 \(\exp\) 的过程,考虑 \(O(n^2)\) 的 \(\ln\):
计算
设 \(G(x)=\ln(1+F(x))\)。注意 \(G(0)=0\)。两边同时求导:
两边提取 \([x^{n-1}]\) 项系数:
注意这里不是出现了两次 \(ng_n\) 而是有一个 \(ng_n\) 的系数为 \(f_{n-n}=0\)。
组合意义
想必是将 \(S\) 划分为若干不相交外面无序内部无序集合的方案数。(?对这里有疑问)
子集 k-exp
即求出 \(\sum_{i=1}^K(F(x))^i/(i!)\)。考虑记为 \(G\),然后注意 \(G(0)=0\)。然后两边求导:
假如已经求出 \(H(x)=(F(x))^K\)。那么提取 \([x^{n-1}]\):
那么怎么求 \(H(x)=(F(x))^K\) 呢?两边求导:
提取 \([x^{n-1}]\) 项系数?
竟然能算?\(O(n^2)\)。这里截断到 \(x^{n+1}\) 即可,因为我们只需要答案的 \([x^n]\)。注意求快速幂时使 \(f\) 有常数项,是一些简单位移。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
#define popcount __builtin_popcount
typedef long long LL;
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
static constexpr int mod = umod;
unsigned v;
modint() : v(0) {}
template <class T, must_int<T> = 0>
modint(T x) {
x %= mod;
v = x < 0 ? x + mod : x;
}
modint operator+() const { return *this; }
modint operator-() const { return modint() - *this; }
friend int raw(const modint &self) { return self.v; }
friend ostream& operator<<(ostream& os, const modint &self) {
return os << raw(self);
}
modint &operator+=(const modint &rhs) {
v += rhs.v;
if (v >= umod) v -= umod;
return *this;
}
modint &operator-=(const modint &rhs) {
v -= rhs.v;
if (v >= umod) v += umod;
return *this;
}
modint &operator*=(const modint &rhs) {
v = 1ull * v * rhs.v % umod;
return *this;
}
modint &operator/=(const modint &rhs) {
assert(rhs.v);
return *this *= rhs.inv();
}
modint inv() const {
if (v >= 2e7) return qpow(*this, mod - 2);
static vector<modint> inv;
if (inv.empty()) inv = {0, 1};
while (inv.size() <= v) {
int n = inv.size();
inv.resize(n << 1);
for (int i = n; i < n << 1; i++) inv[i] = -(mod / i) * inv[mod % i];
}
return inv[v];
}
template <class T, must_int<T> = 0>
friend modint qpow(modint a, T b) {
modint r = 1;
for (; b; b >>= 1, a *= a)
if (b & 1) r *= a;
return r;
}
friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
bool operator==(const modint &rhs) const { return v == rhs.v; }
bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
typedef modint<998244353> mint;
void fort(vector<mint>& a, int op) {
int n = a.size();
for (int k = 1, len = 2; len <= n; k <<= 1, len <<= 1) {
for (int i = 0; i < n; i += len) {
for (int j = 0; j < k; j++) {
a[i + j + k] += a[i + j] * op;
}
}
}
}
int n, m, K;
vector<mint> _qpow(const vector<mint>& f, int K, int lim) {
vector<mint> h(lim);
h[0] = 1;
for (int i = 1; i <= K; i++) h[0] *= f[0];
for (int i = 1; i < h.size(); i++) {
mint rhs = 0;
for (int j = 1; j <= i; j++) rhs += K * j * f[j] * h[i - j];
for (int j = 1; j < i; j++) rhs -= j * h[j] * f[i - j];
h[i] = rhs / i / f[0];
}
return h;
}
vector<mint> qpow(const vector<mint>& f, int K) {
int pos = 0;
while (pos < f.size() && f[pos] == 0) pos++;
if (pos * K >= f.size()) return vector<mint>(n, 0);
auto ret = _qpow(vector<mint>(f.begin() + pos, f.end()), K, f.size() - pos * K);
ret.insert(ret.begin(), pos * K, 0);
return ret;
}
vector<mint> kexp(const vector<mint>& f, const vector<mint>& h, int K) {
vector<mint> g(f.size());
g[0] = 0;
mint fac = 1;
for (int i = 1; i <= K; i++) fac *= i;
mint ifac = 1 / fac;
for (int i = 1; i < f.size(); i++) {
mint rhs = 0;
for (int j = 1; j <= i; j++) rhs += j * f[j] * (g[i - j] + (i == j) - h[i - j] * ifac);
g[i] = rhs / i;
}
return g;
}
int main() {
#ifndef LOCAL
cin.tie(nullptr)->sync_with_stdio(false);
#endif
cin >> n >> m >> K;
vector<vector<mint>> a(n + 1, vector<mint>(1 << n));
for (int i = 1; i <= m; i++) {
int x;
cin >> x;
a[popcount(x)][x] += 1;
}
for (int i = 0; i <= n; i++) fort(a[i], 1);
for (int S = 1; S < 1 << n; S++) {
vector<mint> b(n + 1);
for (int i = 0; i <= n; i++) b[i] = a[i][S];
auto h = qpow(b, K);
auto c = kexp(b, h, K);
for (int i = 0; i <= n; i++) a[i][S] = c[i];
}
fort(a[n], -1);
cout << a[n][(1 << n) - 1] << endl;
return 0;
}
本文来自博客园,作者:caijianhong,转载请注明原文链接:https://www.cnblogs.com/caijianhong/p/18120012