【BZOJ】3626 [LNOI2014]LCA

【算法】树链剖分+线段树(区间加值,区间求和)

【题解】http://hzwer.com/3891.html

中间不要取模不然相减会出错。

血的教训:线段树修改时标记下传+上传,查询时下传。如果修改时标记不下传,下面的结果就会覆盖上面的标记上传造成的影响。

读入后全部排序(离线处理)

链剖之后按顺序每个solve_insert(1,j),对于每次的z询问solve_sum(1,z)。

LCA其实就是两点到达根节点的路径的最近交点。

差分思想的运用:将区间差转为r-(l-1)。

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int maxn=50010,MOD=201314;
int n,first[maxn],tot=0,top[maxn],deep[maxn],pos[maxn],q,ansz[maxn],size[maxn],f[maxn],dfsnum=0;
long long anss[maxn*3];
struct edge{int from,v;}e[maxn*3];
struct node{int l,r,delta,sum;}t[maxn*3];
struct numbers{int num,ord;bool flag;}num[maxn*3];
void insert(int u,int v)
{tot++;e[tot].v=v;e[tot].from=first[u];first[u]=tot;}
bool cmp(numbers a,numbers b)
{return a.num<b.num;}
void dfs1(int x,int fa)
{
    size[x]=1;
    for(int i=first[x];i;i=e[i].from)
     if(e[i].v!=fa)
      {
          int y=e[i].v;
          f[y]=x;
          deep[y]=deep[x]+1;
          dfs1(y,x);
          size[x]+=size[y];
      }
}
void dfs2(int x,int tp,int fa)
{
    pos[x]=++dfsnum;
    top[x]=tp;
    int k=0;
    for(int i=first[x];i;i=e[i].from)
     if(e[i].v!=fa&&size[e[i].v]>size[k])k=e[i].v;
    if(k==0)return;
    dfs2(k,tp,x);
    for(int i=first[x];i;i=e[i].from)
     if(e[i].v!=fa&&e[i].v!=k)dfs2(e[i].v,e[i].v,x);
    
}
void build(int k,int l,int r)
{
    t[k].l=l;t[k].r=r;t[k].delta=0;t[k].sum=0;
    if(l==r)return;
    int mid=(l+r)>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    t[k].sum=t[k<<1].sum+t[k<<1|1].sum;
}
void add(int k,int l,int r)
{
    int left=t[k].l,right=t[k].r;
    if(l<=left&&right<=r)
     {
         t[k].delta++;
         t[k].sum+=right-left+1;
         return;
     }
    if(t[k].delta)
     {
         t[k<<1].delta+=t[k].delta;
         t[k<<1].sum+=(t[k<<1].r-t[k<<1].l+1)*t[k].delta;
         t[k<<1|1].delta+=t[k].delta;
         t[k<<1|1].sum+=(t[k<<1|1].r-t[k<<1|1].l+1)*t[k].delta;
         t[k].delta=0;
     }
    int mid=(left+right)>>1;
    if(l<=mid)add(k<<1,l,r);
    if(r>mid)add(k<<1|1,l,r);
    t[k].sum=t[k<<1].sum+t[k<<1|1].sum;
}
long long query(int k,int l,int r)
{
    int left=t[k].l,right=t[k].r;
    if(l<=left&&right<=r)return t[k].sum;
    if(t[k].delta)
     {
         t[k<<1].delta+=t[k].delta;
         t[k<<1].sum+=(t[k<<1].r-t[k<<1].l+1)*t[k].delta;
         t[k<<1|1].delta+=t[k].delta;
         t[k<<1|1].sum+=(t[k<<1|1].r-t[k<<1|1].l+1)*t[k].delta;
         t[k].delta=0;
     }
    int mid=(left+right)>>1;
    long long ans=0;
    if(l<=mid)ans=query(k<<1,l,r);
    if(r>mid)ans+=query(k<<1|1,l,r);
    return ans;
}
void solve_ins(int x,int y)
{
    while(top[x]!=top[y])
     {
         if(deep[top[x]]<deep[top[y]])swap(x,y);
         add(1,pos[top[x]],pos[x]);
         x=f[top[x]];
     }
    if(pos[x]>pos[y])swap(x,y);
    add(1,pos[x],pos[y]);
}
long long solve_sum(int x,int y)
{
    long long ans=0;
    while(top[x]!=top[y])
     {
         if(deep[top[x]]<deep[top[y]])swap(x,y);
         ans+=query(1,pos[top[x]],pos[x]);
         x=f[top[x]];
     }
    if(pos[x]>pos[y])swap(x,y);
    ans+=query(1,pos[x],pos[y]);
    return ans;
}
int main()
{
    scanf("%d%d",&n,&q);
    int u;
    for(int i=2;i<=n;i++)
     {
         scanf("%d",&u);
         insert(u+1,i);
         insert(i,u+1);
     }
    int ll,rr,zz;
    for(int i=1;i<=q;i++)
     {
         scanf("%d%d%d",&ll,&rr,&zz);ll++;rr++;zz++;
         ansz[i]=zz;
         num[i*2-1].num=ll-1;num[i*2-1].ord=i;num[i*2-1].flag=0;
         num[i*2].num=rr;num[i*2].ord=i;num[i*2].flag=1;
     }
    sort(num+1,num+q*2+1,cmp);
    build(1,1,n);dfs1(1,-1);dfs2(1,1,-1);
    int now=0;
    memset(anss,0,sizeof(anss));
    for(int i=1;i<=q*2;i++)
     {
         if(num[i].num>now)
          for(int j=now+1;j<=num[i].num;j++)solve_ins(1,j);
         now=num[i].num;
         if(num[i].flag)anss[num[i].ord]+=solve_sum(1,ansz[num[i].ord]);
          else anss[num[i].ord]-=solve_sum(1,ansz[num[i].ord]);
     }
    
    for(int i=1;i<=q;i++)printf("%lld\n",anss[i]%MOD);
    return 0;
}
View Code

 

posted @ 2017-04-23 00:03  ONION_CYC  阅读(181)  评论(0编辑  收藏  举报