【JZOJ5402】God Knows
Description
Solution
澄清(说明):以下的
Pi
为题目中的
pi
,以下的
p
(如果有的话)并不指代
考虑dp入手,设
fi
表示当前做到
i
且选择
我们考虑它会从哪些点转移过来,以下:
黄点是当前做到的点,绿点和蓝点都是可能转移的位置,红点为不能转移的位置。
这里蓝点的意思:设蓝点为
为什么可以从这些点转移,我们首先有个结论,组成答案(方案)的线段一定是互不相交的。那么首先蓝点可以转移,因为它的线段可以切断所有的 (j,Pj)(j<i 且 Pj<Pi) ,那么这些绿点也可以转移,因为它们都具有蓝点的性质。
于是我们就有一个成型的 O(n2) 做法,枚举转移点然后转移。
我们考虑如何用数据结构优化,那么对于右边这个坐标系的图,我们相当于把经过黄点(包括这个点)水平直线以上的点全部忽略,然后用剩下的点(以
i
为下标)形成的单调递减的栈,那么栈中的点就是可以转移的点。然而这是在值域区间中忽略掉一些点,我们很难处理。那么,我们将
那么,此时忽略过黄点(包括黄点)的垂直与
Pi
轴的线的右边的点,剩余的能形成单调递减的栈的点就是转移点(即绿点和蓝点,图为上文的图翻转过来),而且这样做的优点在于可以在定义域区间上处理,那么就可以上数据结构了。
然后我们怎么维护这个东西呢?(以下的操作以正上方这幅图为基准,用线段树维护)
设函数
g(l,r,p)
表示在
[l,r]
这个区间内加入
p
后所形成的单调栈
再考虑查询:
当
当
l<r
且
[l,r]
为完整的线段树区间时,设
mid=l+r2
,考虑将
p
加入当前区间的单调栈,设
但是这样的复杂度为
我们发现,一个区间 [l,r] 对应许多线段树区间,找到这些区间的复杂度是 O(log2n) 的,对于一个区间,我们可能会递归两个区间下去求答案,但我们发现,当rmx>p时,左边区间的答案为 g(l,mid,rmx) ,这就可以预处理。我们设 flmin(l,r)=g(l,mid,rmx) ,那么每次插入时维护 flmin ,即可。
那么这样复杂度为:对于一个区间 [l,r] ,它所包含的线段树区间有 O(log2n) 个,每个完整的线段树区间会一直走到底,复杂度为 O(log2n) ,于是单次查询的复杂度为 O(log22n) 。对于 fminl 的计算,每次插入操作会涉及 O(log2n) 个区间,每个完整区间的 fminl 计算复杂度是 O(log2n) 的,所以单次修改复杂度于是 O(log22n) 。
总的复杂度即为 O(nlog22n) 。
我们还可以发现,每次加入的
i
是单调递增的,所以该方法还可以优化成但由于本人姿势水平不够,所以这里暂且留坑。
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
#define N 200010
#define inf 1000000000
using namespace std;
int p[N],c[N],f[N];
void read(int &n){
char ch=' ';
int t=0,c=1;
for(;(ch!='-')&&((ch<'0')||(ch>'9'));ch=getchar());
if(ch=='-') c=-1,ch=getchar();
for(;ch>='0' && ch<='9';ch=getchar()) t=t*10+ch-48;
n=t*c;
}
struct node{
int mf,mx,f;
}tr[N*4];
int max(int x,int y){
return x>y?x:y;
}
int min(int x,int y){
return x<y?x:y;
}
int rmx=0;
int find(int v,int l,int r,int x,int y,int p){
if(tr[v].mx<p) return inf;
if(l==r) {rmx=max(rmx,tr[v].mx);return tr[v].f;}
int mid=(l+r)/2;
if(l==x && r==y)
{
if(tr[v*2+1].mx<p) return find(v*2,l,mid,x,mid,p);
rmx=max(rmx,tr[v].mx);
return min(tr[v].mf,find(v*2+1,mid+1,r,mid+1,y,p));
tr[v].mf=find(v*2,l,mid,x,mid,tr[v*2+1].mx);
}
if(y<=mid) return find(v*2,l,mid,x,y,p);
else if(x>mid) return find(v*2+1,mid+1,r,x,y,p);
else
{
rmx=0;
int tmp=find(v*2+1,mid+1,r,mid+1,y,p);
return min(tmp,find(v*2,l,mid,x,mid,rmx));
}
}
void insert(int v,int l,int r,int x,int p){
if(l==r) {tr[v].f=f[p],tr[v].mx=p;return;}
int mid=(l+r)/2;
if(x<=mid) insert(v*2,l,mid,x,p);
else insert(v*2+1,mid+1,r,x,p);
tr[v].mf=find(v*2,l,mid,l,mid,tr[v*2+1].mx);
tr[v].f=min(tr[v*2].f,tr[v*2+1].f);
tr[v].mx=max(tr[v*2].mx,tr[v*2+1].mx);
}
int main()
{
freopen("knows.in","r",stdin);
freopen("knows.out","w",stdout);
int n;
scanf("%d",&n);
fo(i,1,n) read(p[i]);
fo(i,1,n) read(c[i]);
fo(v,1,n*4) tr[v].f=tr[v].mf=inf;
for(int v=1;v<=n*4;v*=2) tr[v].f=tr[v].mf=0;
fo(i,1,n)
{
f[i]=find(1,0,n,0,p[i]-1,0)+c[i];
insert(1,0,n,p[i],i);
}
int ans=inf,o=0;
fd(i,n,1) if(o<p[i]) o=p[i],ans=min(ans,f[i]);
printf("%d",ans);
}