【HDOJ】4347 The Closest M Points

居然是KD解。

  1 /* 4347 */
  2 #include <iostream>
  3 #include <sstream>
  4 #include <string>
  5 #include <map>
  6 #include <queue>
  7 #include <set>
  8 #include <stack>
  9 #include <vector>
 10 #include <deque>
 11 #include <algorithm>
 12 #include <cstdio>
 13 #include <cmath>
 14 #include <ctime>
 15 #include <cstring>
 16 #include <climits>
 17 #include <cctype>
 18 #include <cassert>
 19 #include <functional>
 20 #include <iterator>
 21 #include <iomanip>
 22 using namespace std;
 23 //#pragma comment(linker,"/STACK:102400000,1024000")
 24 
 25 #define sti                set<int>
 26 #define stpii            set<pair<int, int> >
 27 #define mpii            map<int,int>
 28 #define vi                vector<int>
 29 #define pii                pair<int,int>
 30 #define vpii            vector<pair<int,int> >
 31 #define rep(i, a, n)     for (int i=a;i<n;++i)
 32 #define per(i, a, n)     for (int i=n-1;i>=a;--i)
 33 #define clr                clear
 34 #define pb                 push_back
 35 #define mp                 make_pair
 36 #define fir                first
 37 #define sec                second
 38 #define all(x)             (x).begin(),(x).end()
 39 #define SZ(x)             ((int)(x).size())
 40 #define lson            l, mid, rt<<1
 41 #define rson            mid+1, r, rt<<1|1
 42 
 43 int id, n, m, K;
 44 
 45 typedef struct node_t {
 46     int x[5];
 47 
 48     friend bool operator< (const node_t& a, const node_t& b) {
 49         return a.x[id] < b.x[id];
 50     }
 51 
 52     double Distance(const node_t& a) {
 53         __int64 ret = 0;
 54 
 55         rep(i, 0, K)
 56             ret += 1LL*(a.x[i]-x[i])*(a.x[i]-x[i]);
 57 
 58         return ret;
 59     }
 60 
 61     void print() {
 62         printf("%d", x[0]);
 63         rep(i, 1, K)
 64             printf(" %d", x[i]);
 65         putchar('\n');
 66     }
 67 
 68 } node_t;
 69 
 70 typedef struct Node {
 71     __int64 d;
 72     node_t p;
 73 
 74     Node() {}
 75     Node(__int64 d, node_t& p):
 76         d(d), p(p) {}
 77 
 78     friend bool operator< (const Node& a, const Node& b) {
 79         return a.d < b.d;
 80     }
 81 
 82 } Node;
 83 
 84 const int maxn = 50005;
 85 node_t nd[maxn];
 86 vector<node_t> ans;
 87 
 88 typedef struct KD_tree{
 89     static const int maxd = 5;
 90     static const int maxn = 50005;
 91     node_t P[maxn<<2];
 92     int idx[maxn<<2];
 93     priority_queue<Node> Q;
 94 
 95     void Build(int deep, int l, int r, int rt) {
 96         if (l > r)
 97             return ;
 98         idx[rt] = deep % K;
 99         if (l == r) {
100             P[rt] = nd[l];
101             return ;
102         }
103 
104         id = idx[rt];
105         int mid = (l + r) >> 1;
106         nth_element(nd+l, nd+mid, nd+r+1);
107         P[rt] = nd[mid];
108         Build(deep+1, l, mid-1, rt<<1);
109         Build(deep+1, mid+1, r, rt<<1|1);
110     }
111 
112     void Query(node_t x, int l, int r, int rt) {
113         if (l > r)
114             return ;
115         
116         int id = idx[rt];
117         __int64 tmp = x.Distance(P[rt]);
118         if (l == r) {
119             if (SZ(Q) < m) {
120                 Q.push(Node(tmp, P[rt]));
121             } else {
122                 if (tmp < Q.top().d) {
123                     Q.pop();
124                     Q.push(Node(tmp, P[rt]));
125                 }
126             }
127             return ;
128         }
129         
130         int mid = (l + r) >> 1;
131 
132         if (P[rt].x[id] >= x.x[id]) {
133             Query(x, l, mid-1, rt<<1);
134             if (SZ(Q) < m) {
135                 Q.push(Node(tmp, P[rt]));
136                 Query(x, mid+1, r, rt<<1|1);
137             } else {
138                 if (tmp < Q.top().d) {
139                     Q.pop();
140                     Q.push(Node(tmp, P[rt]));
141                 }
142                 if (1LL*(x.x[id]-P[rt].x[id])*(x.x[id]-P[rt].x[id]) < Q.top().d) {
143                     Query(x, mid+1, r, rt<<1|1);
144                 }
145             }
146         } else {
147             Query(x, mid+1, r, rt<<1|1);
148             if (SZ(Q) < m) {
149                 Q.push(Node(tmp, P[rt]));
150                 Query(x, l, mid-1, rt<<1);
151             } else {
152                 if (tmp < Q.top().d) {
153                     Q.pop();
154                     Q.push(Node(tmp, P[rt]));
155                 }
156                 if (1LL*(x.x[id]-P[rt].x[id])*(x.x[id]-P[rt].x[id]) < Q.top().d) {
157                     Query(x, l, mid-1, rt<<1);
158                 }
159             }
160         }
161     }
162 
163     void Dump() {
164         ans.clr();
165         while (!Q.empty()) {
166             ans.pb(Q.top().p);
167             Q.pop();
168         }
169     }
170 } KD_tree;
171 
172 KD_tree kd;
173 
174 void solve() {
175     int q, sz;
176     node_t p;
177 
178     kd.Build(0, 0, n-1, 1);
179     scanf("%d", &q);
180     while (q--) {
181         rep(j, 0, K)
182             scanf("%d", &p.x[j]);
183         scanf("%d", &m);
184         kd.Query(p, 0, n-1, 1);
185         kd.Dump();
186         sz = SZ(ans);
187         printf("the closest %d points are:\n", m);
188         per(i, 0, sz) {
189             ans[i].print();
190         }
191     }
192 }
193 
194 int main() {
195     ios::sync_with_stdio(false);
196     #ifndef ONLINE_JUDGE
197         freopen("data.in", "r", stdin);
198         freopen("data.out", "w", stdout);
199     #endif
200 
201     while (scanf("%d %d", &n, &K)!=EOF) {
202         rep(i, 0, n)
203             rep(j, 0, K)
204                 scanf("%d", &nd[i].x[j]);
205         solve();
206     }
207 
208     #ifndef ONLINE_JUDGE
209         printf("time = %d.\n", (int)clock());
210     #endif
211 
212     return 0;
213 }

数据发生器。

 1 from copy import deepcopy
 2 from random import randint, shuffle
 3 import shutil
 4 import string
 5 
 6 
 7 def GenDataIn():
 8     with open("data.in", "w") as fout:
 9         t = 10
10         bound = 10**5
11         for tt in xrange(t):
12             n = randint(100, 200)
13             k = randint(1, 5)
14             fout.write("%d %d\n" % (n, k))
15             for i in xrange(n):
16                 L = []
17                 for j in xrange(k):
18                     x = randint(-1000, 1000)
19                     L.append(x)
20                 fout.write(" ".join(map(str, L)) + "\n")
21             q = randint(20, 30)
22             fout.write("%d\n" % (q))
23             for qq in xrange(q):
24                 L = []
25                 for j in xrange(k):
26                     x = randint(-1000, 1000)
27                     L.append(x)
28                 fout.write(" ".join(map(str, L)) + "\n")
29                 m = randint(1, 10)
30                 fout.write("%d\n" % (m))
31             
32                 
33 def MovDataIn():
34     desFileName = "F:\eclipse_prj\workspace\hdoj\data.in"
35     shutil.copyfile("data.in", desFileName)
36 
37     
38 if __name__ == "__main__":
39     GenDataIn()
40     MovDataIn()

 

posted on 2016-02-12 15:49  Bombe  阅读(192)  评论(0编辑  收藏  举报

导航