HDU 5927 Auxiliary Set
这是2016CCPC东北四省赛的F题。
现场的做法有点繁琐,赛后和队友LSH讨论了一下,他提出了一个排序后\(O(N)\)的做法.
题意
每次询问给出一个unimportant node (以下简称u-node) 的集合 \(S\).
要求统计出\(S\)中有多少个节点是某两个不同的important node (以下简称i-node) 的LCA.
(当然这是转化后的题意, 题目并不是这样问的)
做法
对u-node按逆DFS序排序.
遍历\(S\), 对每个u-node \(u\), 维护\(u\)的满足如下条件的儿子\(v\)数目:
子树\(v\)中不全是u-node
不妨将此数目记作\(cnt[u]\).
维护的方法是
对某个 \(u\), 若 cnt[u]==0
就 --cnt[par[u]]
.
\(par[u]\)表示 \(u\) 的父亲.
最后, 满足 \(cnt[u] \ge 2\) 的 \(u\) 的数目即为所求.
Implementation
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int L[N];
int cnt[N];
int tail;
vector<int> g[N];
int par[N];
void dfs(int u, int f){
L[u]=++tail;
par[u]=f;
for(auto v: g[u]){
if(v!=f){
++cnt[u];
dfs(v, u);
}
}
}
bool cmp(int x, int y){
return L[x]>L[y];
}
vector<int> a; // unimportant nodes
int _cnt[N];
// bool flag[N];
int main(){
int T, cas{};
for(cin>>T; T--; ){
int n, q;
cin>>n>>q;
for(int i=1; i<=n; i++){
g[i].clear();
cnt[i]=0; // error-prone
}
for(int i=1; i<n; i++){
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
tail=0;
dfs(1, 0);
printf("Case #%d:\n", ++cas);
for(; q--; ){
int m;
cin>>m;
a.clear();
for(; m--; ){
int x;
scanf("%d", &x);
a.push_back(x);
// flag[x]=true;
}
sort(a.begin(), a.end(), cmp);
for(auto x: a){
_cnt[x]=cnt[x];
}
for(auto x: a){
if(_cnt[x]==0){
--_cnt[par[x]];
}
}
int res=0;
for(auto x: a){
res+=_cnt[x]>=2;
// _cnt[x]=cnt[x];
}
res += n-a.size();
printf("%d\n", res);
}
}
return 0;
}
UPD
ICPCCamp给出的题解上这题的做法是
题意:给定一棵树,每次询问一个集合,问所有不在这个集合的点两两 LCA (可以相同)的并集大小。
题解:实际上就是对于每个被删掉的点,check 一下是不是除了一个子树之外的点都被删掉了。那么对于每个被删掉的点的连通块,从最高点开始 dfs 就行。
一开始并不能看懂 (太弱), 现在大概理解了.
其实思路和上面是一样的, 最终的目的还是: 对每个被删掉的点 \(u\), 统计出\(u\)有多少个儿子\(v\)的子树全被删了, 要统计这个量, 只要考虑每个被删掉的点的连通块就好了. 因为如果\(u\)的某个儿子\(v\)的子树全被删了, 那么这些被删掉的点一定也出现在\(u\)所在的连通块中.
求每个连通块的最高点 (根) 和dfs的复杂度都是\(O(n)\)的. 而我的做法处理单次询问的复杂度是\(O(m\log m) + O(m)\), \(m\)是被删掉点的数目.
并不用找一个删除点的连通块的最高点, 按通常DFS的写法写就可以了, 然而我的DFS()
函数第一次却写跪了.
void DFS(int u){
used[u]=true;
cnt[u]=1;
int c=0;
for(auto v: G[u]){
if(!used[v]) DFS(v), cnt[u]+=cnt[v];
// if(cnt[v]<size[v]) c++;
c+=cnt[v]==size[v];
}
ans+=c+1 < g[u].size()-(u!=1);
}
正确的写法:
void DFS(int u){
used[u]=true;
cnt[u]=1;
int c=0;
for(auto v: G[u]){
if(!used[v]) DFS(v);
cnt[u]+=cnt[v]; // error-prone
// if(cnt[v]<size[v]) c++;
c+=cnt[v]==size[v];
}
ans+=c+1 < g[u].size()-(u!=1);
}
Implementation
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+5;
vector<int> g[N];
vector<int> G[N];
int a[N];
bool used[N];
int par[N], size[N];
void dfs(int u, int f){
par[u]=f;
size[u]=1;
for(auto v: g[u]){
if(v!=f) dfs(v, u), size[u]+=size[v];
}
}
int cnt[N], ans;
void DFS(int u){
used[u]=true;
cnt[u]=1;
int c=0;
for(auto v: G[u]){
if(!used[v]) DFS(v);
cnt[u]+=cnt[v]; // error-prone
// if(cnt[v]<size[v]) c++;
c+=cnt[v]==size[v];
}
ans+=c+1 < g[u].size()-(u!=1);
}
bool f[N];
int main(){
int T, n, q, cs{};
for(cin>>T; T--; ){
printf("Case #%d:\n", ++cs);
cin>>n>>q;
for(int i=1; i<=n; i++) g[i].clear();
for(int i=1; i<n; i++){
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
for(; q--; ){
int m;
scanf("%d", &m);
for(int i=0; i<m; i++){
scanf("%d", a+i);
used[a[i]]=false;
G[a[i]].clear();
f[a[i]]=true;
}
for(int i=0; i<m; i++){
if(f[par[a[i]]])
G[par[a[i]]].push_back(a[i]);
}
ans=0;
for(int i=0; i<m; i++)
if(!used[a[i]]) DFS(a[i]);
for(int i=0; i<m; i++)
f[a[i]]=false;
printf("%d\n", n-m+ans);
}
}
return 0;
}