拓扑排序最长链-P3119 [USACO15JAN]草鉴定Grass Cownoisseur
https://www.luogu.org/problem/show?pid=3119
本来我是来练习tarjan的,结果tarjan部分直接copy了,反而拓扑排序部分想了好久;
这道题SZB大神两次就AC;
但我等到AC,写好题解就只能洗洗睡了;
唉~
差距怎么这么大呢?;
这道题的题意就说,你可以改变一条边的方向,去找到一个环,让环上的点数最大;
网上的题解,大多都在嚷嚷tarjan+拓扑排序最长链;
我先讲讲什么是拓扑排序最长链把;
很显然啊,上面的图里,1~3的最长链是3;
我们考虑一下最朴素的dfs;
我们从1开始,先搜索到了3;
这个时候我们把1~3的最长链更新为 1-3,就是2个节点;
假如3出度有好多,那么凡是3后面的点我们都要遍历一边;
然后遍历完3,我发现 1~3的最长链不是1-3而是1-2-3;
原先的1~3最长链的节点个数2被更新为3;
这个时候我们发现原来3连出去的点,他们与1的最长链都不是最优的;
所以我们又要dfs一遍;
这样太烦了;
那怎么办呢?
看到这里,我想您一定知道什么是拓扑排序最长链了;
对啊,假如我们搜索到3的时候先不去往后面搜索,先去遍历2;
这样1~3的最长链会被及时更新,3后面的节点就不用重复更新了;
其实这样就是按照拓扑排序的顺序去遍历节点啊,不断找入度为0的节点去更新其它节点;
这就是拓扑排序最长链
这道题就简单了啊;
我先缩点一下,让这个有环图变成有向无环图;
然后用一个dfs去算出那些点可以到达1;那些点会从1到达;
分别算出这些点到1的最长链,然后枚举每一条边,看看把这条边反一下,加上两端到1的最长链然后更新ans就好啦;
很显然啊,这两条链不会重复,因为他们的方向是不同的;
我的超级优美的代码,60行!;
超级无敌大压行!!!!
#include<cstdio>//cfb
#include<iostream>
#include<cstring>
using namespace std;
struct cs{int to,next;}a[100001]; //lin[i]是i再那个分量里面 sum[i]就是第i个分量有几个点 d[i]是当前分量i的入度
int head[100001],low[100001],tt[100001],q[100001],lin[100001],cc[100001][2],sum[100001],A[100001][1],d[100001];
bool in[100001],AA[100001][2];//AA[i][0]表示1是否可以到i,[1]是i是否可以到j;A[i][0/1]即他们的最长链
int ll,n,m,x,y,z,t,nn,l,r,ans;
void init(int x,int y){a[++ll].to=y; a[ll].next=head[x]; head[x]=ll;}
void dfs(int x){
tt[x]=++t; low[x]=t; q[++q[0]]=x; in[x]=1;
for(int k=head[x];k;k=a[k].next){
if(!tt[a[k].to])dfs(a[k].to);
if(in[a[k].to])low[x]=min(low[x],low[a[k].to]);
}
if(low[x]==tt[x]){
nn++;
while(1){
in[q[q[0]]]=0;
lin[q[q[0]]]=nn;
q[0]--; sum[nn]++;
if(q[q[0]+1]==x)break;
}
}
}
void make(int x,int num){//然后用一个dfs去算出那些点可以到达1;那些点会从1到达;
in[x]=1; AA[x][num]=1;
for(int k=head[x];k;k=a[k].next){d[a[k].to]++; if(!in[a[k].to])make(a[k].to,num);}
}
void TP(int num){//拓扑排序
r=1; q[1]=lin[1]; A[lin[1]][num]=0;
make(lin[1],num);
while(r>l){
x=q[++l];
for(int k=head[x];k;k=a[k].next){
A[a[k].to][num]=max(A[a[k].to][num],A[x][num]+sum[a[k].to]);
d[a[k].to]--;
if(!d[a[k].to])q[++r]=a[k].to;
}
}
}
void S(){
memset(q,0,sizeof q);memset(head,0,sizeof head);
memset(d,0,sizeof d);memset(in,0,sizeof in);
ll=0; r=l=0;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++){scanf("%d%d",&cc[i][0],&cc[i][1]);init(cc[i][0],cc[i][1]);}
for(int i=1;i<=n;i++)if(!tt[i])dfs(i);
S(); for(int i=1;i<=m;i++){x=cc[i][0]; y=cc[i][1]; if(lin[x]!=lin[y])init(lin[x],lin[y]);} TP(0);
S(); for(int i=1;i<=m;i++){x=cc[i][0]; y=cc[i][1]; if(lin[x]!=lin[y])init(lin[y],lin[x]);} TP(1);
for(int i=1;i<=m;i++){
x=lin[cc[i][0]]; y=lin[cc[i][1]];
if(x==y||!AA[y][0]||!AA[x][1])continue;
ans=max(ans,A[y][0]+A[x][1]);
}
printf("%d",ans+sum[lin[1]]);
}