dsu on tree入门
最近入门了dsu on tree的坑,应该是略有些感悟....仅此记录下.
首先dsu应该是并查集的意思,那如字面意思,dsu on tree的意思就是在树上并查集喽!呸,当然不是,它只是借助了并查集按秩合并的思想,通过优化搜索的顺序,起到优化复杂度的作用!
感觉这点和莫队有点像,莫队不也是用两个指针暴力的跳来跳去的吗?其优化时间的原理是保存当前询问区间的信息,从而只修改下一个询问区间与当前区间不同的地方来加快程序的效率!
当然这种优化不是随随便便的,莫队是经过严格的排序与证明,最后得出O(n根号n)的复杂度.当然这样的原理也注定了其询问的区间必须是本质相同的信息,必须满足两个询问区间的交集两个区间都能用才行.
好了,不说莫队,莫名的跑题了....
其实本人觉得dsu on tree和莫队真的很像,都是统计信息且用其共同部分来加快询问.所以能用dsu on tree写的题必须也要满足这一性质,即交集是有效的.
好了,不逼逼了,直接进入正题吧,到底什么是dsu on tree?
先来简述下定义:关于统计子树内信息的静态题目,可以优先统计轻儿子的信息,之后清除其影响,再统计重儿子的信息,保留他的影响.之后利用重儿子的存留信息再重新统计轻儿子的信息,并计算当前点的答案.
什么?你不知道什么是重儿子,什么是轻儿子,抱歉请先学习树剖...
其实重儿子就是所有儿子中子树中点最多的一个儿子.
考虑下这样做的复杂度为何可以节省时间,从每个点出发,都便利了两遍的轻儿子,一遍的重儿子,其中轻儿子+重儿子==所有儿子,这样的复杂度是O(n)的,剩下就需要算从每个点出发便利一遍轻儿子的复杂度.
首先有个显然的性质:轻儿子的子树数量一定<=当前节点子树数量的一半.显然,不然它就不是轻儿子了.所以我们每次往下便利轻儿子时,每找一个轻儿子,节点数量就可以减少一半.故从根节点出发到任意一个点不会经过超过logn条的轻边(父亲连向轻儿子的边).所以我们便利所有轻儿子的复杂度是log的,所以dsu on tree的复杂度就是nlogn的!
这里放一下模板:
https://codeforces.com/contest/600/problem/E
这个真的算是dsu的模板题了,题目要求统计树中每个节点字数内最多颜色的编号和.
//不等,不问,不犹豫,不回头. #include<bits/stdc++.h> #define _ 0 #define db double #define RE register #define ll long long #define P 1000000007 #define INF 1000000000 #define int ll #define get(x) x=read() #define PLI pair<ll,int> #define PII pair<int,int> #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define pb(x) push_back(x) #define ull unsigned long long #define put(x) printf("%d\n",x) #define putl(x) printf("%lld\n",x) #define rep(i,x,y) for(RE int i=x;i<=y;++i) #define fep(i,x,y) for(RE int i=x;i>=y;--i) #define go(x) for(int i=link[x],y=a[i].y;i;y=a[i=a[i].next].y) using namespace std; const int N=1e5+10; int dfn[N],num,size[N],wson[N],link[N],c[N],tot,ans[N],f[N],cnt[N],pre[N]; int l[N],r[N],skp,sum,mx,n; struct edge{int y,next;}a[N<<1]; inline int read() { int x=0,ff=1; char ch=getchar(); while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();} while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*ff; } inline void add(int x,int y) { a[++tot].y=y;a[tot].next=link[x];link[x]=tot; a[++tot].y=x;a[tot].next=link[y];link[y]=tot; } inline void dfs1(int x) { dfn[x]=++num;pre[num]=x; l[x]=num;size[x]=1; go(x) { if(y==f[x]) continue; f[y]=x;dfs1(y); size[x]+=size[y]; if(size[y]>size[wson[x]]) wson[x]=y; } r[x]=num; } inline void get_data(int x,int op) { rep(i,l[x],r[x]) { if(pre[i]==skp) { i=r[skp]; continue; } cnt[c[pre[i]]]+=op; if(cnt[c[pre[i]]]>mx) mx=cnt[c[pre[i]]],sum=c[pre[i]];//这个更新sum和mx只在更新答案起作用, else if(cnt[c[pre[i]]]==mx) sum+=c[pre[i]];//因为不需要保留时,之后还会将其清零.所以直接这样写不影响. } } inline void dsu(int x,bool del) { go(x) { if(y==f[x]||y==wson[x]) continue; dsu(y,1);//便利轻儿子,且不保留轻儿子的信息. } if(wson[x]) { dsu(wson[x],0);//重儿子信息保留. skp=wson[x];//标记重儿子,便于在之后的统计信息时,跳过重儿子的便利. } get_data(x,1);//得到其他轻儿子的信息. ans[x]=sum;//标记答案. if(del) { skp=0; get_data(x,-1);//如果不需要保留的话,删除信息. sum=mx=0; } } signed main() { // freopen("1.in","r",stdin); get(n); rep(i,1,n) get(c[i]); rep(i,1,n-1) { int get(x),get(y); add(x,y); } dfs1(1);dsu(1,0); rep(i,1,n) printf("%lld ",ans[i]); return (0^_^0); } //以吾之血,祭吾最后的亡魂
这里统计轻儿子的信息时,我采用了时间戳的形式,反常与平时模板用的递归,主要是鄙人觉得递归自带大常数.
这个是第二题,大家都说是模板题,方正我是做了很长长长....时间.
//不等,不问,不犹豫,不回头. //优先考虑暴力的做法.考虑一个点x,如何统计经过x的最长合法链...考虑和树形DP求直径类似的方法求这个东西. //只不过这个不能直接简单的拼接.考虑对于他的每一个儿子y以此考虑,我们用一个数组存一下之前的儿子的信息. //扫到一个儿子时,直接查一下数组找到信息与当前的儿子进行拼接就行了.然后再把当前儿子的信息丢进数组即可. //注意这里不能边查找边更新,由于我们要求必经x,而x到当前儿子y只有一条边,所以y只能贡献一条链,所以我们只能将 //y所有的节点都尝试拼接后,再用y的节点进行更新数组,这样就保证了每次拼接都是在x种不同儿子之间的. //这里的数组考虑记录什么,我们对于当前儿子y的每一个节点都扫一下,由于我们要求的合法链两个端点异或起来必须是 //0或是2^x的形式,所以我们可以直接枚举异或后的结果,看是否存在能和当前节点异或为当前值得链就行了. //所以我们数组记录的应该是值域.由于我们关心的只有链长而已,所以直接保存最长深度的就可以了. #include<bits/stdc++.h> #define _ 0 #define db double #define RE register #define ll long long #define P 1000000007 #define INF 1000000000 #define get(x) x=read() #define PLI pair<ll,int> #define PII pair<int,int> #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define pb(x) push_back(x) #define ull unsigned long long #define getc(c) scanf("%s",c+1) #define put(x) printf("%d\n",x) #define putl(x) printf("%lld\n",x) #define rep(i,x,y) for(RE int i=x;i<=y;++i) #define fep(i,x,y) for(RE int i=x;i>=y;--i) #define go(x) for(int i=link[x],y=a[i].y;i;y=a[i=a[i].next].y) using namespace std; const int N=5e5+10,M=1<<22; int link[N],tot,n,d[N],f[N],size[N],wson[N],D[N],C[M];//D表示点i到根节点的异或和,C表示权值为D[i]的最大深度. int ans[N],now; char str[10]; struct edge{int y,next,v;}a[N]; inline int read() { int x=0,ff=1; char ch=getchar(); while(!isdigit(ch)) {if(ch=='-') ff=-1;ch=getchar();} while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*ff; } inline void add(int x,int y,int v) { a[++tot].y=y;a[tot].v=v;a[tot].next=link[x];link[x]=tot; } inline void dfs1(int x) { size[x]=1; go(x) { f[y]=x;d[y]=d[x]+1; D[y]=D[x]^a[i].v; dfs1(y); size[x]+=size[y]; if(size[y]>size[wson[x]]) wson[x]=y; } } inline void updata(int x) { C[D[x]]=max(C[D[x]],d[x]); go(x) updata(y); } inline void dfs(int x,bool ad) { if(ad) { ans[now]=max(ans[now],d[x]+C[D[x]]); rep(i,0,21) ans[now]=max(ans[now],d[x]+C[D[x]^1<<i]); } else C[D[x]]=-INF; go(x) dfs(y,ad); } inline void dsu(int x,bool del) { go(x) { if(y==wson[x]) continue; dsu(y,1); } if(wson[x]) dsu(wson[x],0); now=x; go(x) { if(y==wson[x]) continue; dfs(y,1);updata(y); } C[D[x]]=max(C[D[x]],d[x]); ans[x]=max(ans[x],d[x]+C[D[x]]); rep(i,0,21) ans[x]=max(ans[x],d[x]+C[D[x]^1<<i]); ans[x]-=d[x]<<1; go(x) ans[x]=max(ans[x],ans[y]); if(del) dfs(x,0); } int main() { //freopen("1.in","r",stdin); get(n); rep(i,2,n) { int get(x);getc(str); add(x,i,1<<(str[1]-'a')); } memset(C,0xcf,sizeof(C)); d[1]=1;dfs1(1);dsu(1,0); rep(i,1,n) printf("%d ",ans[i]); return (0^_^0); } //以吾之血,祭吾最后的亡魂
陈年更新了,最近又看了一道dsu on tree的题目,感觉还是挺好的就在这写一下,以免将这个知识点完全忘记。
题目意思也含简单,就是给定你一个树,树上每个节点都有一个权值ai,问满足所有点对ai^aj=a lca(i,j),即所有满足两个点的权值异或起来等于其lca的ai值得点对的i^j的和。
首先想到的就是枚举每一个子树,然后统计该子树根节点对答案的贡献。由于必须保证两个节点的lca是当前子树的根节点,我们要保证两个节点必须来自该子树中的两个不同子树才行。最初的想法就是开个vector,存节点权值为x的节点编号有哪些,但发现统计答案的时候可能会被卡成O(n^2)的....这就很...,这个时候我们就要用到异或的性质了,因为异或也可以写成不进位加法,所以他可以把每一位拆开异或再相加。例如:10110^01110=10000^0000+0000^1000+100^100+10^10+0^0=11001.所以我们可以这样记,cnt[x][20][0/1],表示权值为x的第i位为0/1的个数。每次找到一个权值为k的节点,我们先找到他需要配对的节点的权值,为k^a root,然后将当前节点编号二进制拆分,一位一位统计答案。最后将这个节点标号也统计进cnt数组即可。
#include<bits/stdc++.h> #define ll long long using namespace std; const int N=1e5+10,M=1e6+10,qwq=1e6; int n,a[N],cnt[M][22][2],wson[N],size[N],f[N]; int lca; ll ans; vector<int>son[N]; inline void dfs1(int x,int fa) { size[x]=1; for(auto y:son[x]) { if(y==fa) continue; dfs1(y,x);f[y]=x; size[x]+=size[y]; if(size[y]>size[wson[x]]) wson[x]=y; } } inline void dfs(int x,int op) { //op为1表示统计答案。 //op为2表示将该节点加入cnt数组中。 for(int j=20;j>=0;--j) { int d=(x&(1<<j))?1:0; if(op==1) { if((a[x]^a[lca])<=qwq) ans+=(ll)(1<<j)*cnt[a[x]^a[lca]][j][d^1]; } else if(op==2) cnt[a[x]][j][d]++; else if(op==3) cnt[a[x]][j][d]--; } for(auto y:son[x]) { if(y==f[x]) continue; dfs(y,op); } } inline void dsu(int x,bool del) { for(auto y:son[x]) { if(y==f[x]||y==wson[x]) continue; dsu(y,1); } if(wson[x]) dsu(wson[x],0); lca=x; for(auto y:son[x]) { if(y==f[x]||y==wson[x]) continue; dfs(y,1);dfs(y,2); } for(int j=20;j>=0;--j) { int d=(x&(1<<j))?1:0; cnt[a[x]][j][d]++; } if(del) dfs(x,3); } int main() { // freopen("1.in","r",stdin); scanf("%d",&n); for(int i=1;i<=n;++i) scanf("%d",&a[i]); for(int i=1;i<n;++i) { int x,y; scanf("%d%d",&x,&y); son[x].push_back(y); son[y].push_back(x); } dfs1(1,0);dsu(1,0); printf("%lld",ans); return 0; }