NOIP 模拟 $90\; \rm 校门外歪脖树上的鸽子$
题解 \(by\;zj\varphi\)
树上问题,采用树链剖分。
考虑一下线段树区间查询的过程,放到这题上就是:递归下去的从儿子的区间交点分开。
如果递归右儿子,那么就会对左儿子造成贡献,反之同理。
具体实现就是开两棵线段树,一棵表示从右链递归,对左儿子的贡献,另一棵相反。
求出每个点从下往上走左子树到的深度最浅的点,和走右子树到的。
注意根节点的兄弟是它自己。
Code
#include<bits/stdc++.h>
#define ri signed
#define pd(i) ++i
#define bq(i) --i
#define func(x) std::function<x>
namespace IO{
char buf[1<<21],*p1=buf,*p2=buf;
#define gc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?(-1):*p1++
#define debug1(x) std::cerr << #x"=" << x << ' '
#define debug2(x) std::cerr << #x"=" << x << std::endl
#define Debug(x) assert(x)
struct nanfeng_stream{
template<typename T>inline nanfeng_stream &operator>>(T &x) {
bool f=false;x=0;char ch=gc();
while(!isdigit(ch)) f|=ch=='-',ch=gc();
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=gc();
return x=f?-x:x,*this;
}
}cin;
}
using IO::cin;
namespace nanfeng{
#define mk std::make_pair
#define FI FILE *IN
#define FO FILE *OUT
template<typename T>inline T cmax(T x,T y) {return x>y?x:y;}
template<typename T>inline T cmin(T x,T y) {return x>y?y:x;}
using ull=long long;
static const int N=4e5+7;
int ch[N][2],ws[N],top[N],siz[N],hs[N],le[N],ll[N],rl[N],ul[N],ur[N],br[N],fa[N],dfn[N],bc[N],dep[N],tot,X,Y,n,m,opt,l,r,w,rt,typ,al;
std::map<std::pair<int,int>,int> mp;
func(void(int)) dfs1=[](int x) {
siz[x]=le[x]=1;
int sl=ch[x][0],sr=ch[x][1];
if (!fa[x]) ul[x]=ur[x]=x;
else ul[x]=!ws[x]?ul[fa[x]]:x,ur[x]=ws[x]?ur[fa[x]]:x;
ll[x]=rl[x]=x;
dep[x]=dep[fa[x]]+1;
if (x>n) {
dfs1(sl),dfs1(sr);
siz[x]+=siz[sl]+siz[sr];
hs[x]=siz[sl]>siz[sr]?sl:sr;
ll[x]=ll[sl],rl[x]=rl[sr];
le[x]=le[sl]+le[sr];
}
mp[mk(ll[x],rl[x])]=x;
};
func(void(int,int)) dfs2=[](int x,int tp) {
dfn[bc[tot]=x]=++tot;
top[x]=tp;
if (hs[x]) dfs2(hs[x],tp);
if (x>n) {
int v=ch[x][0]==hs[x]?ch[x][1]:ch[x][0];
dfs2(v,v);
}
};
struct Segmenttree{
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define up(x) T[x].sum=T[ls(x)].sum+T[rs(x)].sum
struct segmenttree{ull sum,le,lz;}T[N<<2];
func(void(int,int,int)) build=[&](int x,int l,int r) {
if (l==r) {
int k=bc[l];
if (ws[k]!=typ) T[x].le=le[br[k]];
return;
}
int mid=(l+r)>>1;
build(ls(x),l,mid);
build(rs(x),mid+1,r);
T[x].le=T[ls(x)].le+T[rs(x)].le;
};
func(void(int)) down=[&](int x) {
if (!T[x].lz) return;
T[ls(x)].lz+=T[x].lz,T[ls(x)].sum+=T[ls(x)].le*T[x].lz;
T[rs(x)].lz+=T[x].lz,T[rs(x)].sum+=T[rs(x)].le*T[x].lz;
T[x].lz=0;
};
func(void(int,int,int,int,int,int)) update=[&](int x,int k,int l,int r,int lt,int rt) {
if (l<=lt&&rt<=r) return T[x].lz+=k,T[x].sum+=T[x].le*k,void();
int mid=(lt+rt)>>1;
down(x);
if (l<=mid) update(ls(x),k,l,r,lt,mid);
if (r>mid) update(rs(x),k,l,r,mid+1,rt);
up(x);
};
func(ull(int,int,int,int,int)) query=[&](int x,int l,int r,int lt,int rt) {
if (l<=lt&&rt<=r) return T[x].sum;
int mid=(lt+rt)>>1;
ull res=0;
down(x);
if (l<=mid) res+=query(ls(x),l,r,lt,mid);
if (r>mid) res+=query(rs(x),l,r,mid+1,rt);
return res;
};
}T[2];
auto Aupdate=[](int x,int w) {
if (ws[x]) T[0].update(1,w,dfn[br[x]],dfn[br[x]],1,al);
else T[1].update(1,w,dfn[br[x]],dfn[br[x]],1,al);
};
auto Aquery=[](int x) {
if (ws[x]) return T[0].query(1,dfn[br[x]],dfn[br[x]],1,al);
else return T[1].query(1,dfn[br[x]],dfn[br[x]],1,al);
};
auto update=[](const int opt,int x,int v,int w) {
while(top[x]!=top[v]) {
T[opt].update(1,w,dfn[top[x]],dfn[x],1,al);
x=fa[top[x]];
}
if (x!=v) T[opt].update(1,w,dfn[v]+1,dfn[x],1,al);
};
auto query=[](const int opt,int x,int v) {
ull res=0;
while(top[x]!=top[v]) {
res+=T[opt].query(1,dfn[top[x]],dfn[x],1,al);
x=fa[top[x]];
}
if (x!=v) res+=T[opt].query(1,dfn[v]+1,dfn[x],1,al);
return res;
};
auto Getlca=[](int x,int v) {
while(top[x]!=top[v]) {
if (dep[top[x]]<dep[top[v]]) std::swap(x,v);
x=fa[top[x]];
}
return dep[x]<dep[v]?x:v;
};
auto find=[](int x,int v) {
x=top[x];
while(x!=top[v]) {
if (fa[x]==v) return x;
x=top[fa[x]];
}
return hs[v];
};
inline int main() {
FI=freopen("pigeons.in","r",stdin);
FO=freopen("pigeons.out","w",stdout);
cin >> n >> m;
al=(n<<1)-1;
memset(ws,-1,sizeof(ws));
for (ri i(1);i<n;pd(i)) {
cin >> X >> Y;
ch[n+i][0]=X,ch[n+i][1]=Y;
ws[X]=0,ws[Y]=1;
fa[X]=fa[Y]=n+i;
br[X]=Y,br[Y]=X;
}
for (ri i(1);i<=(n<<1)-1;pd(i)) if (!fa[i]) {rt=i;break;}
br[rt]=rt;
dfs1(rt),dfs2(rt,rt);
typ=1,T[0].build(1,1,al);
typ=0,T[1].build(1,1,al);
for (ri i(1);i<=m;pd(i)) {
cin >> opt >> l >> r;
int k=-1;
std::pair<int,int> tmp=mk(l,r);
if (mp.find(tmp)!=mp.end()) k=mp[tmp];
if (opt==1) {
cin >> w;
if (k!=-1) Aupdate(k,w);
else {
int nx=ul[l],ny=ur[r],lca=Getlca(l,r);
if (dep[nx]<=dep[lca]) Aupdate(find(l,lca),w);
else Aupdate(nx,w),update(0,nx,find(l,lca),w);
if (dep[ny]<=dep[lca]) Aupdate(find(r,lca),w);
else Aupdate(ny,w),update(1,ny,find(r,lca),w);
}
} else {
ull res=0;
if (k!=-1) res=Aquery(k);
else {
int nx=ul[l],ny=ur[r],lca=Getlca(l,r);
// debug1(nx),debug1(ny),debug2(lca);
if (dep[nx]<=dep[lca]) res+=Aquery(find(l,lca));
else res+=Aquery(nx)+query(0,nx,find(l,lca));
if (dep[ny]<=dep[lca]) res+=Aquery(find(r,lca));
else res+=Aquery(ny)+query(1,ny,find(r,lca));
}
printf("%llu\n",res);
}
}
return 0;
}
}
int main() {return nanfeng::main();}