【题解】Luogu-P3301 [SDOI2013]方程

P3301 [SDOI2013]方程

Description

给定方程及不等式组

\[\begin{cases} x_1+x_2+\cdots+x_n=m\\ \\ x_1\le a_1\\ x_2\le a_2\\ \cdots\\ x_{n1}\le a_{n1}\\ \\ x_{n1+1}\ge a_{n1+1}\\ x_{n1+2}\ge a_{n1+2}\\ \cdots\\ x_{n1+n2}\ge a_{n1+n2} \end{cases} \]

请求出该方程组的正整数解的个数 \(\bmod p\)

  • 对于 \(100\%\) 的数据:\(n\le 10^9,n1\le 8,n2\le 8,m\le 10^9,p\le 437367875,T\le 5,1\le a_{1\dots n1+n2}\le m,n1+n2\le n\)

Solution

前置芝士:

  • 基础的计数 + 组合知识

  • exLucas

对于形如 \(x_i\ge a_i\) 的,用小奥思路将 \(m\gets m-(a_i-1)\),这时限制就变成了 \(x_i\ge 1\),也就是去掉了限制。

对于形如 \(x\le a_i\) 的,反面考虑 \(x>a_i\),即 \(x\ge a_i+1\),其它无限制的情况数,然后就和上面一样了。

注意一下容斥。

假设当前为 \(nowm\),那么根据插板法,情况数就为 \(C_{nowm-1}^{n-1}\) ,这里直接用 exLucas 即可。

时间复杂度为 \(O(n1!\cdot p\log m)\)

但是你需要坚信它是跑不满的(

然后 \(70\) 了。

亿些小优化:

  • 提前分解 \(p\)

Code

//18 = 9 + 9 = 18.
#include <iostream>
#include <cstdio>
#define Debug(x) cout << #x << "=" << x << endl
#define int long long
using namespace std;

int qpow(int a, int b, int p)
{
	int base = a, ans = 1;
	while (b)
	{
		if (b & 1)
		{
			ans = ans * base % p;
		}
		base = base * base % p;
		b >>= 1;
	}
	return ans;
}

int fac[10];

int cal(int n, int p, int pos, int pa)
{
	if (!n)
	{
		return 1;
	}
	int ans = qpow(fac[pos], n / pa, pa);
	for (int i = 1; i <= n % pa; i++)
	{
		if (i % p)
		{
			ans = ans * i % pa;
		}
	}
	return ans * cal(n / p, p, pos, pa) % pa;
}

int cnt_p(int n, int m, int p)
{
	int cnt = 0;
	for (int i = p; i <= n; i *= p)
	{
		cnt += n / i;
	}
	for (int i = p; i <= m; i *= p)
	{
		cnt -= m / i;
	}
	for (int i = p; i <= n - m; i *= p)
	{
		cnt -= (n - m) / i;
	}
	return cnt;
}

int x, y;

void exgcd(int a, int b)
{
	if (!b)
	{
		x = 1, y = 0;
		return;
	}
	exgcd(b, a % b);
	int tmp = x;
	x = y;
	y = tmp - a / b * y;
}

int inv(int a, int p)
{
	exgcd(a, p);
	x = (x % p + p) % p;
	return x;
}

int C(int n, int m, int p, int pos, int pa)
{
	int a = cal(n, p, pos, pa), b = cal(m, p, pos, pa), c = cal(n - m, p, pos, pa), cnt = cnt_p(n, m, p);
	return a * inv(b, pa) % pa * inv(c, pa) % pa * qpow(p, cnt, pa) % pa;
}

int prime[10], a[10], b[10];

int CRT(int n)
{
	int m = 1;
	for (int i = 1; i <= n; i++)
	{
		m *= a[i];
	}
	int ans = 0;
	for (int i = 1; i <= n; i++)
	{
		int mi = m / a[i];
		int Mi = inv(mi, a[i]);
		ans = (ans + b[i] * mi % m * Mi % m) % m;
	}
	return ans;
}

int k;

void pre(int p)
{
	for (int i = 2; i * i <= p; i++)
	{
		if (p % i == 0)
		{
			prime[++k] = i;
			a[k] = 1;
			while (p % i == 0)
			{
				a[k] *= i;
				p /= i;
			}
		}
	}
	if (p > 1)
	{
		prime[++k] = p;
		a[k] = p;
	}
	for (int i = 1; i <= k; i++)
	{
		fac[i] = 1;
		for (int j = 1; j <= a[i]; j++)
		{
			if (j % prime[i])
			{
				fac[i] = fac[i] * j % a[i];
			}
		}
	}
}

int exLucas(int n, int m)
{
	if (n < m)
	{
		return 0;
	}
	for (int i = 1; i <= k; i++)
	{
		b[i] = C(n, m, prime[i], i, a[i]);
	}
	return CRT(k);
}

int p, n, n1, ans;
int w[20];

void dfs(int tot, int bound, int nega, int nowm)
{
//	Debug(nowm), Debug(nega);
//	Debug(exLucas(nowm - 1, n - 1, p));
	ans = (ans + nega * exLucas(nowm - 1, n - 1) + p) % p;
	if (tot > n1)
	{
		return;
	}
	for (int i = bound; i <= n1; i++)
	{
		dfs(tot + 1, i + 1, -nega, nowm - w[i]);
	}
}

signed main()
{
	int t;
	scanf("%lld%lld", &t, &p);
	pre(p);
	while (t--)
	{
		int n2, m;
		scanf("%lld%lld%lld%lld", &n, &n1, &n2, &m);
		for (int i = 1; i <= n1 + n2; i++)
		{
			scanf("%lld", w + i);
		}
		for (int i = 1; i <= n2; i++)
		{
			m -= (w[n1 + i] - 1);
		}
		ans = 0;
		dfs(1, 1, 1, m);
		printf("%lld\n", ans);
	}
	return 0;
}
posted @ 2021-12-17 20:25  mango09  阅读(25)  评论(0编辑  收藏  举报
-->