同余最短路
主要是解决几个数能否拼凑成其他数的问题
给定三个数 \(x,y,z\),问有多少 \(w=ax+by+cz\le h-1\)
定义 \(dis_i\) 表示\(x,y,z\) 能够组成 \(w\mod x=i\) 的最小的数
可以这样建图 \(i\stackrel{y}{\longrightarrow}i+y\mod x,i\stackrel{z}{\longrightarrow}i+z\mod x\)
这样就只需要跑一遍最短路就可以求出 \(dis\) 了,统计答案时若 \(dis_i\le h-1\),就加上 \(\lfloor\frac{h-1-dis_i}{x}\rfloor +1\)
代码
#include<iostream>
#include<cstring>
#include<cstdio>
#include<queue>
using namespace std;
const int N = 2e5 + 5;
typedef long long ll;
typedef pair<ll, int> PII;
ll dis[N], edge[N];
int head[N], ver[N], net[N], idx;
priority_queue<PII, vector<PII>, greater<PII> > q;
bool vis[N];
void add(int a, int b, int c)
{
net[++idx] = head[a], ver[idx] = b, edge[idx] = c, head[a] = idx;
}
void Dij()
{
memset(dis, 0x3f, sizeof(dis));
memset(vis, 0, sizeof(vis));
q.push(make_pair(0, 0));
dis[0] = 0;
while (!q.empty())
{
int u = q.top().second;
q.pop();
if (vis[u])
continue;
vis[u] = true;
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (dis[v] > dis[u] + edge[i])
{
dis[v] = dis[u] + edge[i];
q.push(make_pair(dis[v], v));
}
}
}
}
int main()
{
ll h;
int x, y, z;
scanf("%lld%d%d%d", &h, &x, &y, &z);
h--;
for (int i = 0; i < x; i++)
add(i, (i + y) % x, y), add(i, (i + z) % x, z);
Dij();
ll ans = 0;
for (int i = 0; i < x; i++)
if (dis[i] <= h)
ans += (h - dis[i]) / x + 1;
printf("%lld", ans);
return 0;
}
和上面的一样,就直接放代码了
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int N = 6e6 + 5;
typedef pair<int, int> PII;
typedef long long ll;
int a[15];
ll dis[N];
int head[N], net[N], ver[N], edge[N], idx;
priority_queue<PII, vector<PII>, greater<PII> > q;
bool vis[N];
void add(int a, int b, int c)
{
net[++idx] = head[a], ver[idx] = b, edge[idx] = c, head[a] = idx;
}
void Dij()
{
memset(dis, 0x3f, sizeof(dis));
memset(vis, 0, sizeof(vis));
dis[0] = 0;
q.push(make_pair(0, 0));
while (!q.empty())
{
int u = q.top().second;
q.pop();
if (vis[u])
continue;
vis[u] = true;
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (dis[v] > dis[u] + edge[i])
{
dis[v] = dis[u] + edge[i];
q.push(make_pair(dis[v], v));
}
}
}
}
int main()
{
int n;
ll l, r;
scanf("%d%lld%lld", &n, &l, &r);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 0; i < a[1]; i++)
for (int j = 2; j <= n; j++)
add(i, (i + a[j]) % a[1], a[j]);
Dij();
ll ans = 0;
for (int i = 0; i < a[1]; i++)
{
if (dis[i] <= r)
ans += (r - dis[i]) / a[1] + 1;
if (dis[i] < l)
ans -= (l - 1 - dis[i]) / a[1] + 1;
}
printf("%lld", ans);
return 0;
}