AT-abc214-g题解

题目描述

给定两个排列 \(p, q\),要求统计满足 \(\forall i, r_i\not= p_i, r_i \not= q_i\) 的排列 \(r\) 的数量。对 \(1000000007\) 取模

数据范围 \(n \le 3000\)

solution

本题要求统计数量,反正我想了半天没想到怎么正向统计(bushi),因此我们考虑容斥。

\(h_i\) 为只看其中 \(i\) 条限制,这 \(i\) 条限制都不满足的情况数。因此答案即为 \(\sum \limits_{i=0}^{n}(-1)^ih_i\)

现在问题转变为了如何求这个 \(h\)

这里有一个及其巧妙的转化:从 \(p_i\)\(q_i\) 连边,这样一定会出现至少一个环。第 \(i\) 个位置填 \(r_i\) 就意味着第 \(i\) 条边对应着第 \(r_i\) 个点。

而我们考虑不满足要求的情况,就是这条边正好对应着它的起点或者终点。而没有对应的情况我们就可以直接把这条边从图中删去。(注意:下文中所说的选择某条边就意味着不删去这条边)

在一个大小为 \(n\) 的环中,如果我们删去 \(i\) 条边,我们就会剩下 \(i + 1\) 个连通块,如果我们在这个环上不删去任何一条边,那么就意味着这个环上每条边都对应着自己的起点或者终点。这时有两种情况:全都对应自己的起点和全都对应自己的终点,一共两种情况。

而当出现链时,我们考虑这条链上边选择点的方案数。我们可以枚举每个点,这个点左边的边全部选起点,右边的点全部选终点,这样一条链的方案数就是这条链的大小。

图示,相同颜色为对应关系

\(dp[n, m]\) 为一个大小为 \(n\)*删去 \(m\) 条边的方案数,我们可以枚举第一个点所在链的大小 \(i\),便可推出方程 \(dp_{n, m}=\sum \limits _{i=1}^{n}dp_{n-i,m-1}i\)。这样就求出了链上的情况统计。

我们再考虑环上的统计。设 \(f[n,m]\) 为一个大小为 \(n\)删去 \(m\) 条边的方案数。根据上面的分析,我们可以得到边界条件 \(f[n,0]=2\)。我们也可以像上面一样,枚举第一个点所在的连通块的大小,我们可以得到转移方程 \(f_{n, m}=\sum \limits _{i=1}^{n}i^2dp_{n-i,m-1}\)


关于上面转移方程中 \(i^2\) 的解释(因为我一开始也没搞懂)

我们枚举一号点所在连通块的大小 \(i\),因为删去了边,所以这个连通块一定是一条链。而这条链的选择有 \(i\) 种,大小为 \(i\) 的链的方案数也是有 \(i\) 种的。因此根据乘法原理,应该是 \(i^2\)


我们统计了一个环的情况,我们再考虑多个环。我们设 \(g[n,m]\) 为前 \(n\) 个环删去 \(m\) 条边的方案数。枚举最后一个环删去多少条边,我们能得到转移方程 \(g_{n, m}=\sum \limits _{i=1}^{v_n}g_{n-1,m-i}f_{v_n,i}\)\(v_n\) 为第 \(n\) 个环的大小)。

我们已经求出了所有环删去一些边的方案数,这时我们再考虑回去,我们会发现,删去一条边就意味着这条边不对应它的起点或终点,也就意味着它在 \(r\) 排列中合法,那么减去这些删去的边,剩下的就是不满足要求的个数 \(h\) 了。也就是 \(h_i=g_{cnt,n-i}\)

这时我们看似应该已经做完这题了,但是我们可能还会有一些问题。我们可以想一下哪里有问题。

剩下的问题就是数据范围了,我们可以注意到我们上面写的所有转移方程好像都是 \(O(n^3)\) 的(枚举 \(n, m, i\)),这必然无法通过 \(n \le 3000\) 的数据。所以我们需要考虑优化。

我们再把上面的那几个转移方程拿下来一个一个分析。

\[dp_{n, m}=\sum \limits _{i=1}^{n}dp_{n-i,m-1}i \]

这个很明显是 \(O(n^3)\) 的方程。我们考虑如何快速求出后面这一项。首先我们很容易就能注意到:这是一个极其明显的卷积形式。我们自然可以想到 NTT,但是首先复杂度只能到 \(O(n^2\log n)\),可能也过不去。而且取模数也不是通常的 \(998244353\),而是 \(1000000007\),只能使用 MTT,而 MTT 的常数...就看下面 MTT 的代码吧。

NTT(A[0], G1, 1), NTT(B[0], G1, 1);
for(int i = 0; i < lim; i ++ ) a1[i] = (ll)A[0][i] * B[0][i] % G1;
NTT(a1, G1, -1);

NTT(A[1], G2, 1), NTT(B[1], G2, 1);
for(int i = 0; i < lim; i ++ ) a2[i] = (ll)A[1][i] * B[1][i] % G2;
NTT(a2, G2, -1);

NTT(A[2], G3, 1), NTT(B[2], G3, 1);
for(int i = 0; i < lim; i ++ ) a3[i] = (ll)A[2][i] * B[2][i] % G3;
NTT(a3, G3, -1);

for(int i = 0; i < lim; i ++ )
{
    __int128 M = (__int128)G1 * G2 * G3;
    ans[i] = ((__int128)a1[i] * M / G1 * inv(M / G1, G1) % M
                + (__int128)a2[i] * M / G2 * inv(M / G2, G2) % M + 
                (__int128)a3[i] * M / G3 * inv(M / G3, G3) % M) % M % p;
}

9 遍 NTT 根本跑不动好吧......

因此我们考虑如何才能把转移弄成 \(O(1)\)。我们可以类似前缀和的思想,下面有一个引理可以帮助我们完成构建前缀和的任务。


引理:一个常数数列(例如 \(1, 1, 1, 1\cdots\)) 的 \(k\) 阶前缀和是一个 \(k\) 次多项式,而一个 \(k\) 次多项式的 \(k\) 阶差分为常数数列。

这里就先显然带过,我们只需要数列 \(1, 1, 1, 1\cdots\) 的二阶(下面还有三阶)前缀和。

(其实这玩意也有应用,例如这题


因此我们考虑如何拼凑出来转移方程前面的 \(i\)。我们可以记录两个变量 \(sum_1\)\(sum_2\),前者存储上面所说的常数数列,后者存储 \(sum_1\) 的前缀和数列(注意这里说的数列的意思是 \(dp\) 数组的系数)。然后 \(sum_2\) 就是我们要转移到的值。

画图理解就是:

同理,\(f_{n, m}=\sum \limits _{i=1}^{n}i^2dp_{n-i,m-1}\) 也可以这样优化。\(i^2\) 是个二次多项式,我们就需要 \(3\) 个变量来计算它。

\(g_{n, m}=\sum \limits _{i=1}^{v_n}g_{n-1,m-i}f_{v_n,i}\) 其实根本不需要优化。我们考虑下把这三个循环写出来是什么样子:

for i = 1 ~ cnt
    for j = 1 ~ n
        for k = 1 ~ v[i]
            转移

我们可以发现复杂度即为 \(O(n \sum v)\),就是 \(O(n^2)\)

因此现在三个转移方程复杂度都为 \(O(n^2)\) 了,我们就可以愉快的 AC 了!

代码这里放一下吧,因为这题算思维题里面码量大的了,而且边界也很多。

#define LOCAL
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> PII;

const int mod = 1e9 + 7, N = 3010;
ll f[N][N], g[N][N], dp[N][N], h[N];
ll fact[N];
ll p[N], q[N];
ll ne[N], la[N];
bitset<N> vis;
vector<int> d;
ll n, ans;

int main()
{
    n = read();

    fact[0] = 1;
    for(int i = 1; i <= n + 1; i ++ ) 
    {
        fact[i] = fact[i - 1] * i % mod;
        dp[1][i] = i;
    }

    for(int i = 2; i <= n + 1; i ++ )
    {
        ll sum1 = 0, sum2 = 0;
        sum1 = sum2 = dp[i - 1][i - 1];
        for(int j = i; j <= n + 1; j ++ )
        {
            dp[i][j] = sum2;
            sum1 = (sum1 + dp[i - 1][j]) % mod;
            sum2 = (sum2 + sum1) % mod;
        }
    }

    for(int i = 1; i <= n + 1; i ++ )
    {
        f[1][i] = i * i % mod;
        f[0][i] = 2;
    }

    for(int i = 2; i <= n + 1; i ++ )
    {
        ll sum1 = dp[i - 1][i - 1];
        ll sum2 = dp[i - 1][i - 1];
        ll sum3 = dp[i - 1][i - 1];
        for(int j = i; j <= n + 1; j ++ )
        {
            f[i][j] = sum3;
            sum2 = (sum2 + sum1 * 2) % mod;
            sum1 = (sum1 + dp[i - 1][j]) % mod;
            sum2 = (sum2 + dp[i - 1][j]) % mod;
            sum3 = (sum3 + sum2) % mod;
        }
    }

    for(int i = 1; i <= n; i ++ ) p[i] = read();
    for(int i = 1; i <= n; i ++ ) 
    {
        q[i] = read();
        ne[p[i]] = q[i];
        la[q[i]] = p[i];
    }

    for(int i = 1; i <= n; i ++ )
    {
        if(!vis[p[i]])
        {
            int num = 1;
            int st = p[i];
            vis[st] = true;
            while(ne[st] != p[i])
            {
                st = ne[st];
                num ++;
                vis[st] = true;
            }
            d.push_back(num);
        }
    }

    ll sum = 0;
    for(int i = 0; i < d.size(); i ++ )
    {
        ll v = d[i];
        sum += v;
        if(i == 0)
        {
            if(v == 1)
            {
                g[i][0] = 1;
                g[i][1] = 1;
                continue;
            }
            g[i][0] = 2;
            for(int j = 1; j <= v; j ++ )
                g[i][j] = f[j][v];
            continue;
        }
        if(v == 1)
        {
            for(int j = 0; j <= sum; j ++ )
            {
                g[i][j] = g[i - 1][j];
                if(j) g[i][j] = (g[i][j] + g[i - 1][j - 1]) % mod;
            }
            continue;
        }
        for(int j = 0; j <= sum; j ++ )
            for(int k = max(j - v, 0ll); k <= sum - v; k ++ )
                g[i][j] = (g[i][j] + g[i - 1][k] * f[j - k][v] % mod) % mod;
    }

    for(int i = 0; i <= n; i ++ )
        h[n - i] = g[d.size() - 1][i];

    for(int i = 0; i <= n; i ++ )
    {
        if(i & 1) ans = (ans - h[i] * fact[n - i] % mod + mod) % mod;
        else ans = (ans + h[i] * fact[n - i] % mod) % mod;
    }
    
    cout << ans << endl;

    return 0;
}
posted @ 2023-07-10 18:38  crimson000  阅读(8)  评论(0编辑  收藏  举报