倍增并查集(萌萌哒)
说实话,这道题一开始我还真没想到是并查集。看到题目,第一个反应是暴力打标记,因为相同的一段的只需要找一次,打上标记后就意味着不会再对答案做出贡献。但这样显然是超时的对吧,所以就想着可不可以拿个啥数据结构来维护。虽然没有想出来。
先不说思想有没有什么bug,但感觉就算这样打,维护标记也是不好维护的。一整块打标记用线段树就很好实现,但是在下一条信息的时候,万一和前面的信息区间有重叠,准确找出又需要打多少标记还是有点麻烦的(应该是对于这两段都求一下是否有部分被打了标记,两者的标记合起来,再被区间长度减去才是这一次对ans的贡献,但是怎么求出两者的标记合起来覆盖了多长的区间呢?不可能是单纯的相加,它的位置是不好定位的)
所以我的代码一开始是这样的(十分暴力但还有一个点没有考虑到)
for(int i=1;i<=m;++i) { w[i].l1=read();w[i].r1=read(); w[i].l2=read();w[i].r2=read(); w[i].c=w[i].r1-w[i].l1+1; int s1=w[i].l1,s2=w[i].l2; while(s1<=w[i].r1&&s2<=w[i].r2) { if(!flagg[s1]&&!flagg[s2])ans++; flagg[s1]=1,flagg[s2]=1; s1++;s2++; } } for(int i=1;i<=n;++i) if(!flagg[i])ans++;
然后为什么这样打连思想都是错的呢?
我对拍的时候发现了这样一组数据
8 3
6 6 2 2
3 3 7 7
6 7 3 4
输出的答案是90000,但正确的答案应该是9000。
于是发现这种思想是有bug的==
单纯的打标记是不对的,比如这组数据:我们给6,2打上标记,ans+1,再给3,7打上标记,ans+1,然后第三条信息的时候,我们不会再打标记
加上1,5,8三个没有标记的,最后的ans是5
但是这个打标记应该是有标记的“序号”的
第一条信息
我们给6,2打上标记1
第二条信息
我们给3,7打上标记2
第三条信息
我们发现6,3是一样的,那么标记1和标记2就是一样的,也就是说2,6,3,7,4都是一样的数!那么只对ans贡献1
所以最后的ans是4
再把模型抽象一下,这,这不就是并查集嘛==
朴素的思考就是每个点都建并查集,每条信息时把对应的点一个一个合并。
但时间过不了,所以这里就有一个很神奇的算法,倍增并查集,把n优化成logn就可以过了
#include<bits/stdc++.h> #define N 100003 #define mod 1000000007 #define LL long long using namespace std; int read() { int x=0,f=1;char s=getchar(); while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();} while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();} return x*f; } int f[N][22]; int getfa(int x,int j) { if(f[x][j]==x) return x; return f[x][j]=getfa(f[x][j],j); } void merge(int x,int y,int j) { f[getfa(x,j)][j]=getfa(y,j); } int main() { int n=read(),m=read(); int ans=0,op=0; for(int i=1;i<=n;++i) for(int j=0;j<=20;++j) f[i][j]=i;//从i起始,2^j的长度的区间 for(int i=1;i<=m;++i) { int l1=read(),r1=read(); int l2=read(),r2=read(); for(int j=20;j>=0;--j) if(l1+(1<<j)-1<=r1){merge(l1,l2,j);l1+=(1<<j);l2+=(1<<j);}//找到最大的区间合并,再起点挪动 } for(int j=20;j>=1;--j) { for(int i=1;i+(1<<j)-1<=n;++i)//!!! 边界处理注意 { merge(i,getfa(i,j),j-1); //再一层层将区间分半合并 merge(i+(1<<(j-1)),getfa(i,j)+(1<<(j-1)),j-1); } } int md=20; for(int i=1;i<=n;++i) { if(getfa(i,0)==i)ans++; } LL res=1; ans--; res=res*9%mod; for(int i=1;i<=ans;++i) res=res*10%mod; printf("%lld\n",res); }