【CDQ分治】三元环
三元环
思路
考虑 \(3\) 个点的有向图,要么成环,要么有一个点入度为 \(2\) ,假设第 个点的入度为 \(d_i\),答案为 \(C_n^3-\sum\limits_{i=1}^nC_{d_i}^2\)。
根据题目关系,\(i\rightarrow j\) 当且仅当 \(i<j \ and\ f_i <f_j \ and\ g_i < g_j\),否则就是 \(j\rightarrow i\),所以根据这个三维关系,我们可以先根据前两维求出 \(i<j\ and\ f_i\ge f_j\) 的入度,然后通过 cdq分治去求满足这个三维关系的各点的度数。
代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
template<typename T>
struct BIT {
#ifndef lowbit
#define lowbit(x) (x & (-x));
#endif
int n;
vector<T> t;
BIT () {}
BIT (int _n): n(_n) { t.resize(_n + 1); }
BIT (int _n, vector<T>& a): n(_n) {
t.resize(_n + 1);
for (int i = 1; i <= n; ++ i) {
t[i] += a[i];
int j = i + lowbit(i);
if (j <= n) t[j] += t[i];
}
}
//单点修改
void update(int i, T x) {
while (i <= n) {
t[i] += x;
i += lowbit(i);
}
}
//区间查询
T sum(int i) {
T ans = 0;
while (i > 0) {
ans += t[i];
i -= lowbit(i);
}
return ans;
}
T query(int i, int j) {
return sum(j) - sum(i - 1);
}
//区间修改则存入差分数组,[l, r] + k则update(x,k),update(y+1,-k)
//单点查询则直接求前缀和sum(x)
//求逆序对
/*
iota(d.begin(), d.end(), 0);
stable_sort(d.begin(), d.end(), [&](int x, int y) {
return a[x] < a[y];
});去重排序
BIT<i64> tree(n);
i64 ans = 0;
for (int i = 1; i <= n; i ++) {
tree.update(d[i], 1);
ans += i - tree.sum(d[i]);
}
*/
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
vector<array<int, 3>> a(n + 1);
for (int i = 1; i <= n; i ++) {
cin >> a[i][0];
a[i][2] = i;
}
for (int i = 1; i <= n; i ++) {
cin >> a[i][1];
}
BIT<i64> bit(n);
vector<int> in(n + 1);
//求 i < j and fi >= fj
for (int i = n; i >= 1; i --) {
in[i] += bit.sum(a[i][0]);
bit.update(a[i][0], 1);
}
for (int i = n; i >= 1; i --) {
bit.update(a[i][0], -1);
}
auto cdq = [&](auto && self, int l, int r)->void{
if (l == r)
return ;
int mid = l + r >> 1;
self(self, l, mid);
self(self, mid + 1, r);
sort(a.begin() + l, a.begin() + mid + 1, [](auto x, auto y) {
if (x[0] != y[0]) return x[0] < y[0];
return x[1] < y[1];
});
sort(a.begin() + mid + 1, a.begin() + r + 1, [](auto x, auto y) {
if (x[0] != y[0]) return x[0] < y[0];
return x[1] < y[1];
});
//求 i < j and fi < fj and gi < gj
int i = l, j = mid + 1;
while (j <= r) {
while (i <= mid && a[i][0] < a[j][0]) {
bit.update(a[i][1], 1);
i ++;
}
in[a[j][2]] += bit.sum(a[j][1] - 1);
j ++;
}
for (int k = l; k < i; k ++) {
bit.update(a[k][1], -1);
}
//求 i < j and fi < fj and gi >= gj
i = mid, j = r;
while (i >= l) {
while (j > mid && a[j][0] > a[i][0]) {
bit.update(a[j][1], 1);
j --;
}
in[a[i][2]] += bit.sum(a[i][1]);
i --;
}
for (int k = r; k > j; k --) {
bit.update(a[k][1], -1);
}
};
cdq(cdq, 1, n);
i64 ans = 1ll * n * (n - 1) * (n - 2) / 6;
for (int i = 1; i <= n; i ++) {
ans -= 1ll * in[i] * (in[i] - 1) / 2;
}
cout << ans << '\n';
return 0;
}