保序回归学习笔记
既然 Tutte
矩阵的坑填了但是没有完全填上,保序回归也要打算开始填了。
定义
给定一张 DAG,每个点 \(u\) 可以表示成一个三元组 \((x_{u},y_{u},w_{u})\)
对于一条边 \((u,v)\) 表示 \(x_{u}\geq x_{v}\)
给出一组 \(x\) 满足 \(\sum^{n}_{i=1}w_{i}|(x_{i}-y_{i})|^k\) 最小
然而这个形式碰到的并不多,实际上最多的还是 \(L_{k}\) 类问题,其中的 \(L\) 指的是链的意思。
\(L_{k}的性质\)
首先可以发现,如果 \(y\) 也是单调不减的,那么答案为 \(0\)
否则找到 \(y_{i}>y_{i+1}\) 的位置 \(i\) ,最优解一定是 \(x_{i}=x_{i+1}\)
-
证明
-
$x_{i+1} \leq y_{i+1} $
此时让 \(x_{i}=x_{i+1}=y_{i+1}\) 一定最优
-
\(x_{i+1}\geq y_{i+1}\)
同1
-
\(y_{i+1} < x_{i}\leq x_{i+1} < y_{i}\)
显然如果 \(x_{i}\not = x_{i+1}\) ,那么让 \(x_{i} \leftarrow x_{i}+d\) 更优。
综述当 \(x_{i}=x_{i+1}\) 的最优。
-
得到这个性质之后,我们首先可以解决其中的一类特殊情况了。
\(k=2\)
我们找到一个 \(i\) 使得 \(y_{i}>y_{i+1}\) ,那么有 \(x_{i}=x_{i+1}=x\)
后面一部分是一个常数,可以不管,直接加到答案中,前面则可以看作是一个子问题,这样就可以解决 \(L_{2}\) 问题了。
例题 [HNOI2019]序列
一眼看上去似乎不是很好做,但是注意到修改是独立的,这类操作的套路是分为 \([1,x-1]\) 和 \([x+1,n]\) 两段来维护,那么问题就在于怎么合并这三个部分了。
根据保序回归得到的结论,对于剩下来的序列我们再做一次暴力得到的答案是正确,因为最后的答案和合并过程没有关系。
假设最后 \(x\) 合并出来的区间为 \((L,R)\) ,那么一定满足 \(y_{L}<y_{(L,R)}<y_{R}\) 。
同时,最后的 \((L,R)\) 一定是唯一的,所以等价于找到最小的 \(R'\) 满足 \(y_{(L,R')}<y_{R'}\) ,\(L\) 同理,这两个东西都可以二分。
总之迷迷糊糊就写完了,细节还是挺多的。
#include<cstdio>
#include<iostream>
#include<vector>
#include<algorithm>
#include<cstring>
#include<queue>
using namespace std;
#define ll long long
#define ri int
#define pii pair<int,int>
#define px pair<int,node>
int n,m;
const int MAXN=1e5+7;
const ll mod=998244353;
ll add(ll x,ll y){return (x+=y)<mod?x:x-mod;}
ll dec(ll x,ll y){return (x-=y)<0?x+mod:x;}
ll ksm(ll d,ll t,ll res=1){for(;t;t>>=1,d=d*d%mod) if(t&1) res=res*d%mod;return res;}
struct Z{ll w;};
Z operator+(const Z &a,const Z &b){return (Z){add(a.w,b.w)};}
Z operator-(const Z &a,const Z &b){return (Z){dec(a.w,b.w)};}
Z operator*(const Z &a,const Z &b){return (Z){a.w*b.w%mod};}
Z operator/(const Z &a,const Z &b){return (Z){a.w*ksm(b.w,mod-2)%mod};}
ll a[MAXN],B,inv[MAXN];
Z ansx[MAXN],ansy[MAXN],ans[MAXN];
struct node{
Z w,y;
ll up,down;
int l,r;
};
struct G{ll up,down;};
bool operator<=(const G &p,const G &q){return p.up*q.down<=q.up*p.down;}
bool operator<(const G &p,const G &q){return p.up*q.down<q.up*p.down;}
bool operator<=(const node &p,const node &q){return p.up*q.down<=q.up*p.down;}
Z operator^(const node &p,const node &q){
Z mid=(p.w*p.y+q.w*q.y)/(p.w+q.w);
return p.w*(mid-p.y)*(mid-p.y)+q.w*(mid-q.y)*(mid-q.y);
}
node operator+(const node &p,const node &q){return (node){p.w+q.w,(p.w*p.y+q.w*q.y)/(p.w+q.w),p.up+q.up,p.down+q.down,min(p.l,q.l),max(p.r,q.r)};}
struct Stack{
node w[MAXN];
int Top;
void push(const node &W){w[++Top]=W;}
void pop(){Top--;}
int size(){return Top;}
bool empty(){return Top==0;}
node top(){return w[Top];}
}Sl,Sr;
vector<pii> Q[MAXN];
vector<node> vec[MAXN];
ll sum[MAXN],sumX[MAXN];
int X;
G V(int l,int r){return (G){sum[r]-sum[l-1]-a[X]+B,r-l+1};}
int findL(int R){
int l=1,r=Sl.size(),res=1;
while(l<=r){
int mid=(l+r>>1);
G now=V(Sl.w[mid].r+1,R);
if(now<=(G){Sl.w[mid].up,Sl.w[mid].down}) r=mid-1;
else l=mid+1,res=Sl.w[mid].r+1;
}
return res;
}
int findR(){
int l=1,r=Sr.size(),res=n;
while(l<=r){
int mid=l+r>>1,p=findL(Sr.w[mid].l-1);
if((G){Sr.w[mid].up,Sr.w[mid].down}<=V(p,Sr.w[mid].l-1)) r=mid-1;
else l=mid+1,res=Sr.w[mid].l-1;
}
return res;
}
int main(){
scanf("%d%d",&n,&m);
for(ri i=1;i<=n;++i) scanf("%lld",&a[i]),sum[i]=sum[i-1]+a[i],sumX[i]=add(sumX[i-1],a[i]*a[i]%mod),inv[i]=ksm(i,mod-2);
for(ri i=1;i<=m;++i){
int x,y;scanf("%d%d",&x,&y);
Q[x].push_back((pii){i,y});
}
for(ri i=n;i;--i){
node cur=(node){(Z){1},(Z){a[i]},a[i],1,i,i};
ansy[i]=ansy[i+1];
while(!Sr.empty()&&Sr.top()<=cur){
ansy[i]=ansy[i]+(cur^Sr.top());
cur=cur+Sr.top();
vec[i].push_back(Sr.top());
Sr.pop();
}
Sr.push(cur);
}
printf("%lld\n",ansy[1].w);
for(X=1;X<=n;++X){
Sr.pop();
while(!vec[X].empty()) Sr.push(vec[X].back()),vec[X].pop_back();
for(auto x:Q[X]){
B=x.second;
int r=findR(),l=findL(r);
ans[x.first]=ansx[l-1]+ansy[r+1];
Z A=(Z){add(dec(dec(sumX[r],sumX[l-1]),1ll*a[X]*a[X]%mod),1ll*B*B%mod)};
Z C=(Z){add(dec((sum[r]-sum[l-1])%mod,a[X]),B)};
C=C*C*(Z){inv[r-l+1]};
ans[x.first]=ansx[l-1]+ansy[r+1]+A-C;
}
node cur=(node){(Z){1},(Z){a[X]},a[X],1,X,X};
ansx[X]=ansx[X-1];
while(!Sl.empty()&&cur<=Sl.top()){
ansx[X]=ansx[X]+(cur^Sl.top());
cur=cur+Sl.top();
Sl.pop();
}
Sl.push(cur);
}
for(ri i=1;i<=m;++i) printf("%lld\n",ans[i].w);
}