【题解】Luogu-P3301 [SDOI2013]方程
Description
给定方程及不等式组
\[\begin{cases}
x_1+x_2+\cdots+x_n=m\\
\\
x_1\le a_1\\
x_2\le a_2\\
\cdots\\
x_{n1}\le a_{n1}\\
\\
x_{n1+1}\ge a_{n1+1}\\
x_{n1+2}\ge a_{n1+2}\\
\cdots\\
x_{n1+n2}\ge a_{n1+n2}
\end{cases}
\]
请求出该方程组的正整数解的个数 \(\bmod p\)。
- 对于 \(100\%\) 的数据:\(n\le 10^9,n1\le 8,n2\le 8,m\le 10^9,p\le 437367875,T\le 5,1\le a_{1\dots n1+n2}\le m,n1+n2\le n\)
Solution
前置芝士:
-
基础的计数 + 组合知识
-
exLucas
对于形如 \(x_i\ge a_i\) 的,用小奥思路将 \(m\gets m-(a_i-1)\),这时限制就变成了 \(x_i\ge 1\),也就是去掉了限制。
对于形如 \(x\le a_i\) 的,反面考虑 \(x>a_i\),即 \(x\ge a_i+1\),其它无限制的情况数,然后就和上面一样了。
注意一下容斥。
假设当前为 \(nowm\),那么根据插板法,情况数就为 \(C_{nowm-1}^{n-1}\) ,这里直接用 exLucas
即可。
时间复杂度为 \(O(n1!\cdot p\log m)\)。
但是你需要坚信它是跑不满的(
然后 \(70\) 了。
亿些小优化:
- 提前分解 \(p\)
Code
//18 = 9 + 9 = 18.
#include <iostream>
#include <cstdio>
#define Debug(x) cout << #x << "=" << x << endl
#define int long long
using namespace std;
int qpow(int a, int b, int p)
{
int base = a, ans = 1;
while (b)
{
if (b & 1)
{
ans = ans * base % p;
}
base = base * base % p;
b >>= 1;
}
return ans;
}
int fac[10];
int cal(int n, int p, int pos, int pa)
{
if (!n)
{
return 1;
}
int ans = qpow(fac[pos], n / pa, pa);
for (int i = 1; i <= n % pa; i++)
{
if (i % p)
{
ans = ans * i % pa;
}
}
return ans * cal(n / p, p, pos, pa) % pa;
}
int cnt_p(int n, int m, int p)
{
int cnt = 0;
for (int i = p; i <= n; i *= p)
{
cnt += n / i;
}
for (int i = p; i <= m; i *= p)
{
cnt -= m / i;
}
for (int i = p; i <= n - m; i *= p)
{
cnt -= (n - m) / i;
}
return cnt;
}
int x, y;
void exgcd(int a, int b)
{
if (!b)
{
x = 1, y = 0;
return;
}
exgcd(b, a % b);
int tmp = x;
x = y;
y = tmp - a / b * y;
}
int inv(int a, int p)
{
exgcd(a, p);
x = (x % p + p) % p;
return x;
}
int C(int n, int m, int p, int pos, int pa)
{
int a = cal(n, p, pos, pa), b = cal(m, p, pos, pa), c = cal(n - m, p, pos, pa), cnt = cnt_p(n, m, p);
return a * inv(b, pa) % pa * inv(c, pa) % pa * qpow(p, cnt, pa) % pa;
}
int prime[10], a[10], b[10];
int CRT(int n)
{
int m = 1;
for (int i = 1; i <= n; i++)
{
m *= a[i];
}
int ans = 0;
for (int i = 1; i <= n; i++)
{
int mi = m / a[i];
int Mi = inv(mi, a[i]);
ans = (ans + b[i] * mi % m * Mi % m) % m;
}
return ans;
}
int k;
void pre(int p)
{
for (int i = 2; i * i <= p; i++)
{
if (p % i == 0)
{
prime[++k] = i;
a[k] = 1;
while (p % i == 0)
{
a[k] *= i;
p /= i;
}
}
}
if (p > 1)
{
prime[++k] = p;
a[k] = p;
}
for (int i = 1; i <= k; i++)
{
fac[i] = 1;
for (int j = 1; j <= a[i]; j++)
{
if (j % prime[i])
{
fac[i] = fac[i] * j % a[i];
}
}
}
}
int exLucas(int n, int m)
{
if (n < m)
{
return 0;
}
for (int i = 1; i <= k; i++)
{
b[i] = C(n, m, prime[i], i, a[i]);
}
return CRT(k);
}
int p, n, n1, ans;
int w[20];
void dfs(int tot, int bound, int nega, int nowm)
{
// Debug(nowm), Debug(nega);
// Debug(exLucas(nowm - 1, n - 1, p));
ans = (ans + nega * exLucas(nowm - 1, n - 1) + p) % p;
if (tot > n1)
{
return;
}
for (int i = bound; i <= n1; i++)
{
dfs(tot + 1, i + 1, -nega, nowm - w[i]);
}
}
signed main()
{
int t;
scanf("%lld%lld", &t, &p);
pre(p);
while (t--)
{
int n2, m;
scanf("%lld%lld%lld%lld", &n, &n1, &n2, &m);
for (int i = 1; i <= n1 + n2; i++)
{
scanf("%lld", w + i);
}
for (int i = 1; i <= n2; i++)
{
m -= (w[n1 + i] - 1);
}
ans = 0;
dfs(1, 1, 1, m);
printf("%lld\n", ans);
}
return 0;
}