题解【bzoj2002 [Hnoi2010]Bounce 弹飞绵羊】
Description
给 \(n\) 个点以及它们的弹力系数 \(k_i\) ,含义为 可以弹到 \(i + k_i\) 的位置。
支持两个东西,修改一个点的弹力系数;求一个点要弹多少次弹出 \(n\)
Solution
用 LCT 做。弹力系数是 \(k_i\) 可以看作是 \(i\) 和 \(i+k_i\) 连了一条边。如果弹出去了就不妨设和 \(0\) 连一条边。
对于修改操作,先把原来的边删除,修改 k 数组,再连上新边
对于查询操作,维护一个子树大小 siz (这里是 splay 上的 siz,不是原树上的),然后询问就相当于当前这个点 \(u\) 到 \(0\) 这条链上有几个点。所以就 split 出来这条链然后输出 siz - 1 就行了(注意要减 \(1\) 因为问的是弹多少次)
然后就做完了(注意输入的时候要加 1)
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 200100;
int n, m, K[N];
struct node {
int siz, rev;
node *ch[2], *prt;
int dir() { return prt->ch[1] == this; }
int isr() { return !prt || (prt->ch[0] != this && prt->ch[1] != this); }
void setc(node *p, int k) { this->ch[k] = p; if(p) p->prt = this; }
void upd() { int s = 1; if(ch[0]) s += ch[0]->siz; if(ch[1]) s += ch[1]->siz; siz = s; }
void push() { if(!rev) return ; swap(ch[0], ch[1]);
if(ch[0]) ch[0]->rev ^= 1; if(ch[1]) ch[1]->rev ^= 1; rev = 0; }
}*P[N], pool[N], *cur = pool; node *sta[N]; int top;
node *New() { node *p = cur++; p->siz = 1; return p; }
void rotate(node *p) {
node *prt = p->prt; int k = p->dir();
if(!prt->isr()) prt->prt->setc(p, prt->dir());
else p->prt = prt->prt; prt->setc(p->ch[!k], k);
p->setc(prt, !k); prt->upd(); p->upd();
}
void splay(node *p) {
node *q = p;
while(1) { sta[++top] = q; if(q->isr()) break; q = q->prt; }
while(top) sta[top--]->push();
while(!p->isr()) {
if(p->prt->isr()) rotate(p);
else if(p->dir() == p->prt->dir()) rotate(p->prt), rotate(p);
else rotate(p), rotate(p);
} p->upd();
}
node *access(node *p) { node *q = 0; for(; p; p = p->prt) splay(p), p->ch[1] = q, (q = p)->upd(); return q; }
void mkroot(node *p) { access(p); splay(p); p->rev ^= 1; p->push(); }
void split (node *p, node *q) { mkroot(p); access(q); splay(p); }
void link (node *p, node *q) { mkroot(p); mkroot(q); p->setc(q, 1); }
void cut (node *p, node *q) { split(p, q); p->ch[1] = q->prt = 0; }
int main() {
scanf("%d", &n); P[0] = New();
for(int i = 1; i <= n; i++) scanf("%d", &K[i]), P[i] = New();
for(int i = 1; i <= n; i++) {
if(i + K[i] <= n) link(P[i], P[i + K[i]]);
else link(P[i], P[0]);
} int m; scanf("%d", &m);
for(int i = 1; i <= m; i++) {
int op, u; scanf("%d %d", &op, &u); u++;
if(op == 1) {
split(P[0], P[u]); printf("%d\n", P[0]->siz - 1);
}
if(op == 2) { int k; scanf("%d", &k);
if(u + K[u] <= n) cut(P[u], P[u + K[u]]);
else cut(P[u], P[0]); K[u] = k;
if(u + K[u] <= n) link(P[u], P[u + K[u]]);
else link(P[u], P[0]);
}
}
return 0;
}