poj3417 Network——LCA+树上差分

题目:http://poj.org/problem?id=3417

根据一条边被几个环覆盖来判断能不能删、有几种情况等;

用树上差分,终点 s++,LCA s-=2,统计时计算子树s值的和即可;

用ST表做LCA,不知为何WA了:

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
int const MAXN=1e5+5;
int n,m,head[MAXN],ct,dep[MAXN],pre[MAXN][20],s[MAXN];
long long ans;
struct N{
    int to,next;
    N(int t=0,int n=0):to(t),next(n) {}
}edge[MAXN<<1];
void dfs(int x,int f)
{
    dep[x]=dep[f]+1;
    pre[x][0]=f;
    for(int i=head[x];i;i=edge[i].next)
        if(edge[i].to!=f)dfs(edge[i].to,x);
}
void init()
{
    dfs(1,0);
    for(int k=1;k<=17;k++)
        for(int i=1;i<=n;i++)
            pre[i][k]=pre[pre[i][k-1]][k-1];
}
int lca(int x,int y)
{
    if(dep[x]>dep[y])swap(x,y);
    int d=dep[y]-dep[x];
    for(int i=0;i<=17;i++)
        if(/*d&(1<<i)*/(d>>i)&1)y=pre[y][i];
    for(int i=17;i>=0;i--)
        if(pre[x][i]!=pre[y][i])
        {
            x=pre[x][i];
            y=pre[y][i];
        }
    return pre[x][0];
}
long long dfs2(int x,int f)
{
    long long sum=s[x];
    for(int i=head[x],u;i;i=edge[i].next)
    {
        if((u=edge[i].to)==f)continue;
        long long k=dfs2(u,x);
        if(k==0)ans+=m;
        if(k==1)ans++;
        sum+=k;
    }
    return sum;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        edge[++ct]=N(y,head[x]);head[x]=ct;
        edge[++ct]=N(x,head[y]);head[y]=ct;
    }
    init();
    for(int i=1,x,y;i<=m;i++)
    {
        scanf("%d%d",&x,&y);
        if(x==y)continue;
        s[x]++;s[y]++;
        s[lca(x,y)]-=2;
    }
    dfs2(1,0);
    printf("%lld",ans);
    return 0;
}
ST表为什么WA

于是改成了tarjan,过程中求答案;

注意非树边加边时判掉x=y的情况。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
int const MAXN=1e5+5;
int n,m,head[MAXN],ct,s[MAXN],fa[MAXN],ct2,head2[MAXN];
long long ans;
bool vis[MAXN];
struct N{
    int to,next;
    N(int t=0,int n=0):to(t),next(n) {}
}edge[MAXN<<1],ed[MAXN<<1];
int find(int x)  
{  
    if(fa[x]==x)return x;  
    return fa[x]=find(fa[x]);  
}  
void tarjan(int x)  
{  
    fa[x]=x;vis[x]=1;
    for(int i=head2[x],u;i;i=ed[i].next)
        if(vis[u=ed[i].to])s[find(u)]-=2;
    for(int i=head[x],u;i;i=edge[i].next)
    {
        if(vis[u=edge[i].to])continue;//fa
        tarjan(u);fa[u]=x;s[x]+=s[u];
        if(s[u]==0)ans+=m;
        if(s[u]==1)ans++;
    }
}  
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1,x,y;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        edge[++ct]=N(y,head[x]);head[x]=ct;
        edge[++ct]=N(x,head[y]);head[y]=ct;
    }
    for(int i=1,x,y;i<=m;i++)
    {
        scanf("%d%d",&x,&y);
        if(x==y)continue;//!
        s[x]++;s[y]++;
        ed[++ct2]=N(y,head2[x]);head2[x]=ct2;
        ed[++ct2]=N(x,head2[y]);head2[y]=ct2;
    }
    tarjan(1);
    printf("%d",ans);
    return 0;
}

 

posted @ 2018-04-24 19:05  Zinn  阅读(167)  评论(0编辑  收藏  举报