CF1257G - Divisor Set(分治,fft)

题目

给定\(n\)个质数,设\(S\)代表这些质数的乘积所代表的数的所有正因子。对于一个有限正整数集合\(D\),如果任意\(a\in D\)\(b\in D\)\(a\neq b\),都满足\(a\nmid b\),那么\(D\)就是好的。问\(S\)最大好子集是多大。

题解

假如这\(n\)个质数互不相等,显然答案就是\(C(n,\frac{n}{2})\)。因为这样取的数互相不会整除彼此,而且是\(C(n,k)\)中最大的数。那么如果有相等的数呢?类比一下,答案就是所有质因子个数(相同也算)为\(\frac{n}{2}\)的数的集合。它们显然互不整除,而且也是所有质因子个数相等的数的集合中最大的那个。

这样问题就是一个简单背包dp了。一共有\(m\)个质数,设第\(i\)种质数\(p_i\)\(c_i\)个,\(dp[pos][sum]\)代表前\(i\)种质数,剩余质因子个数为\(sum\)时的方案数。转移:

\[dp[i][j]=\sum\limits_{k=0}^{c_i}{dp[i-1][j-k]} \]

可以优化成\(O(n^2)\),用前缀和。btw,如果这里直接用\(dp\)存储前缀和,有\(dp[0][k]=1\),这样最后答案应该是\(dp[m][\frac{n}{2}]-dp[m][\frac{n}{2}-1]\);但是如果只令\(dp[0][0]=1\),即\(dp[0]\)是差分的形式,然后直接按照前缀和优化来存储和计算\(dp\),最后答案直接就是\(dp[m][\frac{n}{2}]\),因为一开始\(dp[0]\)就是差分。

显然会超时。如果用生成函数的思想,\(p_i\)\(c_i\)个,那么就构造多项式\((x^0+x^1+...+x^{c_i})\),然后将所有\(p_i\)的多项式乘起来,最后\(x^{\frac{n}{2}}\)的系数就是答案。直接分治+fft。

vector作为返回值时是以作为右值返回,时间复杂度为\(O(1)\),常数不会太大,非常快,不用担心超时。将vector作为返回值问题不大。
时间复杂度\(O(n\log^2{n})\)

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 3e5 + 10;
const int M = 998244353;

int rev[N];
inline ll qpow(ll a, ll b, ll m) {
    ll res = 1;
    while (b) {
        if (b & 1)
            res = (res * a) % m;

        a = (a * a) % m;
        b = b >> 1;
    }
    return res;
}

void change(vector<ll>& y, int len) { // 蝴蝶变换
    for (int i = 0; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) {
            rev[i] |= len >> 1;
        }
    }
    for (int i = 0; i < len; ++i) {
        if (i < rev[i]) {
            swap(y[i], y[rev[i]]);
        }
    }
    return;
}

void ntt(vector<ll>& y, int len, int on) { // -1逆变换
    change(y, len);
    for (int h = 2; h <= len; h <<= 1) {
        ll gn = qpow(3, (M - 1) / h, M); // 原根为3
        if (on == -1)
            gn = qpow(gn, M - 2, M);
        for (int j = 0; j < len; j += h) {
            ll g = 1;

            for (int k = j; k < j + h / 2; k++) {
                ll u = y[k];
                ll t = g * y[k + h / 2] % M;
                y[k] = (u + t) % M;
                y[k + h / 2] = (u - t + M) % M;
                g = g * gn % M;
            }
        }
    }
    if (on == -1) {
        ll inv = qpow(len, M - 2, M);
        for (int i = 0; i < len; i++) {
            y[i] = y[i] * inv % M;
        }
    }
}

int get(int x) {
    int res = 1;
    while(res < x) {
        res <<= 1;
    }
    return res;
}

int arr[N];
vector<int> num;
vector<int> pre;

vector<ll> solve(int l, int r) {
    if(l == r) {
        vector<ll> f;
        for(int i = 0; i <= num[l]; i++) {
            f.push_back(1);
        }
        return f;
    }
    int mid = (l + r) / 2;
    vector<ll> f = solve(l, mid);
    vector<ll> g = solve(mid + 1, r);
    int nl = f.size(), nr = g.size();
    int len = get(nl + nr - 1);
    f.resize(len, 0);
    g.resize(len, 0);
    ntt(f, len, 1);
    ntt(g, len, 1);
    for(int i = 0; i < len; i++) f[i] = f[i] * g[i] % M;
    ntt(f, len, -1);
    f.resize(nl + nr - 1);
    return f;
}

int main() {
    IOS;
    int n;
    cin >> n;
    for(int i = 1; i <= n; i++) {
        cin >> arr[i];
    }
    sort(arr + 1, arr + 1 + n);
    int cnt = 1;
    num.push_back(0);
    for(int i = 2; i <= n + 1; i++) {
        if(i > n || arr[i] != arr[i - 1]) {
            num.push_back(cnt);
            cnt = 1;
        } else cnt++;
    }
    for(int i = 0; i < num.size(); i++) {
        pre.push_back(num[i]);
        if(i) pre[i] = (pre[i] + pre[i - 1]) % M;
    }
    vector<ll> ans = solve(1, num.size() - 1);
    cout << ans[n / 2] << endl;
}
posted @ 2021-09-28 23:00  limil  阅读(46)  评论(0编辑  收藏  举报