USACO Section 1.3 Wormholes 解题报告
题目
题目描述
在一个二维平面上有N个点,这N个点是(N/2)个虫洞的端点,虫洞的特点就是,你以什么状态从某个端点进去,就一定会以什么状态从另一端的端点出来。现在有一头牛总是沿着与X轴正方向平行的直线前进(也就是总是向右边走),所以它有可能被某些虫洞给困住。例如有一个虫洞的两个端点是A(0,0)、B(1,0),那么它如果从AB之间的某点开始出发,那么它一定会进入B点,但是当它进入B点之后会立刻传送到A点,这个时候它的朝向是不会变的,也就是说还是朝着X轴正方向,那么如果它继续前进,它肯定会又进入B点,然后就这样一直循环,被这个虫洞所困住。现在这个农夫不知道这头牛的位置(所以这头牛可以在平面上的任意一点),也不知道这个二维平面上的N个点是怎么连接的,所以你要找到所有的连接方式,并且记录有多少种方式,会让这头牛可能被虫洞困住。
数据范围
2 <= N <= 12
,N为偶数- 每个点的坐标的x,y都是属于
0~100000000
之间的正整数
样例输入
第一行输入N,下面N行输入N个端点的坐标
4
0 0
1 0
1 1
0 1
样例输出
2
解题思路
由于N很小,所以我们可以考虑直接枚举每一种可能的连接方式,然后再分别判断该种连接方式是否存在可以困住牛的虫洞,这也是最直接的想法。
我们可以计算一下枚举每一种连接方式的情况,如果N=12,那么我们会产生6条边,我们按照下面这种方式来找到这6条边:
- 将所有的点都编号,我们从编号为1的点开始连接,这个点有N-1个点可以与它连接。
- 我们将剩下的N-2个点中找到编号最小的点,从这个点开始连接,所以有N-3个点可以与它连接。
- 与上同理,我们可以一直处理到最后一条边。
这样我们可以计算(N-1)*(N-3)*(N-5)*……*3*1 = 10395
,然后我们对每一种情况都必须遍历每一个点来判断是否能够困住这头牛,所以最终的结果是 10395 * 12
这个结果约为100000,所以这个时间复杂度是可以接受的。
Tip:我们在判断是否能够被虫洞困住的时候一定要注意这样一个特点,那就是,“从一个端点A进入虫洞”与“从虫洞端点A出来”这两种状态是不一样的,例如这两个点A(0,0)、B(1,0)构成一个虫洞,我从B点进入,那么我肯定会被困住,但是如果我从B点出来,朝右走,那么这个虫洞是困不住我的。
解题代码
/*
ID: yinzong2
PROG: wormhole
LANG: C++11
*/
#define MARK
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<set>
#include<cstdlib>
#include<cstring>
using namespace std;
const int MAXN = 15;
const int INF = 0x3fffffff;
int n, ans;
int pr[MAXN];
int right[MAXN];
bool vis[MAXN];
struct WormHole{
int x,y;
}w[MAXN];
set<int>s;
bool outFrom[MAXN];
//判断是否有环可以困住这头牛
bool testSolution() {
memset(outFrom, false, sizeof(outFrom));
for(int i = 1; i <= n; i++) {
if(!outFrom[i]) {
int start = i;
s.clear();//这个set是用来保存有哪些点在行进的过程中作为虫洞的出口,一定不能将虫洞的入口保存进去,保证一致性
s.insert(start);
outFrom[start] = true;
while(right[start]) {
start = pr[ right[start] ];
if(s.find(start) != s.end()) {
return true;
}
s.insert(start);//进入一个端点,和从这个端点出来的状态是不一样的,所以我们一定要注意
outFrom[start] = true;
}
}
}
return false;
}
//枚举每一种可能的虫洞连接方式
void makeSolution(int start, int curLevel) {
if(curLevel == (n/2)) {
for(int i = 1; i <= n; i++) {
if(!vis[i] && start != i) {
pr[start] = i;
pr[i] = start;
}
}
if(testSolution()) {
ans++;
}
return ;
}
vis[start] = true;
for(int i = 1; i <= n; i++) {
if(i != start && !vis[i]) {
vis[i] = true;
pr[start] = i;
pr[i] = start;
for(int j = 1; j <= n; j++) {
if(!vis[j]) {
makeSolution(j, curLevel+1);
break; // 这里别忘记了
}
}
// 状态还原很重要
vis[i] = false;
}
}
// 状态还原很重要
vis[start] = false;
}
//找到每个点的右边最近的点
bool findRight() {
int dis, id;
bool flag = false;
for(int i = 1; i <= n; i++) {
dis = INF;
for(int j = 1; j <= n; j++) {
if(i != j && w[i].y == w[j].y && w[i].x < w[j].x) {
if(dis > (w[j].x-w[i].x)) {
id = j;
dis = w[j].x-w[i].x;
}
}
}
if(dis != INF) {
right[i] = id;
flag = true;
}
}
return flag;
}
int main() {
#ifdef MARK
freopen("wormhole.in", "r", stdin);
freopen("wormhole.out", "w", stdout);
#endif // MARK
while(~scanf("%d", &n)) {
for(int i = 1; i <= n; i++) {
scanf("%d%d", &w[i].x, &w[i].y);
right[i] = 0;
}
bool flag = findRight();
ans = 0;
//优化,当所有的点的Y坐标都不相同的时候,肯定是没有解的。
if(flag) {
memset(vis, false, sizeof(vis));
makeSolution(1, 1);
}
printf("%d\n", ans);
}
return 0;
}
官方给的题解还附带上youtube上的视频讲解,特别好。可以看到代码是怎么一步一步实现的。而且代码特别的优雅,值得好好学习。
#include <iostream>
#include <fstream>
using namespace std;
#define MAX_N 12
int N, X[MAX_N+1], Y[MAX_N+1];
int partner[MAX_N+1];
int next_on_right[MAX_N+1];
bool cycle_exists(void)
{
for (int start=1; start<=N; start++) {
// does there exist a cylce starting from start
int pos = start;
for (int count=0; count<N; count++)
pos = next_on_right[partner[pos]];
if (pos != 0) return true;
}
return false;
}
// count all solutions
int solve(void)
{
// find first unpaired wormhole
int i, total=0;
for (i=1; i<=N; i++)
if (partner[i] == 0) break;
// everyone paired?
if (i > N) {
if (cycle_exists()) return 1;
else return 0;
}
// try pairing i with all possible other wormholes j
for (int j=i+1; j<=N; j++)
if (partner[j] == 0) {
// try pairing i & j, let recursion continue to
// generate the rest of the solution
partner[i] = j;
partner[j] = i;
total += solve();
partner[i] = partner[j] = 0;
}
return total;
}
int main(void)
{
ifstream fin("wormhole.in");
fin >> N;
for (int i=1; i<=N; i++) fin >> X[i] >> Y[i];
fin.close();
for (int i=1; i<=N; i++) // set next_on_right[i]...
for (int j=1; j<=N; j++)
if (X[j] > X[i] && Y[i] == Y[j]) // j right of i...
if (next_on_right[i] == 0 ||
X[j]-X[i] < X[next_on_right[i]]-X[i])
next_on_right[i] = j;
ofstream fout("wormhole.out");
fout << solve() << "\n";
fout.close();
return 0;
}