题解 CF375D 【Tree and Queries】

\[\text{CF375D Tree and Queries} \]

\(\quad\)题目链接:CF375D Tree and Queries(洛谷的链接)

思路

\(\quad\)标准做法是动态规划,但看到 \(4.5s\) 的时限,似乎可以树上启发式合并水过去,只要用 \(num_i\)\(cnt_k\) 的数组来记录出现颜色 \(i\) 的数量及超过 \(k\) 的颜色数量即可。

\(\quad\)下面就简单讲讲树上启发式合并 (Dsu on Tree)算法,如果有不懂的可以提出来,如果已经学过了的可以直接看最后的完整代码。

\[\text{前置知识} \]

\(\quad\)学这个之前需要对树上操作、 \(dfs\) 序和轻重链剖分等知识有一定了解,最好已经掌握了树链剖分。

\[\text{算法思想} \]

\(\quad\)树上启发式合并 (Dsu on Tree),是一个在 \(O(n\log n)\) 时间内解决许多树上问题的有力算法,对于某些树上离线问题可以速度大于等于大部分算法且更易于理解和实现。

\(\quad\)先想一下暴力算法,对于每一次询问都遍历整棵子树,然后统计答案,最后再清空 \(cnt\) 数组,最坏情况是时间复杂度为 \(O(n^2)\) ,对于 \(10^5\) 的数据肯定是过不去的。

\(\quad\)现在考虑优化算法,暴力算法跑得慢的原因就是多次遍历,多次清空数组,一个显然的优化就是将询问同一个子树的询问放在一起处理,但这样还是没有处理到关键,最坏情况时间复杂度还是 \(O(n^2)\),考虑到询问 \(x\) 节点时,\(x\) 的子树对答案有贡献,所以可以不用清空数组,先统计 \(x\) 的子树中的答案,再统计 \(x\) 的答案,这样就需要提前处理好 \(dfs\) 序。

\(\quad\)然后我们可以考虑一个优化,遍历到最后一个子树时是不用清空的,因为它不会产生对其他节点影响了,根据贪心的思想我们当然要把节点数最多的子树(即重儿子形成的子树)放在最后,之后我们就有了一个看似比较快的算法,先遍历所有的轻儿子节点形成的子树,统计答案但是不保留数据,然后遍历重儿子,统计答案并且保留数据,最后再遍历轻儿子以及父节点,合并重儿子统计过的答案。

\(\quad\)关键代码

il void add(int x){num[col[x]]++;cnt[num[col[x]]]++;}//单点增加贡献
il void raise(int x)//增加x的子树的贡献
{
  for(re i=seg[x];i<=seg[x]+size[x]-1;i++)
    add(rev[i]);
}
il void clear(int x)//清空子树
{
  for(re i=seg[x];i<=seg[x]+size[x]-1;i++)
    {
      int y=rev[i];
      cnt[num[col[y]]]--;num[col[y]]--;
    }
}
il void dfs1(int x,int fa)//预处理
{
  size[x]=1;seg[x]=++seg[0];rev[seg[x]]=x;father[x]=fa;
  for(re i=head[x],y;i,y=go[i];i=next[i])
    {
      if(y==fa)continue;dfs1(y,x);
      size[x]+=size[y];
      if(size[y]>size[son[x]])son[x]=y;
    }
}
il void dfs2(int x,int flag)//flag表示是否为重儿子,1表示重儿子,0表示轻儿子
{
  for(re i=head[x],y;i,y=go[i];i=next[i])
    {
      if(y==father[x]||y==son[x])continue;
      dfs2(y,0);//先遍历轻儿子
    }if(son[x])dfs2(son[x],1);//再遍历重儿子
  for(re i=head[x],y;i,y=go[i];i=next[i])
    {
      if(y==father[x]||y==son[x])continue;
      raise(y);//更新轻儿子的贡献
    }add(x);//加上x结点本身的贡献
  for(re i=0;i<q[x].size();i++)//更新答案
    ans[q[x][i].id]=cnt[q[x][i].k];
  if(!flag)clear(x);//如果是轻儿子,就清空
}

\(\quad\)完整 \(AC\) 代码,建议自己写出来。

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<map>
#include<vector>
using namespace std;
#define re register int
#define int long long
#define LL long long
#define il inline
#define next nee
#define inf 1e18
il int read()
{
  int x=0,f=1;char ch=getchar();
  while(!isdigit(ch)&&ch!='-')ch=getchar();
  if(ch=='-')f=-1,ch=getchar();
  while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
  return x*f;
}
il void print(int x)
{
  if(x<0)putchar('-'),x=-x;
  if(x/10)print(x/10);
  putchar(x%10+'0');
}
const int N=1e5+5;
int n,m,next[N<<1],go[N<<1],head[N],tot,seg[N],col[N];
int rev[N],size[N],son[N],father[N],cnt[N],ans[N],num[N];
struct node{int k,id;};
vector<node>q[N];
il void Add(int x,int y)
{
  next[++tot]=head[x];
  head[x]=tot;go[tot]=y;
}
il void add(int x){num[col[x]]++;cnt[num[col[x]]]++;}
il void raise(int x)
{
  for(re i=seg[x];i<=seg[x]+size[x]-1;i++)
    add(rev[i]);
}
il void clear(int x)
{
  for(re i=seg[x];i<=seg[x]+size[x]-1;i++)
    {
      int y=rev[i];
      cnt[num[col[y]]]--;num[col[y]]--;
    }
}
il void dfs1(int x,int fa)
{
  size[x]=1;seg[x]=++seg[0];rev[seg[x]]=x;father[x]=fa;
  for(re i=head[x],y;i,y=go[i];i=next[i])
    {
      if(y==fa)continue;dfs1(y,x);
      size[x]+=size[y];
      if(size[y]>size[son[x]])son[x]=y;//重儿子
    }
}
il void dfs2(int x,int flag)//flag表示是否为重儿子,1表示重儿子,0表示轻儿子
{
  for(re i=head[x],y;i,y=go[i];i=next[i])
    {
      if(y==father[x]||y==son[x])continue;
      dfs2(y,0);//先遍历轻儿子
    }if(son[x])dfs2(son[x],1);//再遍历重儿子
  for(re i=head[x],y;i,y=go[i];i=next[i])
    {
      if(y==father[x]||y==son[x])continue;
      raise(y);//更新轻儿子的贡献
    }add(x);//加上x结点本身的贡献
  for(re i=0;i<q[x].size();i++)//更新答案
    ans[q[x][i].id]=cnt[q[x][i].k];
  if(!flag)clear(x);//如果是轻儿子,就清空
}
signed main()
{
  n=read();m=read();
  for(re i=1;i<=n;i++)col[i]=read();
  for(re i=1,x,y;i<n;i++)x=read(),y=read(),Add(x,y),Add(y,x);
  for(re i=1,x,y;i<=m;i++)x=read(),y=read(),q[x].push_back((node){y,i});
  dfs1(1,0);dfs2(1,1);
  for(re i=1;i<=m;i++)print(ans[i]),putchar('\n');
  return 0;
}
posted @ 2020-11-23 22:00  Farkas_W  阅读(89)  评论(0编辑  收藏  举报