P10380 「ALFR Round 1」D 小山的元力
历时两天,算是搞出来了。
P10380 「ALFR Round 1」D 小山的元力 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
提醒
首先如果你是用 Lucas 定理并用阶乘形式来求组合数的,请判断组合数是否成立,即 \(C_a^b\),\(a\) 是否大于等于 \(b\)。如果小于你将 re 几个点,如果是直接用快速幂求解逆元来做的,恭喜你,你将 WA#20和#46。因为 \(p\) 可能小于 \(n,m\) 或 \(n + m\),导致你求出的阶乘到了\(f[p]\) 时变为 \(0\),使得后面的计算错误。
正题
分析题意,首先对于第 \(i\) 堆,当这堆放这 \(k\) 个元素的时候,无论这是哪一堆,它外面的情况的总数都是一样的,也就是分配剩下 \(n - k\) 个数的情况的数量是一样的。每一堆都如此,那么这堆放 \(k\) 个元素的总贡献就是 \(k \times 总情况数 \times sum\) 。 \(sum = 1! + 2! + \dots + m!\) 因为 sum 是固定的所以可以提出来最后相乘,那么目标就是求所有的 \(k \times 总情况数\)。
设一共有 \(n\) 个相同元素,放 \(m\) 堆(每堆可以不放),\(k\in [0,n]\),那么剩下的就是求每个 \(k\) 对应的总情况数。
设当前堆放 \(k\) 个数,那么就剩下 \(n - k\) 个数要分配到 \(m - 1\) 个空堆(可以不放)。这里可以用隔板法,用 \(m - 2\) 个隔板,把剩下 \(n - k\) 个数分成 \(m - 1\) 块。因为有空堆,而隔板法不能有空的分配,所以可以人为添加 \(m - 1\) 个元素,把情况变成必须放,这时候元素有 \(n - k + m - 1\) 个,空隙(两边不算)有 \(n - k + m - 1 - 1\) 个,即 \(n - k + m - 2\) 个。剩下的就是在这 \(n - k - 2\) 个空隙里面选 \(m - 2\) 个放上隔板,也就是求 \(C_{n - k + m - 2} ^ {m - 2}\)。
求组合数有很多方法,当时看数据范围,我直接用的快速幂求通过阶乘求解,然后因为 \(p\) 的大小寄掉了,也就是上面的提醒。因此这里用 Lucas 定理求解组合数,这样就不会因为 \(p\) 的事情寄掉了,时间上也够。且因为模数 \(p\) 不变,所以可以先预处理阶乘,但是不要先预处理逆元,否则时间复杂度会变成 \(O(p\log n)\) 容易 TLE。动态求解逆元,求出组合数,也就是总情况数。
最后把每种 \(k\) 的总情况数乘上 \(k\) 相加即
而这就是答案 \(ans\)。
最后输出 \(ans \times sum \bmod p\) 即可。
注意当 \(m = 1\) 的时候需要特判
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 11000100;
int n, m, p;
int f[N];
int sum;
int qmi(int a, int k, int p)
{
int res = 1;
while (k)
{
if (k & 1) res = (LL)res * a % p;
k >>= 1;
a = (LL)a * a % p;
}
return res;
}
int in_f(int x) // 逆元
{
return qmi(x, p - 2, p);
}
int C(int a, int b)
{
if (a < b) return 0; // 一定要判断,不然会re,可能本地没问题,但那只是越界不够大而已
return (LL)f[a] * in_f(f[b]) % p * in_f(f[a - b]) % p;
}
int lucas(int a, int b) // 获取C(a, b)组合数
{
if (a < p && b < p) return C(a, b);
return (LL)C(a % p, b % p) * lucas(a / p, b / p) % p;
}
int main()
{
cin >> n >> m >> p;
if (m == 1) // 特判
{
cout << n % p;
return 0;
}
f[0] = 1;
int maxv = max(m, p - 1);
for (int i = 1; i <= maxv; i ++ )
{
f[i] = (LL)f[i - 1] * i % p;
if (i <= m) sum = (sum + f[i]) % p;
}
int ans = 0;
for (int i = 0; i <= n; i ++ )
{
ans = (ans + (LL)i * lucas(n - i + m - 2, m - 2) % p) % p;
}
cout << (LL)ans * sum % p << endl;
return 0;
}