题解 校门外歪脖树上的鸽子
毒瘤数据结构
正解思路很特别
对于一个闭区间修改 \([l, r]\),将其写成开区间 \((l-1, r+1)\)
于是类似zkw线段树,我们发现在原树上向上跳链(到lca的孙子辈)的过程中应该修改的节点恰好是访问到的节点的兄弟
于是分成左链和右链分别树剖,每个节点维护其兄弟的信息
然后因为写成了开区间,需要加两个虚点0和n+1
发现这两个虚点恰好使得包含边界的修改可以打在相应的节点上了(相当于提高了边界的优先级/为边界提供了兄弟)
于是可以树剖修改了
复杂度 \(O(nlog^2n)\)
Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 400010
#define ll long long
#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, m;
int tl[N], tr[N], ls[N], rs[N], rot;
bool vis[N];
namespace force{
ll dat[N];
void build(int p) {
if (!ls[p]) {tl[p]=tr[p]=p; return ;}
build(ls[p]); build(rs[p]);
tl[p]=tl[ls[p]]; tr[p]=tr[rs[p]];
// cout<<"p: "<<p<<' '<<tl[p]<<' '<<tr[p]<<endl;
}
void upd(int p, int l, int r, int val) {
if (l<=tl[p]&&r>=tr[p]) {dat[p]+=1ll*val*(tr[p]-tl[p]+1); return ;}
if (l<=tr[ls[p]]) upd(ls[p], l, r, val);
if (r>=tl[rs[p]]) upd(rs[p], l, r, val);
}
ll query(int p, int l, int r) {
if (l<=tl[p]&&r>=tr[p]) return dat[p];
ll ans=0;
if (l<=tr[ls[p]]) ans+=query(ls[p], l, r);
if (r>=tl[rs[p]]) ans+=query(rs[p], l, r);
return ans;
}
void solve() {
build(rot);
for (int i=1,l,r,d; i<=m; ++i) {
if (read()&1) {
l=read(); r=read(); d=read();
upd(rot, l, r, d);
}
else {
l=read(); r=read();
printf("%lld\n", query(rot, l, r));
}
}
exit(0);
}
}
namespace task{
int dep[N], top[N], id[N], rk[N], siz[N], siz2[N], fa[N], mson[N], tot, num;
ll extra;
struct seg{
bool whitch; // 1->left
int tl[N<<2], tr[N<<2];
ll tag[N<<2], k[N<<2], dat[N<<2];
seg(bool t):whitch(t){}
#define tl(p) tl[p]
#define tr(p) tr[p]
#define tag(p) tag[p]
#define k(p) k[p]
#define dat(p) dat[p]
#define pushup(p) dat(p)=dat(p<<1)+dat(p<<1|1)
void spread(int p) {
if (!tag(p)) return ;
dat(p<<1)+=tag(p)*k(p<<1); tag(p<<1)+=tag(p);
dat(p<<1|1)+=tag(p)*k(p<<1|1); tag(p<<1|1)+=tag(p);
tag(p)=0;
}
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r;
if (l==r) {k(p)=whitch?(ls[fa[rk[l]]]==rk[l]?siz2[rs[fa[rk[l]]]]:0):(rs[fa[rk[l]]]==rk[l]?siz2[ls[fa[rk[l]]]]:0); return ;}
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
k(p)=k(p<<1)+k(p<<1|1);
}
void upd(int p, int l, int r, ll val) {
if (l<=tl(p)&&r>=tr(p)) {dat(p)+=val*k(p); tag(p)+=val; return ;}
spread(p);
int mid=(tl(p)+tr(p))>>1;
if (l<=mid) upd(p<<1, l, r, val);
if (r>mid) upd(p<<1|1, l, r, val);
pushup(p);
}
ll query(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) return dat(p);
spread(p);
int mid=(tl(p)+tr(p))>>1; ll ans=0;
if (l<=mid) ans+=query(p<<1, l, r);
if (r>mid) ans+=query(p<<1|1, l, r);
return ans;
}
}left(1), right(0);
void dfs1(int u) {
// cout<<"dfs1: "<<u<<endl;
siz[u]=u<=n+1; siz2[u]=u&&u<=n;
if (~ls[u]) {
dep[ls[u]]=dep[u]+1; fa[ls[u]]=u;
dfs1(ls[u]);
siz[u]+=siz[ls[u]]; siz2[u]+=siz2[ls[u]];
}
if (~rs[u]) {
dep[rs[u]]=dep[u]+1; fa[rs[u]]=u;
dfs1(rs[u]);
siz[u]+=siz[rs[u]]; siz2[u]+=siz2[rs[u]];
}
mson[u]=siz[ls[u]]>siz[rs[u]]?ls[u]:rs[u];
}
void dfs2(int u, int t) {
top[u]=t;
id[u]=++tot; rk[tot]=u;
if (ls[u]==-1 && rs[u]==-1) return ;
if (mson[u]==ls[u]) dfs2(ls[u], t), dfs2(rs[u], rs[u]);
else dfs2(rs[u], t), dfs2(ls[u], ls[u]);
}
int lca(int a, int b) {
while (top[a]!=top[b]) {
if (dep[top[a]]<dep[top[b]]) swap(a, b);
a=fa[top[a]];
}
return dep[a]<dep[b]?a:b;
}
void ladd(int s, int t, ll val) {
// cout<<"ladd: "<<s<<' '<<t<<' '<<val<<endl;
while (top[s]!=top[t]) {
left.upd(1, id[top[s]], id[s], val);
// cout<<"show: "<<s<<' '<<top[s]<<endl;
// cout<<"add: "<<id[top[s]]<<' '<<id[s]<<endl;
s=fa[top[s]];
}
if (id[t]+1<=id[s]) left.upd(1, id[t]+1, id[s], val); //, cout<<"add2: "<<id[t]+1<<' '<<id[s]<<endl;
}
void radd(int s, int t, ll val) {
// cout<<"radd: "<<s<<' '<<t<<' '<<val<<endl;
while (top[s]!=top[t]) {
right.upd(1, id[top[s]], id[s], val);
// cout<<"add: "<<id[top[s]]<<' '<<id[s]<<endl;
s=fa[top[s]];
}
if (id[t]+1<=id[s]) right.upd(1, id[t]+1, id[s], val); //, cout<<"add2: "<<id[t]+1<<' '<<id[s]<<endl;
}
ll lqsum(int s, int t) {
ll ans=0;
while (top[s]!=top[t]) {
ans+=left.query(1, id[top[s]], id[s]);
// cout<<"ans: "<<ans<<endl;
s=fa[top[s]];
}
if (id[t]+1<=id[s]) ans+=left.query(1, id[t]+1, id[s]);
return ans;
}
ll rqsum(int s, int t) {
// cout<<"rqsum: "<<s<<' '<<t<<endl;
ll ans=0;
while (top[s]!=top[t]) {
ans+=right.query(1, id[top[s]], id[s]);
s=fa[top[s]];
}
if (id[t]+1<=id[s]) ans+=right.query(1, id[t]+1, id[s]);
return ans;
}
void upd(int l, int r, ll val) {
if (l<1&&r>n) {extra+=val; return ;}
// cout<<"upd: "<<l<<' '<<r<<' '<<val<<endl;
int t=lca(l, r);
// cout<<"lca: "<<t<<endl;
ladd(l, ls[t], val); radd(r, rs[t], val);
}
ll query(int l, int r) {
if (l<1&&r>n) return extra*n;
// cout<<"query: "<<l<<' '<<r<<endl;
int t=lca(l, r);
// cout<<"sum: "<<lqsum(l, ls[t])<<' '<<rqsum(r, rs[t])<<endl;
return lqsum(l, ls[t])+rqsum(r, rs[t]);
}
void solve() {
int num=n*2;
++num; ls[num]=0; rs[num]=rot; rot=num;
++num; ls[num]=rot; rs[num]=n+1; rot=num;
dep[rot]=1; dfs1(rot); dfs2(rot, rot);
// cout<<"id: "; for (int i=0; i<=num; ++i) cout<<id[i]<<' '; cout<<endl;
left.build(1, 1, num); right.build(1, 1, num);
for (int i=1,l,r,d; i<=m; ++i) {
if (read()&1) {
l=read(); r=read(); d=read();
upd(l-1, r+1, d);
}
else {
l=read(); r=read();
printf("%lld\n", query(l-1, r+1));
}
}
// cout<<"qval: "<<right.query(1, 4, 4)<<endl;
// cout<<"qval: "<<right.query(1, 8, 8)<<endl;
exit(0);
}
}
signed main()
{
freopen("pigeons.in", "r", stdin);
freopen("pigeons.out", "w", stdout);
n=read(); m=read();
memset(ls, -1, sizeof(ls));
memset(rs, -1, sizeof(rs));
for (int i=1,t1,t2; i<n; ++i) {
t1=read(); t2=read();
if (t1>n) ++t1; if (t2>n) ++t2;
vis[ls[n+i+1]=t1]=1, vis[rs[n+i+1]=t2]=1;
}
for (int i=1; i<=2*n; ++i) if (i!=n+1 && !vis[i]) rot=i;
// force::solve();
task::solve();
return 0;
}