BZOJ4012 [HNOI2015]开店 【动态点分治 + splay】
题目链接
题解
Mychael并没有A掉,而是T掉了
讲讲主要思路
在点分树上每个点开两棵\(splay\),
平衡树\(A\)维护子树中各年龄到根的距离
平衡树\(B\)维护子树中各年龄到点分树父亲的距离
然后询问就可以在点分树上用两棵平衡树相减计算了
大常数\(O(nlog^2n)\)被卡死
// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<map>
#define Redge(u) for (register int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (register int i = 1; i <= (n); i++)
#define mp(a,b) make_pair<int,LL>(a,b)
#define cls(s) memset(s,0,sizeof(s))
#define cp pair<int,LL>
#define LL long long int
#define ls ch[u][0]
#define rs ch[u][1]
#define isr(u) (fa[u] && ch[fa[u]][1] == u)
#define res register
using namespace std;
const int maxn = 200005,maxm = 10000005,INF = 1000000000;
inline int read(){
res int out = 0,flag = 1; res char c = getchar();
while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
return out * flag;
}
inline void write(LL x){
if (x / 10) write(x / 10);
putchar(x % 10 + '0');
}
int Siz[maxm],siz[maxm],w[maxm],ch[maxm][2],fa[maxm],cnt;
LL Sum[maxm],sum[maxm];
struct Splay_Tree{
int rt;
void upd(int u){
Siz[u] = Siz[ls] + Siz[rs] + siz[u];
Sum[u] = Sum[ls] + Sum[rs] + sum[u];
}
void spin(int u){
int s = isr(u),f = fa[u];
fa[u] = fa[f]; if (fa[f]) ch[fa[f]][isr(f)] = u;
ch[f][s] = ch[u][s ^ 1]; if (ch[u][s ^ 1]) fa[ch[u][s ^ 1]] = f;
fa[f] = u; ch[u][s ^ 1] = f;
upd(f); upd(u);
}
void splay(int u,int f = 0){
for (; fa[u] != f; spin(u))
if (fa[fa[u]] != f) spin((isr(u) ^ isr(fa[u])) ? u : fa[u]);
if (!f) rt = u;
}
void insert(int& u,int f,int v,int Val){
if (!u){
w[u = ++cnt] = v; Sum[u] = sum[u] = Val;
fa[u] = f; siz[u] = Siz[u] = 1; splay(u);
}
else if (w[u] > v) insert(ls,u,v,Val);
else if (w[u] < v) insert(rs,u,v,Val);
else {Sum[u] += Val; sum[u] += Val; Siz[u]++; siz[u]++; splay(u);}
}
void ins(int pos,int v){insert(rt,0,pos,v);}
int pre(int u,int v){
if (!u) return 0;
if (w[u] >= v) return pre(ls,v);
else {
int t = pre(rs,v);
return t ? t : u;
}
}
int post(int u,int v){
if (!u) return 0;
if (w[u] <= v) return post(rs,v);
else {
int t = post(ls,v);
return t ? t : u;
}
}
cp query(int l,int r){
int L = pre(rt,l),R = post(rt,r);
splay(L); splay(R,L);
if (!ch[R][0]) return mp(0,0);
return mp(Siz[ch[R][0]],Sum[ch[R][0]]);
}
void init(){rt = 0; ins(-INF,0); ins(INF,0);}
}A[maxn],B[maxn];
LL ans;
int n,m,Limit,val[maxn],L,R;
int h[maxn],ne = 1;
struct EDGE{int to,nxt,w;}ed[maxn << 1];
inline void build(int u,int v,int w){
ed[++ne] = (EDGE){v,h[u],w}; h[u] = ne;
ed[++ne] = (EDGE){u,h[v],w}; h[v] = ne;
}
LL Dis[maxn][23];
int F[maxn],Fa[maxn],size[maxn],vis[maxn],N,rt;
void getrt(int u){
F[u] = 0; size[u] = 1;
Redge(u) if (!vis[to = ed[k].to] && to != Fa[u]){
Fa[to] = u; getrt(to);
size[u] += size[to];
F[u] = max(F[u],size[to]);
}
F[u] = max(F[u],N - size[u]);
if (F[u] < F[rt]) rt = u;
}
int c[maxn],d[maxn],ci;
void dfs1(int u){
size[u] = 1; c[++ci] = u;
Redge(u) if (!vis[to = ed[k].to] && to != Fa[u]){
Fa[to] = u; d[to] = d[u] + ed[k].w;
dfs1(to);
size[u] += size[to];
}
}
int pre[maxn],dep[maxn];
void solve(int u,int D){
vis[u] = true; size[u] = 1; ci = 0; d[u] = 0; dep[u] = D;
A[u].init(); A[u].ins(val[u],0); B[u].init();
Redge(u) if (!vis[to = ed[k].to]){
Fa[to] = u; d[to] = d[u] + ed[k].w;
dfs1(to);
}
int v = pre[u];
REP(i,ci){
A[u].ins(val[c[i]],d[c[i]]);
Dis[c[i]][D] = d[c[i]];
}
if (v){
B[u].ins(val[u],Dis[u][D - 1]);
REP(i,ci) B[u].ins(val[c[i]],Dis[c[i]][D - 1]);
}
Redge(u) if (!vis[to = ed[k].to]){
N = size[to]; F[rt = 0] = INF;
getrt(to); pre[rt] = u;
solve(rt,D + 1);
}
}
void work(int x){
ans = A[x].query(L,R).second;
LL dd; cp t1,t2;
for (res int u = x,i = dep[x] - 1; pre[u]; u = pre[u],i--){
dd = Dis[x][i];
t1 = A[pre[u]].query(L,R);
t2 = B[u].query(L,R);
ans += t1.second - t2.second + dd * (t1.first - t2.first);
}
printf("%lld\n",ans);
}
int main(){
n = read(); m = read(); Limit = read();
LL a,b,w;
for (res int i = 1; i <= n; i++) val[i] = read();
for (res int i = 1; i < n; i++){
a = read(); b = read(); w = read();
build(a,b,w);
}
N = n; F[rt = 0] = INF;
getrt(1);
solve(rt,0);
int u;
while (m--){
u = read(); a = read(); b = read();
L = (a + ans) % Limit;
R = (b + ans) % Limit;
if (L > R) swap(L,R);
work(u);
}
return 0;
}