题解:CF607E Cross Sum
Problem
给定 \(N\) 条不平行的直线 \(y = \frac{k_i}{1000}x+\frac{b_i}{1000}\),\(N\) 条直线总共会有 \(\frac{N(N-1)}{2}\) 个交点(包含在同一个位置的点,即相同位置算不同的点),找出距离原点前 \(K\) 近的交点的距离和。
$2\le N\le 5\times 10^4 $,\(1\le K\le \frac{N(N-1)}{2}\),$-100000\le|A_i|,|B_i|,|C_i|\le1000(1\le i\le N) $,保证 \(A_i\),\(B_i\) 均非零。
Sol
先考虑一个弱化版:求第 \(K\) 近的交点的距离,具体可以参考 [ABC263Ex] Intersection 2。
发现题目点的个数其实是有 \(\mathcal{O}(n^2)\) 个的,显然不好直接统计,考虑二分答案。
具体地,二分出距离 \(r\),变为求原点与交点的距离 \(\le r\) 的个数。思考直线 \(L_i\) 与直线 \(L_j\) 的交点会被统计的条件。可以画出一个 \(x^2 +y^2 = r^2\) 的圆,即变为交点在圆内。由于不好直接考虑 \(L_i\),\(L_j\) 的交点,可以算出所有直线与圆的交点 \((x_{i1}, y_{i1}), (x_{i2}, y_{i2})\)。将两个交点都转化为一个连向原点的与 \(x\) 轴的夹角,这样可以得到一个度数区间 \([l_i, r_i](0\le l_i\le r_i<2\pi)\),然后两条直线满足交点的条件就是这两个度数区间相交但不包含。统计相交但不包含的区间对数可以通过按左端点从小到大排序后用树状数组统计。这一部分的时间复杂度:\(\mathcal{O(n\log n\log \frac{V}{\text{eps}})}\)。这里的 \(\text{eps}\) 是程序中设置的精度。
\(x^2 + y^2 = R^2\) 与 \(Ax + By + C = 0\) 的交点。
- \(B \neq 0\) 时,有 \(y = -\frac{Ax + C}{B}\):
$ y $ 可以用 $\frac{Ax + C}{B} $ 计算。
- \(B = 0\) 时,\(x = -\frac{C}{A}\)。
有了弱化版做铺垫后,这道题就很容易了。
由于交点个数不超过 \(m\),所以可以直接使用线段树套 vector 做到 \(\mathcal{O}(n\log n + m)\)。
总时间复杂度 \(\mathcal{O}(n\log n\log + m)\)。
注意最后计算距离和时圆上的交点要单独计算(因为可能加上圆上的所有点点数就超过 \(m\) 了),然后还要判断所有合法的点与 \((x_0, y_0)\) 重合的情况。
Code
#include<bits/stdc++.h>
#define ll long long
#define sz(a) ((int) (a).size())
#define vi vector < int >
#define pb emplace_back
using namespace std;
#define double long double
#define fi first
#define se second
const double eps = 1e-9;
#define x0 x_0
#define y0 y_0
int n, m, k; double x0, y0;
double a[50010], b[50010], c[50010], lsh[100010];
#define pdd pair < double, double >
pdd tmp[50010];
struct BIT {
int n;
vi v;
BIT(int _n = 100000) : v(_n + 10) { n = _n; }
void add(int x, int y) {
for(; x <= n; x += x & -x)
v[x] += y;
}
int ask(int x) {
int ret = 0;
for(; x; x -= x & -x)
ret += v[x];
return ret;
}
int ask(int x, int y) { return ask(y) - ask(x - 1); }
} ;
double sqr(double x) { return x * x; }
double dis(double ax, double ay, double bx, double by) { return sqrt(sqr(ax - bx) + sqr(ay - by)); }
double dis(pdd a, pdd b) { return sqrt(sqr(a.fi - b.fi) + sqr(a.se + b.se)); }
pdd calc(double a, double b, double c1, double R) {
double c = a * x0 + b * y0 + c1;
double sa = sqr(a), sb = sqr(b), sc = sqr(c), sr = sqr(R);
double p1x, p2x, p1y, p2y;
if(fabs(b) > eps) {
double delta = (sa + sb) * sb * sr - sb * sc;
if(delta < 0)
return make_pair(1011451423, 1011451423);
p1x = (-a * c + sqrt(delta)) / (sa + sb);
p2x = (-a * c - sqrt(delta)) / (sa + sb);
p1y = -(a * p1x + c) / b;
p2y = -(a * p2x + c) / b;
}
double l = atan2(p1x, p1y), r = atan2(p2x, p2y);
if(l > r)
swap(l, r);
return make_pair(l, r);
}
int len;
pair < int, int > pos[50010];
bool check(double mid) {
m = 0;
for(int i = 1; i <= n; ++i) {
auto ret = calc(a[i], b[i], c[i], mid);
if(ret.fi != 1011451423 || ret.se != 1011451423)
++m, tmp[m] = ret, lsh[2 * m - 1] = ret.fi, lsh[2 * m] = ret.se;
}
sort(lsh + 1, lsh + 2 * m + 1);
len = unique(lsh + 1, lsh + 2 * m + 1) - lsh - 1;
for(int i = 1; i <= m; ++i)
pos[i].fi = lower_bound(lsh + 1, lsh + len + 1, tmp[i].fi) - lsh,
pos[i].se = lower_bound(lsh + 1, lsh + len + 1, tmp[i].se) - lsh;
sort(pos + 1, pos + m + 1);
BIT T(len);
int cnt = 0;
for(int i = 1; i <= m; ++i) {
cnt += T.ask(pos[i].fi, pos[i].se);
T.add(pos[i].se, 1);
}
return cnt >= k;
}
pdd gnd(double a0, double b0, double c0, double a1, double b1, double c1) {
if(fabs(a0 * b1 - a1 * b0) < eps)
return make_pair(1011451423, 1011451423);
double x, y;
if(fabs(a0) > eps) {
y = (a1 * c0 - a0 * c1) / (a0 * b1 - a1 * b0);
x = -(b0 * y + c0) / a0;
}
else {
y = -c0 / b0;
x = (-c1 - y * b1) / a1;
}
return make_pair(x, y);
}
int num;
pair < double, pdd > ft[50010];
pair < double, pdd > gln(double p, double q, double r, double s) {
if(fabs(p - r) < eps)
return make_pair(1.0, make_pair(0.0, -p));
double k = (q - s) / (p - r);
double b = q - k * p;
return make_pair(k, make_pair(-1.0, b));
}
vector < int > tr[400010];
void modify(int k, int v, int x, int L, int R) {
tr[x].pb(v);
if(L == R)
return;
int mid = (L + R) >> 1;
k <= mid ? modify(k, v, x << 1, L, mid) : modify(k, v, x << 1 | 1, mid + 1, R);
}
vector < int > res;
void query(int l, int r, int x, int L, int R) {
if(l <= L && R <= r) {
for(auto i : tr[x])
res.pb(i);
return;
}
int mid = (L + R) >> 1;
if(l <= mid)
query(l, r, x << 1, L, mid);
if(r > mid)
query(l, r, x << 1 | 1, mid + 1, R);
}
void query(int l, int r) {
res.clear();
query(l, r, 1, 1, len);
}
int main() {
cin >> n >> x0 >> y0 >> k;
x0 /= 1000, y0 /= 1000;
for(int i = 1; i <= n; ++i)
cin >> a[i] >> c[i], a[i] /= 1000, b[i] = -1, c[i] /= 1000; //, cout << a[i] << " " << b[i] << " " << c[i] << "\n";
cout << fixed << setprecision(12);
double L = eps, R = 3e9;
while(R - L > eps) {
double mid = (L + R) / 2;
if(check(mid))
R = mid;
else
L = mid;
}
if(L <= eps)
return cout << 0 << "\n", 0;
check(L);
for(int i = 1; i <= m; ++i)
ft[i] = gln(L * cos(lsh[pos[i].fi]), L * sin(lsh[pos[i].fi]), L * cos(lsh[pos[i].se]), L * sin(lsh[pos[i].se]));
int cnt = 0; double ans = 0;
for(int i = 1; i <= m; ++i) {
if(pos[i].fi + 1 > pos[i].se - 1) continue;
query(pos[i].fi + 1, pos[i].se - 1);
for(auto j : res) {
double ret = dis(make_pair(0.0, 0.0), gnd(ft[i].fi, ft[i].se.fi, ft[i].se.se, ft[j].fi, ft[j].se.fi, ft[j].se.se));
ans += ret, ++cnt;
}
modify(pos[i].se, i, 1, 1, len);
}
cout << ans + L * (k - cnt) << "\n";
return 0;
}