bzoj5392 [Lydsy1806月赛]路径统计

传送门

分析

我们设sum[x]为小于等于x的点现在有多少联通

于是一个序列合法当且只当sum[R]-sum[L-1]=len且所有点度数不大于2

我们知道如果对于序列[L,R]满足条件则[L+1,R]一定满足

如果[L,R]不满足则[L-1,R]一定不满足

所以我们可以枚举R然后找最靠左的满足度数都小于2的L

用线段树维护信息查询区间内最大值是R的数的个数就是贡献

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cctype>
#include<cmath>
#include<cstdlib>
#include<queue>
#include<ctime>
#include<vector>
#include<set>
#include<map>
#include<stack>
using namespace std;
int n,L,R,num,Max,cnt,col[1200000],d[1200000],sum[1200000],du[260000];
long long Ans;
vector<int>high[260000],low[260000];
inline void build(int le,int ri,int wh){
    sum[wh]=1;
    d[wh]=ri;
    if(le==ri)return;
    int mid=(le+ri)>>1;
    build(le,mid,wh<<1);
    build(mid+1,ri,wh<<1|1);
}
inline void pd(int wh){
    if(col[wh]){
      d[wh<<1]+=col[wh];
      col[wh<<1]+=col[wh];
      d[wh<<1|1]+=col[wh];
      col[wh<<1|1]+=col[wh];
      col[wh]=0;
    }
}
inline void up(int wh){
    d[wh]=max(d[wh<<1],d[wh<<1|1]);
    sum[wh]=(d[wh<<1]==d[wh]?sum[wh<<1]:0)+(d[wh<<1|1]==d[wh]?sum[wh<<1|1]:0);
}
inline void update(int le,int ri,int wh,int x,int y){
    if(le>=x&&ri<=y){
      col[wh]++;
      d[wh]++;
      return;
    }
    pd(wh);
    int mid=(le+ri)>>1;
    if(mid>=x)update(le,mid,wh<<1,x,y);
    if(mid<y)update(mid+1,ri,wh<<1|1,x,y);
    up(wh);
}
inline void que(int le,int ri,int wh,int x,int y){
    if(le>=x&&ri<=y){
      if(d[wh]>Max)Max=d[wh],cnt=sum[wh];
        else if(d[wh]==Max)cnt+=sum[wh];
      return;
    }
    pd(wh);
    int mid=(le+ri)>>1;
    if(mid>=x)que(le,mid,wh<<1,x,y);
    if(mid<y)que(mid+1,ri,wh<<1|1,x,y);
    up(wh);
}
inline void add(int x){
    for(int i=0;i<low[x].size();i++){
      update(1,n,1,1,low[x][i]);
      if(low[x][i]>=L)num+=((++du[x])==3)+((++du[low[x][i]])==3);
    }
}
inline void deal(int x){
    for(int i=0;i<high[x].size();i++){
      if(high[x][i]<=R)num-=((--du[x])==2)+((--du[high[x][i]])==2);
    }
}
int main(){
    int i,j,k;
    scanf("%d",&n);
    for(i=1;i<n;i++){
      int x,y;
      scanf("%d%d",&x,&y);
      if(x>y)swap(x,y);
      high[x].push_back(y);
      low[y].push_back(x);
    }
    L=1;
    build(1,n,1);
    for(i=1;i<=n;i++){
      R=i;
      add(i);
      while(num)deal(L),L++;
      Max=0;
      que(1,n,1,L,i);
      if(Max==i)Ans+=1ll*cnt;
    }
    cout<<Ans;
    return 0;
}
posted @ 2019-02-20 14:17  水题收割者  阅读(217)  评论(0编辑  收藏  举报