CF613D Kingdom and its Cities 虚树 + 树形DP
Code:
#include<bits/stdc++.h> #define ll long long #define maxn 300003 #define RG register using namespace std; inline int read() { RG int x=0,t=1;RG char ch=getchar(); while((ch<'0'||ch>'9')&&ch!='-')ch=getchar(); if(ch=='-')t=-1,ch=getchar(); while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar(); return x*t; } inline void setIO(string s) { string in=s+".in", out=s+".out"; freopen(in.c_str(),"r",stdin); } int edges,tim,n; int hd[maxn], to[maxn<<1], nex[maxn<<1]; inline void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } int fa[maxn], top[maxn], dfn[maxn], hson[maxn], siz[maxn], dep[maxn]; void dfs1(int u,int ff) { siz[u]=1,fa[u]=ff,dfn[u]=++tim,dep[u]=dep[ff]+1; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs1(v, u); siz[u]+=siz[v]; if(siz[v]>siz[hson[u]]) hson[u]=v; } } void dfs2(int u,int tp) { top[u]=tp; if(hson[u]) dfs2(hson[u], tp); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==fa[u]||v==hson[u]) continue; dfs2(v,v); } } inline int LCA(int x,int y) { while(top[x]!=top[y]) { dep[top[x]] > dep[top[y]] ? x = fa[top[x]] : y = fa[top[y]]; } return dep[x] < dep[y] ? x : y; } int tp; vector<int>G[maxn]; int arr[maxn],mk[maxn],S[maxn],g[maxn],f[maxn]; int cmp(int i,int j) { return dfn[i]<dfn[j]; } inline void addvir(int u,int v) { G[u].push_back(v); } inline void insert(int x) { if(tp<=1) { S[++tp]=x; return; } int lca=LCA(x, S[tp]); if(lca==S[tp]) { S[++tp]=x; return; } while(tp > 1 && dep[S[tp - 1]] >= dep[lca]) addvir(S[tp-1], S[tp]), --tp; if(S[tp]!=lca) addvir(lca,S[tp]), S[tp]=lca; S[++tp]=x; } void DP(int x) { g[x]=f[x]=0; for(int i=0;i<G[x].size();++i) { int v = G[x][i]; DP(v); f[x]+=f[v]; g[x]+=g[v]; } if(mk[x]) f[x]+=g[x], g[x]=1; else f[x]+=(g[x]>1), g[x]=(g[x]==1); G[x].clear(); } inline void work() { int k=read(); for(int i=1;i<=k;++i) arr[i]=read(), mk[arr[i]]=1; sort(arr+1,arr+1+k,cmp); for(int i=1;i<=k;++i) if(mk[arr[i]]&&mk[fa[arr[i]]]) { for(int j=1;j<=k;++j) mk[arr[j]]=0; printf("-1\n"); return; } tp=0; if(arr[1]!=1) S[tp=1]=1; for(int i=1;i<=k;++i) insert(arr[i]); while(tp > 1) addvir(S[tp - 1], S[tp]), --tp; DP(1); printf("%d\n",f[1]); for(int j=1;j<=k;++j) mk[arr[j]]=0; } int main() { // setIO("input"); n=read(); for(int i=1,a,b;i<n;++i) { a=read(),b=read(); add(a,b), add(b,a); } dfs1(1,0), dfs2(1,1); int Q; Q=read(); for(int i=1;i<=Q;++i) work(); return 0; }