【BZOJ】3626 [LNOI2014]LCA
【算法】树链剖分+线段树(区间加值,区间求和)
【题解】http://hzwer.com/3891.html
中间不要取模不然相减会出错。
血的教训:线段树修改时标记下传+上传,查询时下传。如果修改时标记不下传,下面的结果就会覆盖上面的标记上传造成的影响。
读入后全部排序(离线处理)
链剖之后按顺序每个solve_insert(1,j),对于每次的z询问solve_sum(1,z)。
LCA其实就是两点到达根节点的路径的最近交点。
差分思想的运用:将区间差转为r-(l-1)。
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int maxn=50010,MOD=201314; int n,first[maxn],tot=0,top[maxn],deep[maxn],pos[maxn],q,ansz[maxn],size[maxn],f[maxn],dfsnum=0; long long anss[maxn*3]; struct edge{int from,v;}e[maxn*3]; struct node{int l,r,delta,sum;}t[maxn*3]; struct numbers{int num,ord;bool flag;}num[maxn*3]; void insert(int u,int v) {tot++;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} bool cmp(numbers a,numbers b) {return a.num<b.num;} void dfs1(int x,int fa) { size[x]=1; for(int i=first[x];i;i=e[i].from) if(e[i].v!=fa) { int y=e[i].v; f[y]=x; deep[y]=deep[x]+1; dfs1(y,x); size[x]+=size[y]; } } void dfs2(int x,int tp,int fa) { pos[x]=++dfsnum; top[x]=tp; int k=0; for(int i=first[x];i;i=e[i].from) if(e[i].v!=fa&&size[e[i].v]>size[k])k=e[i].v; if(k==0)return; dfs2(k,tp,x); for(int i=first[x];i;i=e[i].from) if(e[i].v!=fa&&e[i].v!=k)dfs2(e[i].v,e[i].v,x); } void build(int k,int l,int r) { t[k].l=l;t[k].r=r;t[k].delta=0;t[k].sum=0; if(l==r)return; int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); t[k].sum=t[k<<1].sum+t[k<<1|1].sum; } void add(int k,int l,int r) { int left=t[k].l,right=t[k].r; if(l<=left&&right<=r) { t[k].delta++; t[k].sum+=right-left+1; return; } if(t[k].delta) { t[k<<1].delta+=t[k].delta; t[k<<1].sum+=(t[k<<1].r-t[k<<1].l+1)*t[k].delta; t[k<<1|1].delta+=t[k].delta; t[k<<1|1].sum+=(t[k<<1|1].r-t[k<<1|1].l+1)*t[k].delta; t[k].delta=0; } int mid=(left+right)>>1; if(l<=mid)add(k<<1,l,r); if(r>mid)add(k<<1|1,l,r); t[k].sum=t[k<<1].sum+t[k<<1|1].sum; } long long query(int k,int l,int r) { int left=t[k].l,right=t[k].r; if(l<=left&&right<=r)return t[k].sum; if(t[k].delta) { t[k<<1].delta+=t[k].delta; t[k<<1].sum+=(t[k<<1].r-t[k<<1].l+1)*t[k].delta; t[k<<1|1].delta+=t[k].delta; t[k<<1|1].sum+=(t[k<<1|1].r-t[k<<1|1].l+1)*t[k].delta; t[k].delta=0; } int mid=(left+right)>>1; long long ans=0; if(l<=mid)ans=query(k<<1,l,r); if(r>mid)ans+=query(k<<1|1,l,r); return ans; } void solve_ins(int x,int y) { while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])swap(x,y); add(1,pos[top[x]],pos[x]); x=f[top[x]]; } if(pos[x]>pos[y])swap(x,y); add(1,pos[x],pos[y]); } long long solve_sum(int x,int y) { long long ans=0; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])swap(x,y); ans+=query(1,pos[top[x]],pos[x]); x=f[top[x]]; } if(pos[x]>pos[y])swap(x,y); ans+=query(1,pos[x],pos[y]); return ans; } int main() { scanf("%d%d",&n,&q); int u; for(int i=2;i<=n;i++) { scanf("%d",&u); insert(u+1,i); insert(i,u+1); } int ll,rr,zz; for(int i=1;i<=q;i++) { scanf("%d%d%d",&ll,&rr,&zz);ll++;rr++;zz++; ansz[i]=zz; num[i*2-1].num=ll-1;num[i*2-1].ord=i;num[i*2-1].flag=0; num[i*2].num=rr;num[i*2].ord=i;num[i*2].flag=1; } sort(num+1,num+q*2+1,cmp); build(1,1,n);dfs1(1,-1);dfs2(1,1,-1); int now=0; memset(anss,0,sizeof(anss)); for(int i=1;i<=q*2;i++) { if(num[i].num>now) for(int j=now+1;j<=num[i].num;j++)solve_ins(1,j); now=num[i].num; if(num[i].flag)anss[num[i].ord]+=solve_sum(1,ansz[num[i].ord]); else anss[num[i].ord]-=solve_sum(1,ansz[num[i].ord]); } for(int i=1;i<=q;i++)printf("%lld\n",anss[i]%MOD); return 0; }