同余最短路
prologue
都快 csp-s 了还啥也不会的废柴一根,真不知道能不能进队(痴人说梦)。
(后日谈:没进队,回去上 whk 了,不知不觉间也已经过了一年了,明年考上大学之后继续启动 xcpc。 还会回来!)
main body
同余最短路的适用题型
当出现形如「给定 n 个整数,求这 n 个整数能拼凑出多少的其他整数(n 个整数可以重复取)」,以及「给定 n 个整数,求这 n 个整数不能拼凑出的最小(最大)的整数」,或者「至少要拼几次才能拼出模 K 余 p 的数」的问题时可以使用同余最短路的方法。
看到上述的问题,其实很像完全背包,但是你如果开个完全背包很容易就 MLE 喜提0pts 的好成绩,所以我们考虑出来的一种优化方法,优化掉空间,从而实现大跃进( 0pts -> 100pts
同余最短路的通见转移形式
通常我们面对一个题目可以推出来如下的式子:
我们很容易类比到单源最短路。(哪里容易,要不是学了我能想到这?)
最后的答案统计由于我们会对于这个数字一直取模,所以我们的统计答案范围应该是在 $[0, mod) $ 之间的。但是左边界通常不是固定的,会随着一个题目的具体背景然后改变。
例题
T1
P2371 [国家集训队] 墨墨的等式(这个题目也比较板子,不要担心。)
这个题目中我们有 n 个数字,然后让我们去配凑一个 b。然后就按照我们上面说的板子就可以了。只有在我们最后统计答案的时候,因为给定了我们一个区间,所以我们可以借助一种类似于前缀和的思想,统计出来从 \([1, r]\) 的答案 \(cntr\),然后再统计 \([1, l - 1]\) 这个范围的答案,做差。
普通最短路做法(我选择的是 dijkstra + heap):
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define rl register ll
template <class T>
inline void read(T &res)
{
char ch; bool f = 0;
while((ch = getchar()) < '0' || ch > '9') f |= ch == '-';
res = ch ^ 48;
while((ch = getchar()) <= '9' && ch >= '0') res = (res << 1) + (res << 3) + (ch ^ 48);
res = f ? ~res + 1 : res;
}
const ll N = 15, M = 5e5 + 10;
ll n, l, r, a[M];
ll tot, ne[M * N], e[M * N], h[M], w[M * N];
ll dis[M];
bool st[M];
struct node
{
ll id, dis;
bool operator <(const node &x ) const
{
return dis > x.dis;
}
};
priority_queue<node> q;
inline void add(ll a, ll b, ll c)
{
ne[++tot] = h[a], h[a] = tot, e[tot] = b, w[tot] = c;
}
inline void dij()
{
memset(dis, 0x3f, sizeof dis);
memset(st, 0, sizeof st);
dis[0] = 0;
q.push({0, 0});
while(q.size())
{
auto t = q.top(); q.pop();
ll u = t.id;
if(st[u]) continue;
st[u] = true;
for(rl i=h[u]; ~i; i = ne[i])
{
ll v = e[i];
if(dis[v] > dis[u] + w[i])
{
dis[v] = dis[u] + w[i];
q.push({v, dis[v]});
}
}
}
}
int main()
{
// freopen("1.in", "r", stdin), freopen("1.out", "w", stdout);
read(n), read(l), read(r);
memset(h, -1, sizeof h);
for(rl i=1; i <= n; ++ i) read(a[i]);
sort(a + 1, a + 1 + n);
for(rl i=2; i <= n; ++ i)
for(rl j=0; j < a[1]; ++ j) add(j, (j + a[i]) % a[1], a[i]); // 注意一下这里的建图。
dij();
ll ans = 0;
for(rl i=0; i < a[1]; ++ i)
{
ll cntr = (dis[i] <= r) ? (r - dis[i]) / a[1] + 1 : 0;
ll cntl = (dis[i] < l) ? (l - 1 - dis[i]) / a[1] + 1 : 0;
ans += cntr - cntl;
}
cout << ans << endl;
return 0;
}
下面是我们的神奇转圈做法:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define rl register ll
template <class T>
inline void read(T &res)
{
char ch; bool f = 0;
while((ch = getchar()) < '0' || ch > '9') f |= ch == '-';
res = ch ^ 48;
while((ch = getchar()) <= '9' && ch >= '0') res = (res << 1) + (res << 3) + (ch ^ 48);
res = f ? ~res + 1 : res;
}
const ll N = 5e5 + 10, M = 15;
ll n, l, r, a[M], m;
ll f[N], ans;
inline ll gcd(ll a, ll b)
{
return b ? gcd(b, a % b) : a;
}
int main()
{
// freopen("1.in", "r", stdin), freopen("1.out", "w", stdout);
cin >> n >> l >> r;
for(rl i=1; i <= n; ++ i) cin >> a[i];
memset(f, 0x3f, sizeof f); f[0] = 0;
sort(a + 1, a + 1 + n), m = a[1];
for(rl i=2; i <= n; ++ i)
for(rl j=0, lim = gcd(m, a[i]); j < lim; ++ j)
for(rl t=j, c = 0; c < 2; c += t == j)
{
ll p = (t + a[i]) % m;
f[p] = min(f[p], f[t] + a[i]), t = p;
}
for(rl i=0; i < a[1]; ++ i)
{
ll cntr = (f[i] <= r) ? (r - f[i]) / a[1] + 1 : 0;
ll cntl = (f[i] < l) ? (l - 1 - f[i]) / a[1] + 1 : 0;
ans += cntr - cntl;
}
printf("%lld\n", ans);
return 0;
}
(这代码是又短又香,这你不学起来?空间复杂度还低。)
T2
P3403 跳楼机(放大心,纯板子,不至于让你跳楼。)
这个题目是上面的简化版本,只有三个数字。
同时这个题目还有上面我说到的结合具体含义背景。(应该没有人的家里有第零层吧,如果有能不能给孩子拍一下,孩子见识少,没见过。)
统计答案是和上面统计答案的方法一样的。
这个题放最短路的。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define rl register ll
const ll N = 1e5 + 10, M = N << 2;
ll n, a[5];
ll tot, ne[M], e[M], h[N], w[M];
ll dis[N];
bool st[N];
struct node
{
ll id, dis;
bool operator <(const node &x) const
{
return dis > x.dis;
}
};
priority_queue<node> q;
inline void add(ll a, ll b, ll c)
{
ne[++tot] = h[a], h[a] = tot, e[tot] = b, w[tot] = c;
}
inline void dij()
{
memset(dis, 0x3f, sizeof dis);
memset(st, 0, sizeof st);
dis[1] = 1;
q.push({1, 1});
while(q.size())
{
auto t = q.top(); q.pop();
ll u = t.id;
if(st[u]) continue;
st[u] = true;
for(rl i=h[u]; ~i; i = ne[i])
{
ll v = e[i];
if(dis[v] > dis[u] + w[i])
{
dis[v] = dis[u] + w[i];
q.push({v, dis[v]});
}
}
}
}
int main()
{
// freopen("1.in", "r", stdin), freopen("1.out", "w", stdout);
cin >> n;
memset(h, -1, sizeof h);
for(rl i=1; i <= 3; ++ i) cin >> a[i];
for(rl i=2; i <= 3; ++ i)
for(rl j=0; j < a[1]; ++ j)
add(j, (j + a[i]) % a[1], a[i]);
for(rl i=1; i <= 3; ++ i) if(a[i] == 1)
{
cout << n << endl;
return 0;
}
dij();
ll ans = 0;
for(rl i=0; i < a[1]; ++ i)
ans += (dis[i] <= n) ? (n - dis[i]) / a[1] + 1 : 0;
cout << ans << endl;
return 0;
}
T3
对于这个题目的准确范围我也不太清楚,如果有人搞懂请联系我,我修正这个地方。我觉得应该是 3000 就够了,但是在没想清楚(看到\(l_i \le 3000\))之前我认为这个木棒的空间得开到 \(3000 \times 3000\)。
也可以去本人题解看看。
这个题目把每根木棒削去 m 之后的木棒的都表示出来就行了,就这一个转化。唯一判断没有没有可以表示的数的条件是我们的取模数是 1,但是我们可以用一些手法直接给省去判断,下面代码中注释。
这里放个转圈圈的。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define rl register ll
const ll N = 1e7 + 10, M = 110;
ll n, m,l[M], tot, a[N], f[N];
ll ans = -1; // 手法1:将 ans 赋值成 -1.
inline void add()
{
if(l[1] < m)
{
cout << "-1" << endl;
exit(0);
}
for(rl i=1; i <= n; ++ i)
for(rl j=0; j <= m; ++ j)
{
ll x = l[i] - j;
if(x <= 0) break;
a[ ++ tot] = x;
}
}
inline ll gcd(ll a, ll b)
{
return b ? gcd(b, a % b) : a;
}
int main()
{
// freopen("1.in", "r", stdin), freopen("1.out", "w", stdout);
cin >> n >> m;
for(rl i=1; i <= n; ++ i) cin >> l[i];
add();
sort(a + 1, a + 1 + tot);
m = a[1];
memset(f, 0x3f, sizeof f); f[0] = 0;
for(rl i=2; i <= tot; ++ i) // 这个地方别写错了,千万别写成 n,我调了 1h 才视力恢复看到了。
for(rl j=0, lim = __gcd(m, a[i]); j < lim; ++ j)
for(rl t=j, c = 0; c < 2; c += t == j)
{
ll p = (t + a[i]) % m;
f[p] = min(f[p], f[t] + a[i]), t = p;
}
for(rl i=0; i < a[1]; ++ i)
ans = max(ans, f[i] - m);
// 手法2:如果我们的取模数是 1,那么得到的 f 数组应该全是 0,所以这里不用担心 ans 从-1更新到别的数字。
cout << ans << endl;
return 0;
}
T4
支持前往本人题解查看此题更详细讲述。
中间过程太长了,就省去了。(强制去我的题解看)
马蜂优良。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define rl register ll
constexpr ll N = 55, M = 1e5 +10;
ll n, m = 1, q, a[N], f[M], c[N], ans, w;
inline ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
int main()
{
// freopen("1.in", "r", stdin), freopen("1.out", "w", stdout);
cin >> n >> q;
for(rl i=1; i <= n; ++ i)
{
cin >> a[i] >> c[i];
if(w * a[i] < m * c[i]) w = c[i], m = a[i];
}
for(rl i=1; i < m; ++ i) f[i] = -1e18;
for(rl i=1; i <= n; ++ i)
for(rl j=0, lim = gcd(m, a[i]); j < lim; ++ j)
for(rl t = j, asd = 0; asd < 2; asd += t == j)
{
ll p = (t + a[i]) % m;
f[p] = max(f[p], f[t] + c[i] - ((t + a[i]) / m) * w), t = p;
}
while(q -- )
{
ll v; cin >> v;
ll p = v % m;
if(f[p] < -1e17) puts("-1");
else cout << f[p] + v / m * w << endl;
}
return 0;
}
后面估计还会有几个题目,就先到这里了,得去学别的了,那几个做完之后补上,不会咕咕咕。