CF193D Two Segments (线段树+dp)(外加两个扩展题)

大概算是个系列整理

(最强版是模拟赛原题))

首先,我们先来看这个题目。

QWQ一开始是毫无头绪,除了枚举就是枚举

首先,我们可以枚举一个右端点,然后算一下当前右端点的答案

我们令\(f[l,r]\)表示\(a_l到a_r\)这些数,能够最少划分成几段连续的数。

显然,我们要求的是以每个端点为右端点,\(f值<=2的\)
QWQ那么这个玩意应该怎么维护+更新呢

考虑右端点移动,会造成什么后果。

我们令新扩展的位置是\(r\),数是\(x\),他的前驱的位置是\(pre\),后继的位置是\(last\)

\(比如3的前驱就是2,后继是4\)

首先,\(f[1,r]....f[r,r]\)都会加1,不考虑和之前的数能合并的情况下,他自己就需要一段来完成

如果\(pre\)在当前位置的前面,那么\(f[1..r]....[pre,r]\)应该要-1,因为所有从pre之前出发的左端点,新的数可以和前驱的合并,就可以减少一段

那么如果\(last\)前面,也是同理的。

所以,我们需要一个支持区间维护最小值,最小值个数,次小值,次小值个数,还支持区间加和减的一个数据结构

线段树!

这里有几个要注意的地方就是:

1.维护的是严格的最小值和次小值,也就是说两个值不能相同,所以\(up\)的时候,会有一些小细节

void up(int root)
{
 if (f[2*root].mn<f[2*root+1].mn)
 {
  f[root].mn=f[2*root].mn;
  f[root].cimn=min(f[2*root].cimn,f[2*root+1].mn);
 }
 else
 {
  if (f[2*root].mn>f[2*root+1].mn)
  {
    f[root].mn=f[2*root+1].mn;
    f[root].cimn=min(f[2*root].mn,f[2*root+1].cimn);
     }
     else
     {
      f[root].mn=min(f[2*root].mn,f[2*root+1].mn);
      f[root].cimn=min(f[2*root].cimn,f[2*root+1].cimn);
  }
 }
 if(f[root].cimn==f[root].mn) f[root].cimn=1e9; 
 f[root].sum1=f[2*root].sum1*(f[2*root].mn==f[root].mn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].mn);
 f[root].sum2=f[2*root].sum2*(f[2*root].cimn==f[root].cimn)+f[2*root+1].sum2*(f[2*root+1].cimn==f[root].cimn && f[root].cimn!=1e9);
 f[root].sum2+=f[2*root].sum1*(f[2*root].mn==f[root].cimn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].cimn);
}

2.\(query\)由于我们要求的是\(<=2\)的值的个数,所以求和的时候,要注意满足\(mn<=2\)\(cimn==2\))

long long query(int root,int l,int r,int x,int y)
{
 if (x<=l && r<=y)
 {
  //cout<<l<<" "<<r<<endl;
  //cout<<f[root].mn<<" "<<f[root].cimn<<endl;
  return f[root].sum1*(f[root].mn<=2) + f[root].sum2*(f[root].cimn!=f[root].mn && f[root].cimn<=2);
 }
 int mid = l+r >> 1;
 long long ans=0;
 pushdown(root,l,r);
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
}

QWQ大概就是这样了?

具体直接看代码吧

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}
const int maxn= 4e5+1e2;
struct Node
{
 int mn,cimn;
 int sum1,sum2;
};
Node f[4*maxn];
int add[4*maxn];
int n,m;
int a[maxn],b[maxn];
long long ans;
void up(int root)
{
 if (f[2*root].mn<f[2*root+1].mn)
 {
  f[root].mn=f[2*root].mn;
  f[root].cimn=min(f[2*root].cimn,f[2*root+1].mn);
 }
 else
 {
  if (f[2*root].mn>f[2*root+1].mn)
  {
    f[root].mn=f[2*root+1].mn;
    f[root].cimn=min(f[2*root].mn,f[2*root+1].cimn);
     }
     else
     {
      f[root].mn=min(f[2*root].mn,f[2*root+1].mn);
      f[root].cimn=min(f[2*root].cimn,f[2*root+1].cimn);
  }
 }
 if(f[root].cimn==f[root].mn) f[root].cimn=1e9; 
 f[root].sum1=f[2*root].sum1*(f[2*root].mn==f[root].mn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].mn);
 f[root].sum2=f[2*root].sum2*(f[2*root].cimn==f[root].cimn)+f[2*root+1].sum2*(f[2*root+1].cimn==f[root].cimn && f[root].cimn!=1e9);
 f[root].sum2+=f[2*root].sum1*(f[2*root].mn==f[root].cimn)+f[2*root+1].sum1*(f[2*root+1].mn==f[root].cimn);
}
void pushdown(int root,int l,int r)
{
 if (add[root])
 {
  add[2*root]+=add[root];
  add[2*root+1]+=add[root];
  f[2*root].mn+=add[root];
  f[2*root].cimn+=add[root];
  f[2*root+1].mn+=add[root];
  f[2*root+1].cimn+=add[root];
  add[root]=0;
 }
}
void build(int root,int l,int r)
{
 if(l==r)
 {
  f[root].sum1=1;
  f[root].sum2=0;
  f[root].cimn=1e9;
  return;
 }
 int mid = l+r >> 1;
 build(2*root,l,mid);
 build(2*root+1,mid+1,r);
 up(root);
}
void update(int root,int l,int r,int x,int y,int p)
{
 if (x<=l && r<=y)
 {
  add[root]+=p;
  f[root].mn+=p;
  f[root].cimn+=p;
  return;
 }
 int mid = l+r >> 1;
 pushdown(root,l,r);
 if(x<=mid) update(2*root,l,mid,x,y,p);
 if(y>mid) update(2*root+1,mid+1,r,x,y,p);
 up(root);
}
long long query(int root,int l,int r,int x,int y)
{
 if (x<=l && r<=y)
 {
  //cout<<l<<" "<<r<<endl;
  //cout<<f[root].mn<<" "<<f[root].cimn<<endl;
  return f[root].sum1*(f[root].mn<=2) + f[root].sum2*(f[root].cimn!=f[root].mn && f[root].cimn<=2);
 }
 int mid = l+r >> 1;
 long long ans=0;
 pushdown(root,l,r);
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
} 
signed main()
{
  n=read();
  for (int i=1;i<=n;i++) a[i]=read();
  for (int i=1;i<=n;i++) b[a[i]]=i;
  build(1,1,n); 
 // update(1,1,n,1,3,1);
  //update(1,1,n,1,2,1);
  //cout<<query(1,1,n,1,3)<<endl;
  //return 0;
  for (int i=1;i<=n;i++)
  {
   int x = a[b[i]-1],y=a[b[i]+1];
   update(1,1,n,1,i,1);
   if (x && x<i) update(1,1,n,1,x,-1);
   if (y && y<i) update(1,1,n,1,y,-1);
   ans=ans+query(1,1,n,1,i);
   //cout<<ans<<endl;
  }  
  cout<<ans-n<<endl;
  return 0;
}

QWQ嘤嘤嘤

那么如果换一种问法,应该怎么办呢?

给定一个你长度为\(n\)的序列,然后求出来有多少个区间满足最大值减去最小值等于区间长度-1

其实和上个题目差不多了啦。

只不过我们只需要维护最小值,然后求和的时候,只需要满足最小值等于1即可

直接给代码(只呈现关键部分的)

void up(int root)
{
 g[root].mn=min(g[2*root].mn,g[2*root+1].mn);
 g[root].ans=g[2*root].ans*(g[2*root].mn==g[root].mn)+g[2*root+1].ans*(g[2*root+1].mn==g[root].mn);  
}
void pushdown(int root,int l,int r)
{
 if (add[root])
 {
  add[2*root]+=add[root];
  add[2*root+1]+=add[root];
  g[2*root].mn+=add[root];
  g[2*root+1].mn+=add[root];
  add[root]=0;
 }
}
void build(int root,int l,int r)
{
 if (l==r)
 {
  g[root].ans=1;
  return;
 }
 int mid = l+r >> 1;
    build(2*root,l,mid);
 build(2*root+1,mid+1,r);
 up(root); 
}
void update(int root,int l,int r,int x,int y,int p)
{
 if(x<=l && r<=y)
 {
  g[root].mn+=p;
  add[root]+=p;
  return;
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 if (x<=mid) update(2*root,l,mid,x,y,p);
 if (y>mid) update(2*root+1,mid+1,r,x,y,p);
 up(root);
}
long long query(int root,int l,int r,int x,int y)
{
 if(x<=l && r<=y)
 {
  return g[root].ans*(g[root].mn==1); 
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 long long ans=0;
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
}

既然都做到这个程度了,不如就更毒瘤 一点

现在给定你一颗n个点的树,每个点都有一个编号,每条边的长度都是1,让你求有多少条路经满足最大编号-最小编号等于路径长度。

woc上树了....那该怎么做啊?

是不是可以考虑和序列上的相类似呢?

我们不妨对每个点维护一个\(dfn[x]\)表示这个点的\(dfs\)序。

然后依次枚举dfs序上的每个点,计算dfs序从1到当前点之前的所有点到当前点的合法路径条数。

类比序列

对于当前点来说,首先我们要让1到当前点之前所有的路径都+1,然后呢。
我们考虑前驱和后继的位置

这里需要讨论一个是否是祖先的关系(因为画个图就能发现,如果是祖先,那么这个点会影响的路径起点是1到\(dfn[x]\),不然就是他的子树内的所有点)

后继同样是如此

而且在计算完每个儿子的时候,记得加上当前儿子对其他儿子的贡献。

然后最后记得把一个点的贡献都还原,因为我们需要计算别的答案,而当前点就会变成起点之一,那么他作为终点的贡献,就是要去掉的。

QWQ有一些细节写到代码里面了

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}
const int maxn = 1e5+1e2;
const int maxm =  2*maxn;
int point[maxn],nxt[maxm],to[maxm];
int cnt,n,m;
int dfn[maxn];
int ans;
void addedge(int x,int y)
{
 nxt[++cnt]=point[x];
 to[cnt]=y;
 point[x]=cnt;
}
struct Node{
 int mn,ans;
 int len;
};
Node g[4*maxn];
int add[4*maxn];
void up(int root)
{
 g[root].mn=min(g[2*root].mn,g[2*root+1].mn);
 g[root].ans=g[2*root].ans*(g[2*root].mn==g[root].mn)+g[2*root+1].ans*(g[2*root+1].mn==g[root].mn);  
}
void pushdown(int root,int l,int r)
{
 if (add[root])
 {
  add[2*root]+=add[root];
  add[2*root+1]+=add[root];
  g[2*root].mn+=add[root];
  g[2*root+1].mn+=add[root];
  add[root]=0;
 }
}
void build(int root,int l,int r)
{
 if (l==r)
 {
  g[root].ans=1;
  return;
 }
 int mid = l+r >> 1;
    build(2*root,l,mid);
 build(2*root+1,mid+1,r);
 up(root); 
}
void update(int root,int l,int r,int x,int y,int p)
{
 if (x>y) return;
 if(x<=l && r<=y)
 {
  g[root].mn+=p;
  add[root]+=p;
  return;
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 if (x<=mid) update(2*root,l,mid,x,y,p);
 if (y>mid) update(2*root+1,mid+1,r,x,y,p);
 up(root);
}
long long query(int root,int l,int r,int x,int y)
{
 if (x>y) return 0;
 if(x<=l && r<=y)
 {
  return g[root].ans*(g[root].mn==1); 
 }
 pushdown(root,l,r);
 int mid = l+r >> 1;
 long long ans=0;
 if (x<=mid) ans=ans+query(2*root,l,mid,x,y);
 if (y>mid) ans=ans+query(2*root+1,mid+1,r,x,y);
 return ans;
}
int deep[maxn];
int f[maxn][21];
int a[maxn],b[maxn];
int size[maxn];
int tot;
int maxdfn[maxn]; //表示已经计算过的儿子的子树里面的最大的dfs序 
void dfs(int x,int fa,int dep)
{
 deep[x]=dep;
 dfn[x]=++tot;
 size[x]=1;
 for (int i=point[x];i;i=nxt[i])
 {
  int p = to[i];
  if (p==fa) continue; 
  f[p][0]=x;
  dfs(p,x,dep+1);
  size[x]+=size[p];
 }
}
void init()
{
 for (int j=1;j<=20;j++)
   for (int i=1;i<=n;i++)
     f[i][j]=f[f[i][j-1]][j-1];
}
int go_up(int x,int d)
{
 for (int i=0;i<=20;i++)
 if ((1<<i)&d) x=f[x][i];
 return x;
}
bool check(int x,int fa)
{
 if (fa==0 || fa==n+1) return 0;
 if(deep[x]<=deep[fa]) return 0;
 if (go_up(x,deep[x]-deep[fa])==fa) return 1;
 else return 0;
}
int dp(int x,int fa)//我们对于每个点,计算的是 dfs序上[i,r]的合法路径条数 
{
 maxdfn[x]=dfn[x];
    update(1,1,n,1,dfn[x],1); //首先把之前的全部+1 
 if (dfn[x-1]<dfn[x] && x!=1)
 {
    if (check(x,x-1))
      update(1,1,n,1,maxdfn[x-1],-1); //相当于这些点都是从x-1到达x,(相当于除去这个子树外所有的dfs小于当前点的点)所以应该-1,因为可以合并 
          else
      update(1,1,n,dfn[x-1],dfn[x-1]+size[x-1]-1,-1); //如果不是祖先关系,那么从x-1到达x的路径,一定是从他的子树里面出发的 (而且子树内的任何一个点的dfs序一定都在当前点之前) 
 }
 if (dfn[x+1]<dfn[x] && x!=n)
 {
    if (check(x,x+1))
      update(1,1,n,1,maxdfn[x+1],-1);  
          else
      update(1,1,n,dfn[x+1],dfn[x+1]+size[x+1]-1,-1);  
 }
 ans=ans+query(1,1,n,1,dfn[x]);
 for (int i=point[x];i;i=nxt[i])
 {
  int p = to[i];
  if (p==fa) continue;
     int now = dp(p,x);
     update(1,1,n,maxdfn[x]+1,now,1); //处理儿子之间的影响  (因为计算一个点的代价的时候,不仅有祖先或者是别的子树的,还要计算兄弟的) 
     if (x!=1 && (p==x-1 || check(x-1,p))) update(1,1,n,dfn[x-1],dfn[x-1]+size[x-1]-1,-1); //如果x-1在当前的儿子里面,那么他那个子树里到后面的点的代价就可以-1(理解成能够合并) 
     if (x!=n && (p==x+1 || check(x+1,p))) update(1,1,n,dfn[x+1],dfn[x+1]+size[x+1]-1,-1);
     maxdfn[x]=now; 
 }
 if (dfn[x-1]<dfn[x] && x!=1)
 {
    if (check(x,x-1))
      update(1,1,n,1,maxdfn[x-1],1);  
          else
      update(1,1,n,dfn[x-1],dfn[x-1]+size[x-1]-1,1);  
 }
 if (dfn[x+1]<dfn[x] && x!=n)
 {
    if (check(x,x+1))
      update(1,1,n,1,maxdfn[x+1],1);  
          else
      update(1,1,n,dfn[x+1],dfn[x+1]+size[x+1]-1,1);  
 }
 update(1,1,n,1,dfn[x]-1,-1); //还原所有的操作,因为要计算别的为1 的ans,之所以是dfn[x]-1 相当于给这个赋值为1.(至少一段) 
    return maxdfn[x];
}
signed main()
{
  n=read();
  for (int i=1;i<n;i++)
  {
   int x=read(),y=read();
   addedge(x,y);
 addedge(y,x); 
  }
  dfs(1,0,1);
  init();
  build(1,1,n); 
  for (int i=1;i<=n;i++) b[dfn[i]]=i;
  dp(1,0);
  cout<<ans;
  return 0;
}
posted @ 2018-12-22 18:13  y_immortal  阅读(292)  评论(0编辑  收藏  举报