食物链——带权并查集模板题
题意:
存在食物链 A->B->C->A (A->B 表示A吃B),在下面n中说法中,判断有多少句假话
题解:
我们通过他们之间的距离关系分为三类
0->1 记为0 (0->1)
1->2 记为1
2->0 记为2
把所有能确定关系的元素放到一个集合,每个集合的元素分为三类
d【i】表示 i到根节点的距离(也就是吃、被吃、同类的关系)
现在就是的问题如何进行集合的合并
如果输入1 x y
如果x和y在同一个集合 并且 x,y是同类 所以判断 (d[x]-d[y]+3)%3是否等于0
如果x和y不在同一个集合 那么就需要进行区间和并
首先 f[fx]=fy; 让x的父节点fx指向fy 即fx的父节点是fy
因为 x,y是同类所以他们的距离关系相等 所以 d[x]+d[fx]==d[y]
即 d[fx]=d[y]-d[x];
如果输入2 x y
如果x和y在同一个集合 并且 x->y 所以x到根节点的距离比y的距离大1 所以判断 (d[x]-d[y]-1+3)%3是否等于0
如果x和y不在同一个集合 那么就需要进行区间和并
首先 f[fx]=fy; 让x的父节点fx指向fy 即fx的父节点是fy
因为 x->y 所以他们的距离关系满足 d[x]=d[y]+1 所以d[x]+d[fx]==d[y]+1
即 d[fx]=d[y]-d[x]+1;
代码:

#include<iostream> #include<stdio.h> #include<math.h> using namespace std; typedef long long ll; const int maxn=2e5+5; int d[maxn]; int f[maxn]; int Find(int x) { if(x==f[x])return x; int root=Find(f[x]); d[x]+=d[f[x]]; return f[x]=root; } int main() { int n,k; //printf("%d",(-5)%3); scanf("%d%d",&n,&k); for(int i=1;i<=n;i++)f[i]=i; int ans=0; while(k--) { int op,x,y; scanf("%d%d%d",&op,&x,&y); if(x>n || y>n){ans++;continue;} int fx=Find(x); int fy=Find(y); if(op==1) { ///if(fx==fy && ((d[x]+3)%3)!=((d[y]+3)%3))ans++; if(fx==fy && (d[x]-d[y]+3)%3)ans++; else if(fx!=fy) { f[fx]=fy; d[fx]=d[y]-d[x];/// 因为x,y在同一个集合,所以d[x]+d[fx]==d[y]; } } else { if(fx==fy && ((d[x]%3)+3)%3!=(((d[y]+1)%3)+3)%3)ans++; ///if(fx==fy && (d[x]-d[y]-1)%3)ans++;/// 因为 x 吃 y 所以 d[x]=d[y]+1; else if(fx!=fy) { f[fx]=fy; d[fx]=d[y]-d[x]+1; } } } printf("%d\n",ans); return 0; }
把上面两种情况合并

#include<iostream> #include<stdio.h> #include<math.h> using namespace std; typedef long long ll; const int maxn=2e5+5; int f[maxn]; int n,m; int d[maxn]; int Find(int x) { if(x==f[x])return x; int root=Find(f[x]); d[x]=(d[x]+d[f[x]])%3; return f[x]=root; } void join(int op,int x,int y) { int fx=Find(x); int fy=Find(y); f[fy]=fx; d[fy]=(d[x]+op-d[y]+3)%3; } int main() { int n,m; scanf("%d%d",&n,&m); for(int i=0;i<=n;i++)f[i]=i,d[i]=0; int ans=0; while(m--) { int x,y,op; scanf("%d%d%d",&op,&x,&y); op--; if(x>n || y>n){ans++;continue;} if(Find(x)!=Find(y))join(op,x,y); else { if((d[x]+op)%3!=d[y])ans++; } } printf("%d\n",ans); return 0; }