cf #786 (div3)
AK后因细节fst之痛……
[G]
题意:
给一个有向无环图,现要删除一些边使得删除后所有点的入度和出度都减小了(如果是0则不变),求删边后“可爱”点集最大能有多少点。“可爱”点集是其中任意两点间都存在(至少)一条路径的点集。点数、边数均 \(\leq 2*10^5\) 。
分析:
删边的操作可以具体到每个点上来看,就是每个点都至少要减少一条入边一条出边;
考虑dfs统计答案,那么删边操作可以在dfs到每个点上时进行。
设当前点为 \(u\),计算它能给来路返回的最大贡献,其实就是dfs它的所有出点,选一个最大的;其他那些边就可以看做删除了。所以如果它只有一条出边,那它能返回的贡献只有本身的\(1\);同样,如果入度为\(1\),那它返回的贡献应该是\(0\)。
所以这里要注意,每个点自身的\(ans[u]\)代表从它出发可以走到的最大答案,和它返回来路的贡献并不是同一个东西。
然后注意细节……dfs里面不要提前特判return,或者主函数里遍历每个点dfs……
代码如下
#include<bits/stdc++.h>
using namespace std;
int const N=2e5+5;
int n,m,hd[N],cnt,nxt[N<<1],to[N<<1],in[N],out[N];
int ans[N];
bool vis[N];
void add(int u,int v){cnt++; to[cnt]=v; nxt[cnt]=hd[u]; hd[u]=cnt;}
int dfs(int u)
{
// printf("u=%d\n",u);
vis[u]=1;
// if(in[u]==1&&out[u]==1){ans[u]=1; return 0;} // 不该提前return,小心没有dfs到别的点!
// else if(out[u]==1){ans[u]=1; return 1;}
int mx=0;
for(int i=hd[u],v;i;i=nxt[i])
{
if(!vis[v=to[i]]) mx=max(mx,dfs(v));
else
{
if(in[v]==1&&out[v]==1)mx=mx;
else if(out[v]==1)mx=max(mx,1);
else if(in[v]==1)mx=mx;
else mx=max(mx,ans[v]);
}
}
ans[u]=1+mx;
// printf("ans[%d]=%d\n",u,ans[u]);
if(in[u]==1&&out[u]==1){ans[u]=1; return 0;}
else if(out[u]==1){ans[u]=1; return 1;}
if(in[u]==1)return 0;
else return ans[u];
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1,u,v;i<=m;i++)
{
scanf("%d%d",&u,&v);
add(u,v); out[u]++; in[v]++;
}
int mx=0;
for(int i=1;i<=n;i++)
if(!in[i])mx=max(mx,dfs(i));
for(int i=1;i<=n;i++) mx=max(mx,ans[i]);
printf("%d\n",mx);
return 0;
}
代码2如下
#include<bits/stdc++.h>
using namespace std;
int const N=2e5+5;
int n,m,hd[N],cnt,nxt[N<<1],to[N<<1],in[N],out[N];
int ans[N];
bool vis[N];
void add(int u,int v){cnt++; to[cnt]=v; nxt[cnt]=hd[u]; hd[u]=cnt;}
int dfs(int u)
{
// printf("u=%d\n",u);
vis[u]=1;
if(in[u]==1&&out[u]==1){ans[u]=1; return 0;} // 提前return + main函数里遍历每个点
else if(out[u]==1){ans[u]=1; return 1;}
int mx=0;
for(int i=hd[u],v;i;i=nxt[i])
{
if(!vis[v=to[i]]) mx=max(mx,dfs(v));
else
{
if(in[v]==1&&out[v]==1)mx=mx;
else if(out[v]==1)mx=max(mx,1);
else if(in[v]==1)mx=mx;
else mx=max(mx,ans[v]);
}
}
ans[u]=1+mx;
// printf("ans[%d]=%d\n",u,ans[u]);
if(in[u]==1)return 0;
else return ans[u];
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1,u,v;i<=m;i++)
{
scanf("%d%d",&u,&v);
add(u,v); out[u]++; in[v]++;
}
int mx=0;
for(int i=1;i<=n;i++)
mx=max(mx,dfs(i));
for(int i=1;i<=n;i++) mx=max(mx,ans[i]);
printf("%d\n",mx);
return 0;
}