poj 3417 树形dp+LCA
思路:我以前一直喜欢用根号n分段的LCA。在这题上挂了,第一次发现这样的LCA被卡。果断改用Tarjan离线算法求LCA。
当前节点为u,其子节点为v。那么:
当以v根的子树中含有连接子树以外点的边数为out[v]。
out[v]==0,dp[u]+=m;
out[v]==1,dp[u]+=1;
else dp[u]+=0;
最后就是dp[u]+=dp[v]。
对于u点的out[u]+=out[v];
最后out[u]-=cnt[u];cnt[u]表示以u为根,在子树内的边数。
#include<map> #include<set> #include<cmath> #include<queue> #include<cstdio> #include<vector> #include<string> #include<cstdlib> #include<cstring> #include<iostream> #include<algorithm> #define Maxn 100010 #define Maxm 200010 #define LL __int64 #define Abs(x) ((x)>0?(x):(-x)) #define lson(x) (x<<1) #define rson(x) (x<<1|1) #define inf 0x7fffffff #define Mod 1000000007 using namespace std; int head[Maxn],vi[Maxn],val[Maxn],e,dp[Maxn],fs[Maxn],fa[Maxn],out[Maxn],cnt[Maxn],anc[Maxn],vis[Maxn],n,m; struct Edge{ int u,v,next; }edge[Maxm]; vector<int> ll[Maxn]; void init() { memset(head,-1,sizeof(head)); memset(vi,0,sizeof(vi)); memset(out,0,sizeof(out)); memset(cnt,0,sizeof(cnt)); memset(fs,0,sizeof(fs)); e=0; } void add(int u,int v) { edge[e].u=u,edge[e].v=v,edge[e].next=head[u],head[u]=e++; edge[e].u=v,edge[e].v=u,edge[e].next=head[v],head[v]=e++; } void Treedp(int u) { int i,v; vi[u]=1; dp[u]=0; for(i=head[u];i!=-1;i=edge[i].next){ v=edge[i].v; if(vi[v]) continue; Treedp(v); dp[u]+=dp[v]; if(!out[v]) dp[u]+=m; else if(out[v]==1) dp[u]++; out[u]+=out[v]; } out[u]-=cnt[u]; } int find(int x) { if(x!=fa[x]) fa[x]=find(fa[x]); return fa[x]; } void merg(int a,int b) { int x=find(a); int y=find(b); if(fs[y]<=fs[x]) fa[y]=x,fs[x]+=fs[y]; else fa[x]=y,fs[y]+=fs[x]; } void LCA(int u) { int i,v,sz; sz=ll[u].size(); vi[u]=1; anc[u]=u; for(i=head[u];i!=-1;i=edge[i].next){ v=edge[i].v; if(vi[v]) continue; LCA(v); merg(u,v); anc[find(u)]=u; } vis[u]=1; for(i=0;i<sz;i++){ v=ll[u][i]; if(vis[v]){ int lca=anc[find(v)]; if(lca==u){ out[v]++; cnt[u]++; }else if(lca==v){ out[u]++; cnt[v]++; } else { cnt[lca]+=2; out[u]++,out[v]++; } } } } int main() { int i,j,u,v; scanf("%d%d",&n,&m); memset(head,-1,sizeof(head)); for(i=0;i<Maxn;i++) fa[i]=i,fs[i]=1; for(i=1;i<n;i++){ scanf("%d%d",&u,&v); add(u,v); } for(i=1;i<=m;i++){ scanf("%d%d",&u,&v); ll[u].push_back(v); ll[v].push_back(u); } LCA(1); memset(vi,0,sizeof(vi)); Treedp(1); printf("%d\n",dp[1]); return 0; }