Codeforces Round 864 (Div. 2) E. Li Hua and Array
Codeforces Round 864 (Div. 2E. Li Hua and Array)(暴力修改线段树+lca和数论的结合)
Example
input
5 4
8 1 6 3 7
2 1 5
2 3 4
1 1 3
2 3 4
output
10
2
1
Solution
首先你得知道什么是欧拉函数
- 我们\(O(n)\)求出\([1, 5e6]\)范围内的每个数的欧拉函数后可以求一下最大的跳转次数是多少,也就是连着求几次欧拉函数之后这个数变成\(1\),利用桶的思想做一个\(O(n)\)的递推即可知道最多跳转\(23\)次,其实就是\(\log 5e6\)次
void pre() {//筛法求欧拉函数
for (int i = 1; i <= 5000000; i++) {
is_prime[i] = 1;
}
int cnt = 0;
is_prime[1] = 0;
phi[1] = 1;
for (int i = 2; i <= 5000000; i++) {
if (is_prime[i]) {
prime[++cnt] = i;
phi[i] = i - 1;
}
for (int j = 1; j <= cnt && i * prime[j] <= 5000000; j++) {
is_prime[i * prime[j]] = 0;
if (i % prime[j])
phi[i * prime[j]] = phi[i] * phi[prime[j]];
else {
phi[i * prime[j]] = phi[i] * prime[j];
break;
}
}
}
}
int nxt[maxn];
void get_max_time() {
int ans = 0;
for (int i = 2; i <= 5000000; ++i) {
nxt[i] = nxt[phi[i]] + 1;
ans = max(ans, nxt[i]);
}
cout << ans << '\n';
}
-
那也就是说我们进行区间修改的时候即使是逐一去跳转也只需要跳转大约\(n \times 23\)次,但我们需要快速判断这个区间是否还需要修改,这一点可以里用线段树来维护。然后还有一个问题就是如何实现查询,询问是区间内所有点跳到同一个点所需的最小次数,如果把这些跳转看成连边,其实就是这些点到他们\(lca\)的距离和。由于最深点的深度也只有\(20+\)所以倍增最多只会跳\(5\)次左右,我们可以直接在线段树上维护这个\(lca\),我们在记一个\(ans\)数组用来直接在线段树上维护出我们要的答案,这样我们每次\(push\_up\)时就应该是如下的式子
\[ ans[p] = (dep[lca[p << 1]] - dep[lca[p]]) * cnt[p << 1] + (dep[lca[p << 1 | 1]] - dep[lca[p]]) * cnt[p << 1 | 1] \]\[ + ans[p << 1] + ans[p << 1 | 1]; \]也就是每次上传要把答案更新成当前区间\(lca\)到两个子区间的\(lca\)的距离并且分别乘上子区间的点的个数,在\(query\)进行求答案的时候也要进行这样的计算,本质上\(query\)也是在做一个\(push\_up\)的事情。
-
或者也可以直接在线段树上维护每个区间最大最小的点的dfn序,这样的话区间的\(lca\)就是其中最大\(dfn\)和最小\(dfn\)的\(lca\)最后的答案是如下的式子
\[ans=(\sum_{i=l}^r dep[i]) - (r-l+1)\times dep[lca(l\cdots r)] \]但是这样的话要求一个\(dfn\)序,递归会让程序慢很多但也能过,并且求答案的时候思路更简单
CODE
线段树直接维护答案
int n, m;
bool is_prime[maxn];
int phi[maxn];
int prime[maxn];
void pre() {//筛法求欧拉函数
for (int i = 1; i <= 5000000; i++) {
is_prime[i] = 1;
}
int cnt = 0;
is_prime[1] = 0;
phi[1] = 1;
for (int i = 2; i <= 5000000; i++) {
if (is_prime[i]) {
prime[++cnt] = i;
phi[i] = i - 1;
}
for (int j = 1; j <= cnt && i * prime[j] <= 5000000; j++) {
is_prime[i * prime[j]] = 0;
if (i % prime[j])
phi[i * prime[j]] = phi[i] * phi[prime[j]];
else {
phi[i * prime[j]] = phi[i] * prime[j];
break;
}
}
}
}
int a[maxn];
struct node {
int lca, cnt, ans;
};
int dep[maxn];
int fa[maxn][6];
int lca[maxm << 2], sum[maxm << 2], ans[maxn << 2], cnt[maxm << 2];
void init_lca(int n) {
for (int j = 1; j < 6; ++j)
for (int i = 1; i <= n; ++i)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
int get_lca(int u, int v) {
if (dep[u] > dep[v]) swap(u, v);
if (u == 0) return v;
if (v == 0) return u;
for (int i = 5; i >= 0; --i)
if (dep[v] - dep[u] >= (1 << i))
v = fa[v][i];
if (u == v) return u;
for (int i = 5; i >= 0; --i)
if (fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
void push_up(int p) {
cnt[p] = cnt[p << 1] + cnt[p << 1 | 1];
sum[p] = sum[p << 1] + sum[p << 1 | 1];
lca[p] = get_lca(lca[p << 1], lca[p << 1 | 1]);
ans[p] = (dep[lca[p << 1]] - dep[lca[p]]) * cnt[p << 1] + (dep[lca[p << 1 | 1]] - dep[lca[p]]) * cnt[p << 1 | 1] + ans[p << 1] + ans[p << 1 | 1];
}
void build(int p, int l, int r) {
if (l == r) {
lca[p] = a[l];
sum[p] = dep[a[l]];
cnt[p] = 1;
return ;
}
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
push_up(p);
}
void update(int p, int l, int r, int ql, int qr) {
if (!sum[p]) return ;
if (l == r) {
sum[p]--;
lca[p] = fa[lca[p]][0];
return ;
}
int mid = l + r >> 1;
if (ql <= mid) update(p << 1, l, mid, ql, qr);
if (mid < qr) update(p << 1 | 1, mid + 1, r, ql, qr);
push_up(p);
}
node query(int p, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return {lca[p], cnt[p], ans[p]};
node L = {0, 0, 0}, R = {0, 0, 0};
int mid = l + r >> 1;
node ans = {0, 0, 0};
if (ql <= mid) {
L = query(p << 1, l, mid, ql, qr);
}
if (mid < qr) {
R = query(p << 1 | 1, mid + 1, r, ql, qr);
}
ans.lca = get_lca(L.lca, R.lca);
ans.cnt = L.cnt + R.cnt;
if (L.lca) {
ans.ans += L.ans + L.cnt * (dep[L.lca] - dep[ans.lca]);
}
if (R.lca) {
ans.ans += R.ans + R.cnt * (dep[R.lca] - dep[ans.lca]);
}
return ans;
}
int nxt[maxn];
void get_max_time() {
int ans = 0;
for (int i = 2; i <= 5000000; ++i) {
nxt[i] = nxt[phi[i]] + 1;
ans = max(ans, nxt[i]);
}
cout << ans << '\n';
}
void solve(int cas) {
pre();
// get_max_time();
for (int i = 2; i <= 5000000; ++i) {
fa[i][0] = phi[i];
dep[i] = dep[phi[i]] + 1;
}
init_lca(5000000);
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> a[i];
build(1, 1, n);
while (m--) {
int op, l, r;
cin >> op >> l >> r;
if (op == 1) {
update(1, 1, n, l, r);
} else {
cout << query(1, 1, n, l, r).ans << '\n';
}
}
}
通过dfn序求lca
//省略了筛法求欧拉函数
int a[maxn];
vector<int> e[maxn];
int dep[maxn];
int fa[maxn][6];
int sum[maxm << 2], cnt[maxm << 2], mx[maxm << 2], mn[maxm << 2];
int dfn, in[maxn], id[maxn];
void dfs(int u) {
in[u] = ++dfn;
id[dfn] = u;
for (int v : e[u])
dfs(v);
}
void init_lca(int n) {
for (int j = 1; j < 6; ++j)
for (int i = 1; i <= n; ++i)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
int get_lca(int u, int v) {
if (dep[u] > dep[v]) swap(u, v);
if (u == 0) return v;
if (v == 0) return u;
for (int i = 5; i >= 0; --i)
if (dep[v] - dep[u] >= (1 << i))
v = fa[v][i];
if (u == v) return u;
for (int i = 5; i >= 0; --i)
if (fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
void push_up(int p) {
cnt[p] = cnt[p << 1] + cnt[p << 1 | 1];
sum[p] = sum[p << 1] + sum[p << 1 | 1];
mx[p] = max(mx[p << 1], mx[p << 1 | 1]);
mn[p] = min(mn[p << 1], mn[p << 1 | 1]);
}
void build(int p, int l, int r) {
if (l == r) {
sum[p] = dep[a[l]];
cnt[p] = 1;
mx[p] = mn[p] = in[a[l]];
return ;
}
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
push_up(p);
}
void update(int p, int l, int r, int ql, int qr) {
if (!sum[p]) return ;
if (l == r) {
sum[p]--;
int ID = id[mx[p]];
ID = phi[ID];
mx[p] = in[ID];
mn[p] = in[ID];
return ;
}
int mid = l + r >> 1;
if (ql <= mid) update(p << 1, l, mid, ql, qr);
if (mid < qr) update(p << 1 | 1, mid + 1, r, ql, qr);
push_up(p);
}
pii query_m(int p, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return {mx[p], mn[p]};
int mid = l + r >> 1;
pii ans = {0, INF};
if (ql <= mid) {
pii L = query_m(p << 1, l, mid, ql, qr);
ans.fir = max(L.fir, ans.fir); ans.sec = min(L.sec, ans.sec);
}
if (mid < qr) {
pii R = query_m(p << 1 | 1, mid + 1, r, ql, qr);
ans.fir = max(R.fir, ans.fir); ans.sec = min(R.sec, ans.sec);
}
return ans;
}
pii query(int p, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) return {sum[p], cnt[p]};
int mid = l + r >> 1;
pii ans = {0, 0};
if (ql <= mid) {
pii L = query(p << 1, l, mid, ql, qr);
ans.fir += L.fir; ans.sec += L.sec;
}
if (mid < qr) {
pii R = query(p << 1 | 1, mid + 1, r, ql, qr);
ans.fir += R.fir; ans.sec += R.sec;
}
return ans;
}
void solve(int cas) {
cin >> n >> m;
pre();
for (int i = 2; i <= 5000000; ++i) {
fa[i][0] = phi[i];
e[phi[i]].pb(i);
dep[i] = dep[phi[i]] + 1;
}
dfs(1);
init_lca(5000000);
for (int i = 1; i <= n; ++i) cin >> a[i];
build(1, 1, n);
while (m--) {
int op, l, r;
cin >> op >> l >> r;
if (op == 1) {
update(1, 1, n, l, r);
} else {
pii res = query(1, 1, n, l, r);
pii k = query_m(1, 1, n, l, r);
int lca = get_lca(id[k.fir], id[k.sec]);
cout << res.fir - dep[lca] * res.sec << '\n';
}
}
}