bzoj 2648 SJY摆棋子 kd树
初始的时候有一些棋子, 然后给两种操作, 一种是往上面放棋子。 另一种是给出一个棋子的位置, 问你离它最近的棋子的曼哈顿距离是多少。
写了指针版本的kd树, 感觉这个版本很好理解。
#include <bits/stdc++.h> using namespace std; #define mk(x, y) make_pair(x, y) #define mem(a) memset(a, 0, sizeof(a)) #define fi first #define se second typedef pair<int, int> pll; const int inf = 2e9; int cmpflag; pll a[500004]; struct kdTree { pll point; kdTree *l, *r; int x[2], y[2]; kdTree(){}; kdTree(const pll& par): point(par) { x[0] = x[1] = par.fi; y[0] = y[1] = par.se; l = r = NULL; } int getMin(const pll& par) { int ret = 0; if(par.fi < x[0]) ret += x[0] - par.fi; if(par.fi > x[1]) ret += par.fi - x[1]; if(par.se < y[0]) ret += y[0] - par.se; if(par.se > y[1]) ret += par.se - y[1]; return ret; } void pushUp(const kdTree* par) { x[0] = min(x[0], par->x[0]); x[1] = max(x[1], par->x[1]); y[0] = min(y[0], par->y[0]); y[1] = max(y[1], par->y[1]); } }; bool cmp(const pll& lhs, const pll& rhs) { if(cmpflag) return lhs.se < rhs.se; return lhs.fi < rhs.fi; } int getDistance(const pll& lhs, const pll& rhs) { return abs(lhs.fi-rhs.fi)+abs(lhs.se-rhs.se); } void build(kdTree*& p, int l, int r, int w) { if(l > r) return ; int mid = l + r >> 1; cmpflag = w; nth_element(a+l, a+mid, a+r+1, cmp); p = new kdTree(a[mid]); build(p->l, l, mid-1, w^1); build(p->r, mid+1, r, w^1); if(p->l) p->pushUp(p->l); if(p->r) p->pushUp(p->r); } void add(kdTree*& p, const pll& q, int w) { if(!p) { p = new kdTree(q); return ; } cmpflag = w; if(cmp(q, p->point)) { add(p->l, q, w^1); p->pushUp(p->l); } else { add(p->r, q, w^1); p->pushUp(p->r); } } void query(kdTree* p, const pll& q, int& ans) { ans = min(ans, getDistance(q, p->point)); int lDis = p->l?p->l->getMin(q):inf; int rDis = p->r?p->r->getMin(q):inf; if(lDis < rDis) { if(lDis < ans) query(p->l, q, ans); if(ans > rDis) query(p->r, q, ans); } else { if(rDis < ans) { query(p->r, q, ans); } if(ans > lDis) query(p->l, q, ans); } } int main() { int n, m, x, y, sign; cin>>n>>m; for(int i = 1; i <= n; i++) { scanf("%d%d", &a[i].fi, &a[i].se); } kdTree *root = new kdTree(); build(root, 1, n, 0); while(m--) { scanf("%d%d%d", &sign, &x, &y); if(sign == 1) { add(root, mk(x, y), 0); } else { int ans = inf; query(root, mk(x, y), ans); printf("%d\n", ans); } } return 0; }