CodeForces 698D Limak and Shooting Points
CodeForces 698D Limak and Shooting Points
https://codeforces.com/contest/698/problem/D
有 \(n\) 个怪兽和 \(k\) 个魔法石.第 \(i\) 只怪兽位于 \((mx_i, my_i)\)
每个魔法石最多只能使用一次,可以按任意顺序使用魔法石,使用第 \(i\) 个魔法石可以传送至 \((ax_i, ay_i)\) ,然后向任意方向射出一只箭,然后箭会消灭掉这个方向上第一只怪兽,然后消失.
问有多少只怪兽可能被消灭掉.
\(1 \le k \le 7, 1 \le n \le 1000\)
\(-10^9 \le ax_i,ay_i \le 10^9\)
\(-10^9 \le mx_i,my_i \le 10^9\)
给出的 \(n + k\) 个点坐标两两不同.
Tutorial
https://sunnuozhou.github.io/2020/01/16/codeforces-698D/
由于 \(k\) 很小, 考虑搜索.
首先,预处理出第 \(i\) 个魔法石位置和第 \(j\) 只怪兽之间的所有怪兽.
考虑依次判断第 \(x\) 只怪兽是否可以被消灭. 可以首先将 \(x\) 加入队列,然后进行以下循环
- 若队列为空,则可以被消灭
- 否则,取出队首元素,设为 \(u\)
- 分配一个未使用的魔法石 \(v\) ,表示是使用 \(v\) 消灭了 \(u\) .
- 将 \(v, u\) 之间的怪兽加入队列
注意若某一时刻进入过队列的元素超过 \(k\) , 则无解.
预处理的复杂度为 \(O(n^2k)\) ,判断部分复杂度为 \(O(nk \cdot k!)\)
Code
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#define debug(...) fprintf(stderr, __VA_ARGS__)
using namespace std;
typedef long long ll;
const int maxk = 10;
const int maxn = 1000 + 5;
int k, n;
int ax[maxk], ay[maxk];
int mx[maxn], my[maxn];
bool flag;
bool inq[maxn];
bool vis[maxk];
vector<int> pos;
vector<int> r[maxk][maxn];
inline ll det(int x0, int y0, int x1, int y1) { return (ll)x0 * y1 - (ll)y0 * x1; }
inline ll dot(int x0, int y0, int x1, int y1) { return (ll)x0 * x1 + (ll)y0 * y1; }
inline ll norm(int x, int y) { return dot(x, y, x, y); }
inline bool check(int x0, int y0, int x1, int y1)
{
if (det(x0, y0, x1, y1))
return 0;
if (dot(x0, y0, x1, y1) < 0)
return 0;
if (norm(x0, y0) < norm(x1, y1))
return 0;
return 1;
}
void dfs(int step)
{
if (step == pos.size())
flag = 1;
if (flag)
return;
int u = pos[step], rec = pos.size();
for (int i = 1; i <= k; ++i) if (!vis[i])
{
vis[i] = 1;
for (int j = 0; j < r[i][u].size(); ++j)
{
int x = r[i][u][j];
if (!inq[x])
{
inq[x] = 1;
pos.push_back(x);
if (pos.size() > k)
break;
}
}
if (pos.size() <= k)
{
dfs(step + 1);
if (flag) return;
}
while (pos.size() > rec)
{
int x = pos.back(); pos.pop_back();
inq[x] = 0;
}
vis[i] = 0;
}
}
int sol()
{
for (int i = 1; i <= k; ++i)
for (int j = 1; j <= n; ++j)
for (int k = 1; k <= n; ++k) if(j != k)
if (check(mx[j] - ax[i], my[j] - ay[i], mx[k] - ax[i], my[k] - ay[i]))
r[i][j].push_back(k);
int an = 0;
for (int i = 1; i <= n; ++i)
{
flag = 0;
pos.clear(), pos.push_back(i);
memset(vis, 0, sizeof(vis));
memset(inq, 0, sizeof(inq));
dfs(0);
if (flag) ++an;
}
return an;
}
int main()
{
scanf("%d%d", &k, &n);
for (int i = 1; i <= k; ++i)
scanf("%d%d", &ax[i], &ay[i]);
for (int i = 1; i <= n; ++i)
scanf("%d%d", &mx[i], &my[i]);
printf("%d\n", sol());
return 0;
}