luogu P4211 [LNOI2014]LCA
题面传送门
一道典型的树剖题目。
这东西如果暴力肯定是没法算的。除非能转化一下,比如算贡献。
然后会发现HHHOJ上有一道题和这个很像。
这样的话可以把每个点向上算贡献,一直加\(1\)到根节点。
这样当一个点加到时那么就自然算到了贡献。
其实质是差分,只不过没那么明显罢了。
这个东西可以用树剖+线段树维护。
但是这道题是区间查询。如果每次暴力加进去只有\(O(mnlog^2n)\)的复杂度,比那个\(O(nmlogn)\)的暴力还劣。
首先这个区间肯定能差分,将\(l-r\)区间改成\(1-l-1\)与\(1-r\)区间两部分。
然后把所有询问离线下来,从左往右扫描线,边扫边查询答案。
这样复杂度就是\(O((n+m)log^2n)\)
代码实现:
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
int n,m,k,x,y,z;
int siz[50039],d[50039],son[50039],top[50039],id[50039],fa[50039],idea,ans[50039],f[200039],sum[200039];
inline void read(int &x){
char s=getchar();x=0;
while(s<'0'||s>'9') s=getchar();
while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+(s^48),s=getchar();
}
struct yyy{int to,z;};
struct ljb{
int head,h[50039];
yyy f[100039];
inline void add(int x,int y){
f[++head]=(yyy){y,h[x]};
h[x]=head;
}
}s;
inline void dfs1(int x,int last){
d[x]=d[last]+1;
fa[x]=last;
siz[x]=1;
int cur=s.h[x],pus=0;
yyy tmp;
while(cur!=-1){
tmp=s.f[cur];
if(tmp.to!=last){
dfs1(tmp.to,x);
siz[x]+=siz[tmp.to];
if(siz[pus]<siz[tmp.to]) pus=tmp.to;
}
cur=tmp.z;
}
son[x]=pus;
}
inline void dfs2(int x,int last){
top[x]=last;
id[x]=++idea;
if(!son[x]) return;
dfs2(son[x],last);
int cur=s.h[x];
yyy tmp;
while(~cur){
tmp=s.f[cur];
if(tmp.to!=son[x]&&tmp.to!=fa[x]) dfs2(tmp.to,tmp.to);
cur=tmp.z;
}
}
struct ques{int to,num,flag;}tmp;
vector<ques> fs[50039];
inline void swap(int &x,int &y){x^=y^=x^=y;}
inline void push(int l,int r,int now){
if(f[now]){
int m=(l+r)>>1;
f[now<<1]+=f[now];f[now<<1|1]+=f[now];
sum[now<<1]+=(m-l+1)*f[now];sum[now<<1|1]+=(r-m)*f[now];
f[now]=0;
}
}
inline void get(int x,int y,int l,int r,int now){
if(x<=l&&r<=y) {
f[now]++;
sum[now]+=r-l+1;
return;
}
push(l,r,now);
int m=(l+r)>>1;
if(x<=m) get(x,y,l,m,now<<1);
if(y>m) get(x,y,m+1,r,now<<1|1);
sum[now]=sum[now<<1]+sum[now<<1|1];
}
inline int find(int x,int y,int l,int r,int now){
if(x<=l&&r<=y) return sum[now];
int m=l+r>>1,fs=0;
push(l,r,now);
if(x<=m) fs+=find(x,y,l,m,now<<1);
if(y>m) fs+=find(x,y,m+1,r,now<<1|1);
return fs;
}
inline void gets(int x,int y){
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
get(id[top[x]],id[x],1,n,1);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
get(id[x],id[y],1,n,1);
}
inline int finds(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(d[top[x]]<d[top[y]]) swap(x,y);
ans+=find(id[top[x]],id[x],1,n,1);
x=fa[top[x]];
}
if(d[x]>d[y]) swap(x,y);
return ans+find(id[x],id[y],1,n,1);
}
int main(){
// freopen("1.in","r",stdin);
register int i,j;
memset(s.h,-1,sizeof(s.h));
scanf("%d%d",&n,&m);
for(i=2;i<=n;i++) read(x),x++,s.add(i,x),s.add(x,i);
dfs1(1,0);dfs2(1,1);
for(i=1;i<=m;i++){
read(x);read(y);read(z);
x++;y++;z++;
fs[y].push_back((ques){z,i,1});
fs[x-1].push_back((ques){z,i,-1});
}
for(i=1;i<=n;i++){
gets(1,i);
for(j=0;j<fs[i].size();j++){
tmp=fs[i][j];
ans[tmp.num]+=tmp.flag*finds(1,tmp.to);
}
}
for(i=1;i<=m;i++) printf("%d\n",ans[i]%201314);
}