loj#6518. 「雅礼集训 2018 Day11」序列(保序回归)
题目描述
保序回归
给出序列a和形如ai>=aj的限制条件,把x修改成y的代价为(|x-y|)^k,求最小代价
整体二分,对当前值域二分从而变成每个数修改为mid或mid+1,修改为mid的最终值<=mid,修改为mid+1的最终值>=mid+1,继续往下二分
证明感受一下
一般情况下用网络流来求
建图1
转化为最大权闭合子图,如果有a[u]>=a[v]则当v选了1u必须选1,一开始假设选0点权为(1的值-0的值)
最大权闭合子图:原图的边为inf,权值正则连S->u,为负则权值取反连u->T
假设全部选正的,割正边等于不选,割负边等于选,答案=∑正-最小割
建图2
自己想的,还没试过
S->i连选0的代价,i->T连选1的代价,如果a[u]>=a[v]则v->u连inf
割掉等于选,这样不能u选0v选1
本题
n=5000,建图跑网络流是n^3的所以用dp解决
设f[i][0/1]表示以i结尾,结尾为0/1的答案,考虑当前段的开头
一条限制可以拆成两个方向考虑,向前等于在线段树上把一段删掉,向后等于限制转移位置,用单调栈处理
要注意dp的是当前层的数,所以中间会有空,如果覆盖的不是这一层的就不用管,所以要考虑头/尾/间隔之类的东西,弹栈时根据是否覆盖下一位来弹
code
#include <bits/stdc++.h>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define abs(x) ((x)>0?(x):-(x))
#define min(a,b) (a<b?a:b)
#define max(a,b) (a>b?a:b)
#define inf 1000000000
#define ll long long
//#define file
using namespace std;
int a[5001],A[2][5001],S[5001][2][2],nx[5001],b[5001],Ans[5001],n,m,i,j,k,l,tp,x,y,z,tot,T[2];
ll f[5001][2],g[5001][2],ans;
struct tree{
ll tr[20001],Tr[20001];
int tr2[20001],tot,d[20001];
bool bz[20001];
void clear()
{
int i;
fo(i,1,tot) tr[d[i]]=inf,bz[d[i]]=Tr[d[i]]=0;tot=0;
}
void add(int t) {if (!bz[t]) bz[t]=1,d[++tot]=t;}
void down(int t,int len)
{
add(t);
if (Tr[t])
{
if (len>1) Tr[t*2]+=Tr[t],Tr[t*2+1]+=Tr[t],add(t*2),add(t*2+1);
tr[t]+=Tr[t],Tr[t]=0;
}
}
void up(int t) {if (tr[t*2]+Tr[t*2]<tr[t*2+1]+Tr[t*2+1]) tr[t]=tr[t*2]+Tr[t*2],tr2[t]=tr2[t*2]; else tr[t]=tr[t*2+1]+Tr[t*2+1],tr2[t]=tr2[t*2+1];}
void change(int t,int l,int r,int x,int y,ll s)
{
int mid=(l+r)/2;
down(t,r-l+1);
if (x<=l && r<=y) {Tr[t]=Tr[t]+s;down(t,r-l+1);return;}
if (x<=mid) change(t*2,l,mid,x,y,s);
if (mid<y) change(t*2+1,mid+1,r,x,y,s);
up(t);
}
void Change(int t,int l,int r,int x,ll s,int s2)
{
int mid=(l+r)/2;
down(t,r-l+1);
if (l==r) {if (tr[t]>s) tr[t]=s,tr2[t]=s2;return;}
if (x<=mid) Change(t*2,l,mid,x,s,s2);
else Change(t*2+1,mid+1,r,x,s,s2);
up(t);
}
pair<ll,int> find(int t,int l,int r,int x,int y)
{
pair<ll,int> ans,s;ans.first=inf;
int mid=(l+r)/2;
down(t,r-l+1);
if (x<=l && r<=y) return pair<ll,int>(tr[t],tr2[t]);
if (x<=mid) {s=find(t*2,l,mid,x,y);if (s.first<ans.first) ans=s;}
if (mid<y) {s=find(t*2+1,mid+1,r,x,y);if (s.first<ans.first) ans=s;}
return ans;
}
} t1,t2;
struct sta{
int d[5001],tot,tp;
void clear() {tot=0;}
void add(int t)
{
while (tot && S[d[tot]][tp][1]<=S[t][tp][1]) --tot;
d[++tot]=t;
}
} d1,d2;
void work(int l,int r,int st,int ed)
{
int i,j,k,mid=(l+r)/2,St1,Ed1,St2,Ed2;
pair<ll,int> s;
tot=0;
for (i=st; i!=ed; i=nx[i]) b[++tot]=i;
b[++tot]=ed;
t1.clear(),t2.clear();
d1.clear(),d2.clear();
t1.Change(1,1,n,1,0,0);
t2.Change(1,1,n,1,0,0);
fo(i,1,tot)
{
f[i][0]=f[i][1]=inf;
if (i>1 && S[b[i]][1][0]<b[i]) t1.change(1,1,n,S[b[i]][1][0]+1,b[i],inf);
if (i>1 && S[b[i]][0][0]<b[i]) t2.change(1,1,n,S[b[i]][0][0]+1,b[i],inf);
t1.change(1,1,n,1,b[i],abs(a[b[i]]-mid));
t2.change(1,1,n,1,b[i],abs(a[b[i]]-(mid+1)));
while (d1.tot && (i==tot || S[d1.d[d1.tot]][1][1]<=max(b[i+1]-1,S[b[i]][1][1]))) --d1.tot;if (i<tot && S[b[i]][1][1]>=b[i+1]) d1.d[++d1.tot]=b[i];
while (d2.tot && (i==tot || S[d2.d[d2.tot]][0][1]<=max(b[i+1]-1,S[b[i]][0][1]))) --d2.tot;if (i<tot && S[b[i]][0][1]>=b[i+1]) d2.d[++d2.tot]=b[i];
if (d1.d[d1.tot]<b[i])
{
s=t1.find(1,1,n,d1.d[d1.tot]+1,b[i]);
f[i][0]=s.first;g[i][0]=s.second;
}
if (d2.d[d2.tot]<b[i])
{
s=t2.find(1,1,n,d2.d[d2.tot]+1,b[i]);
f[i][1]=s.first;g[i][1]=s.second;
}
if (i<tot)
{
t1.Change(1,1,n,b[i]+1,f[i][1],i);
t2.Change(1,1,n,b[i]+1,f[i][0],i);
}
}
j=tot;k=(f[tot][1]<f[tot][0]);T[0]=T[1]=0;
while (j)
{
fd(i,j,g[j][k]+1) Ans[b[i]]=mid+k,A[k][++T[k]]=b[i];
j=g[j][k],k=!k;
}
fo(i,0,1) fo(j,2,T[i]) nx[A[i][j]]=A[i][j-1];
St1=(!T[0])?-1:A[0][T[0]];Ed1=A[0][1];
St2=(!T[1])?-1:A[1][T[1]];Ed2=A[1][1];
if (l<r)
{
if (St1>-1) work(l,mid,St1,Ed1);
if (St2>-1) work(mid+1,r,St2,Ed2);
}
}
int main()
{
#ifdef file
freopen("loj6518.in","r",stdin);
#endif
// freopen("b.out","w",stdout);
memset(t1.tr,1,sizeof(t1.tr));
memset(t2.tr,1,sizeof(t2.tr));
d1.tp=0,d2.tp=1;
scanf("%d%d",&n,&m);
fo(i,1,n) S[i][0][0]=S[i][1][0]=n+1,nx[i]=i+1;nx[0]=1;
fo(i,1,n) scanf("%d",&a[i]);
fo(i,1,m) scanf("%d%d%d%d",&tp,&x,&y,&z),S[z][tp][0]=min(S[z][tp][0],x),S[z][tp][1]=max(S[z][tp][1],y);
work(1,100000,1,n);
fo(i,1,n) ans+=abs(a[i]-Ans[i]);
printf("%lld\n",ans);
fclose(stdin);
fclose(stdout);
return 0;
}