bzoj3351:[ioi2009]Regions
思路:首先如果颜色相同直接利用以前的答案即可,可以离线排序或是在线hash,然后考虑怎么快速统计答案。
首先如果点a是点b的祖先,那么一定有点b在以点a为根的子树的dfs序区间内的,于是先搞出dfs序。
然后如果颜色a的点数很小,颜色b的点数很大,那么可以考虑枚举a的点数,然后对于每一种颜色开个vector记录一下有哪些点是这种颜色,然后按照它们的dfs序排序,就可以用颜色a中的每个点在颜色b中二分出哪些点属于以该点为根的子树对应的dfs序区间了。复杂度O(size(a)*log(size(b))),size(a)表示颜色a的vector的size()。
然后如果颜色b的点数很小,颜色a的点数很大,那么就枚举b的点数,这时要考虑的问题就成了一个点被多少段区间覆盖了,然后离散化差分预处理,再去二分(我写的是vector的离散化)。复杂度O(size(b)*log(size(a)))
但如果a,b的点数差不多且都很大(也就是几乎为sqrt(n)),那么算法复杂度就会变成O(sqrt(n)*log(n))了,再乘以一个q就会GG,于是只能另寻他法,然后可以发现直接两个指针扫过去,一个扫区间端点另一个扫要询问的点,然后如果扫到一个点就直接统计答案,然后这就变成了O(size(a)+size(b))了。
那这个很大是有多大,很小是有多小呢?
对于第一种算法使用条件是size(b)>x,第二种算法使用条件是size(a)>x,其余则用第三种算法。
对于第一、二种情况,时间复杂度最大是O(n^2logn/x),然后对于第三种则是O(n*x),然后根据基本不等式x=sqrt(nlogn),总时间复杂度为O(n*sqrt(nlogn))。
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<vector> #include<cmath> using namespace std; #define maxn 200005 #define maxr 30000 int n,r,Q,tot,cnt; int now[maxn],pre[2*maxn],son[2*maxn],color[maxn],dfn[maxn],size[maxn]; long long ans[maxn]; inline int read(){ int x=0,f=1;char ch=getchar(); for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1; for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0'; return x*f; } struct node{ int dfn,bo; node(){} node(int a,int b){dfn=a,bo=b;} bool operator <(const node &a)const{return dfn<a.dfn;} }; struct query{ int x,y,id; bool operator <(const query &a)const{return x<a.x||(x==a.x&&y<a.y);} }q[maxn]; bool cmp(int a,int b){return dfn[a]<dfn[b];} vector<int> col[maxr],val[maxr]; vector<node> v[maxr]; vector<int> fuckpps[maxr]; void add(int a,int b){ son[++tot]=b; pre[tot]=now[a]; now[a]=tot; } void link(int a,int b){ add(a,b),add(b,a); } void dfs(int x,int fa){ dfn[x]=++cnt; for (int p=now[x];p;p=pre[p]) if (son[p]!=fa) dfs(son[p],x),size[x]+=size[son[p]]+1; } int binary_search(int l,int r,int b,int pos){ int ans=-1; while (l<=r){ int mid=(l+r)>>1; if (pos>=fuckpps[b][mid]) ans=mid,l=mid+1; else r=mid-1; } return ans+1; } long long solve1(int a,int b){ long long ans=0; for (unsigned int i=0;i<col[a].size();i++){ int x=col[a][i],l=binary_search(0,fuckpps[b].size()-1,b,dfn[x]-1),r=binary_search(0,fuckpps[b].size()-1,b,dfn[x]+size[x]); ans+=r-l; } return ans; } int binary_search2(int l,int r,int b,int pos){ int ans=-1; while (l<=r){ int mid=(l+r)>>1; if (v[b][mid].dfn<=pos) ans=mid,l=mid+1; else r=mid-1; } return ans; } long long solve2(int a,int b){ long long ans=0; for (unsigned int i=0;i<col[b].size();i++){ int x=col[b][i],pos=binary_search2(0,v[a].size()-1,a,dfn[x]); if (pos!=-1) ans+=val[a][pos]; } return ans; } long long solve3(int a,int b){ long long ans=0;unsigned int i=0,j=0,tt=0; while (i<v[a].size() && j<col[b].size()) if (v[a][i].dfn<=dfn[col[b][j]]) tt=val[a][i],i++;else ans+=tt,j++; return ans; } int main(){ n=read(),r=read(),Q=read();int siz=sqrt(n*log2(n)); for (int i=1,x;i<=n;i++){ if (i!=1) x=read(),link(i,x); color[i]=read();col[color[i]].push_back(i); } dfs(1,0); for (int i=1;i<=n;i++) fuckpps[color[i]].push_back(dfn[i]); for (int i=1;i<=r;i++) sort(col[i].begin(),col[i].end(),cmp),sort(fuckpps[i].begin(),fuckpps[i].end()); for (int i=1;i<=r;i++){ for (unsigned int j=0;j<col[i].size();j++) v[i].push_back(node(dfn[col[i][j]],1)),v[i].push_back(node(dfn[col[i][j]]+size[col[i][j]]+1,-1)); sort(v[i].begin(),v[i].end());int sum=0; for (unsigned int j=0;j<v[i].size();j++){ sum+=v[i][j].bo; val[i].push_back(sum); } } for (int i=1;i<=Q;i++) q[i].x=read(),q[i].y=read(),q[i].id=i; sort(q+1,q+Q+1); for (int i=1;i<=Q;i++){ if (q[i].x==q[i-1].x && q[i].y==q[i-1].y){ans[q[i].id]=ans[q[i-1].id];continue;} if (col[q[i].y].size()+1>=siz&&col[q[i].x].size()+1<siz) ans[q[i].id]=solve1(q[i].x,q[i].y); else if (col[q[i].x].size()+1>=siz&&col[q[i].y].size()+1<siz) ans[q[i].id]=solve2(q[i].x,q[i].y); else ans[q[i].id]=solve3(q[i].x,q[i].y); } for (int i=1;i<=Q;i++) printf("%lld\n",ans[i]); return 0; }