【BZOJ3672】【NOI2014】购票(线段树,斜率优化,动态规划)
【BZOJ3672】【NOI2014】购票(线段树,斜率优化,动态规划)
题解
首先考虑\(dp\)的方程,设\(f[i]\)表示\(i\)的最优值
很明显的转移\(f[i]=min(f[j]+(dep[i]-dep[j])·p[i])+q[i]\)
其中满足\(dep[i]-dep[j]\le L[i]\)
然后就可以写出一个\(O(n^2)\)的做法啦
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define ll long long
#define RG register
#define MAX 222222
inline ll read()
{
RG ll x=0,t=1;RG char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
struct Line{int v,next,w;}e[MAX];
int h[MAX],cnt=1;
inline void Add(int u,int v,int w){e[cnt]=(Line){v,h[u],w};h[u]=cnt++;}
int n,type,fa[MAX];
ll S[MAX],P[MAX],Q[MAX],L[MAX];
ll dep[MAX],f[MAX];
void dfs(int u)
{
for(int i=h[u];i;i=e[i].next)
dep[e[i].v]=dep[u]+e[i].w,dfs(e[i].v);
}
double Slope(int a,int b){return (f[a]-f[b])*1.0/(dep[a]-dep[b]);}
namespace Brute
{
void DFS(int u)
{
if(u!=1)f[u]=1e18;
for(int i=fa[u];i&&dep[u]-dep[i]<=L[u];i=fa[i])
f[u]=min(f[u],f[i]+(dep[u]-dep[i])*P[u]+Q[u]);
for(int i=h[u];i;i=e[i].next)DFS(e[i].v);
}
void Solve()
{
DFS(1);
for(int i=2;i<=n;++i)printf("%lld\n",f[i]);
}
}
int main()
{
n=read();type=read();
for(int i=2;i<=n;++i)fa[i]=read(),S[i]=read(),P[i]=read(),Q[i]=read(),L[i]=read();
for(int i=2;i<=n;++i)Add(fa[i],i,S[i]);
dfs(1);
if(n<=2000){Brute::Solve();return 0;}
}
如果没有下面的那条限制,我们可以很容易的写出一个效率优化
设\(k<j<i\),\(j\)的转移优于\(k\),那么有
\(f[j]-dep[j]*p[i]<f[k]-dep[k]*p[i]\)
移项得到\(p[i]>\frac{f[j]-f[k]}{dep[j]-dep[k]}\)
很明显的斜率优化。
就这样我们就写出了一个\(t=0\)的\(20\)代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define ll long long
#define RG register
#define MAX 222222
inline ll read()
{
RG ll x=0,t=1;RG char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
struct Line{int v,next,w;}e[MAX];
int h[MAX],cnt=1;
inline void Add(int u,int v,int w){e[cnt]=(Line){v,h[u],w};h[u]=cnt++;}
int n,type,fa[MAX];
ll S[MAX],P[MAX],Q[MAX],L[MAX];
ll dep[MAX],f[MAX];
void dfs(int u)
{
for(int i=h[u];i;i=e[i].next)
dep[e[i].v]=dep[u]+e[i].w,dfs(e[i].v);
}
double Slope(int a,int b){return (f[a]-f[b])*1.0/(dep[a]-dep[b]);}
namespace Task0
{
int St[MAX],top;
int check(double K)
{
int l=2,r=top,ret=1;
while(l<=r)
{
int mid=(l+r)>>1;
if(Slope(St[mid],St[mid-1])>=K)r=mid-1;
else l=mid+1,ret=mid;
}
return St[ret];
}
void Solve()
{
St[top=1]=1;
for(int i=2;i<=n;++i)
{
int j=check(P[i]);
f[i]=f[j]+Q[i]+(dep[i]-dep[j])*P[i];
while(top>1&&Slope(i,St[top-1])<Slope(St[top],St[top-1]))--top;
St[++top]=i;
printf("%lld\n",f[i]);
}
}
}
int main()
{
n=read();type=read();
for(int i=2;i<=n;++i)fa[i]=read(),S[i]=read(),P[i]=read(),Q[i]=read(),L[i]=read();
for(int i=2;i<=n;++i)Add(fa[i],i,S[i]);
dfs(1);
if(type==0){Task0::Solve();return 0;}
}
当然,再把暴力给加上就有\(40\)分了。
(自己去拼接啊,要不然代码太多了)
剩下的部分分。
对于\(t=1\),我们发现是没有距离限制的
显然是对于每一条链维护一个凸包,所以只需要维护一个可持久化栈然后在上面二分就好了。
这个东西似乎可以用主席树维护,然后在主席树上面二分就行了。
(我就懒得写了)
对于\(t=2\),是一条链,但是有距离限制,
我们发现我们会从前面开始删去凸包上的一些点,这似乎非常不好做,
但是我们似乎可以用\(CDQ\)分治来做?按照能够用来更新的区间排序
每次\(CDQ\)的时候维护一段区间的凸包,然后更新一下答案就好了
(似乎可以这样做吧。。。)
现在终于可以来讲正解之一啦
现在的问题主要是如何动态维护凸包
其实我们把树直接树链剖分之后,对应的满足条件的是\(dfs\)序上的多段区间
所以我们用线段树暴力维护凸包,每个节点开一个\(vector\)
然后暴力维护这段区间的凸包
每个点最多会被\(log\)个线段树节点所包含,复杂度\(O(nlogn)\)
然后每次询问的时候在跳重链+线段树+二分
似乎是三个\(log\),但是我觉得只有两个\(log\)
似乎二分的\(log\)是加起来才有一个\(log\)吧。。。
代码如下:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define ll long long
#define RG register
#define MAX 222222
#define lson (now<<1)
#define rson (now<<1|1)
inline ll read()
{
RG ll x=0,t=1;RG char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
struct Line{int v,next;ll w;}e[MAX];
int h[MAX],cnt=1;
inline void Add(int u,int v,ll w){e[cnt]=(Line){v,h[u],w};h[u]=cnt++;}
int n,type,fa[MAX];
ll S[MAX],P[MAX],Q[MAX],L[MAX];
ll dis[MAX],f[MAX];
int size[MAX],top[MAX],dfn[MAX],low[MAX],ln[MAX],tim,hson[MAX];
double Slope(int a,int b){return (f[a]-f[b])*1.0/(dis[a]-dis[b]);}
void dfs(int u)
{
size[u]=1;
for(int i=h[u];i;i=e[i].next)
{
int v=e[i].v;dis[v]=dis[u]+e[i].w;
dfs(v);size[u]+=size[v];
if(size[v]>size[hson[u]])hson[u]=v;
}
}
void dfs(int u,int tp)
{
top[u]=tp;dfn[u]=++tim;ln[tim]=u;
if(hson[u])dfs(hson[u],tp);
for(int i=h[u];i;i=e[i].next)
if(e[i].v!=hson[u])dfs(e[i].v,e[i].v);
}
vector<int> t[MAX<<2];
void Modify(int now,int l,int r,int p)
{
int tp=t[now].size();
while(tp>1&&Slope(ln[p],t[now][tp-2])<Slope(t[now][tp-1],t[now][tp-2]))
--tp,t[now].pop_back();
t[now].push_back(ln[p]);
if(l==r)return;
int mid=(l+r)>>1;
if(p<=mid)Modify(lson,l,mid,p);
else Modify(rson,mid+1,r,p);
}
ll Calc(vector<int> t,int u)
{
int l=1,r=t.size()-1,ret=0;
while(l<=r)
{
int mid=(l+r)>>1;
if(Slope(t[mid],t[mid-1])<1.0*P[u])l=mid+1,ret=mid;
else r=mid-1;
}
int v=t[ret];
return f[v]+(dis[u]-dis[v])*P[u]+Q[u];
}
ll Query(int now,int l,int r,int L,int R,int u)
{
if(L<=l&&r<=R)return Calc(t[now],u);
int mid=(l+r)>>1;ll ret=1e18;
if(L<=mid)ret=min(ret,Query(lson,l,mid,L,R,u));
if(R>mid)ret=min(ret,Query(rson,mid+1,r,L,R,u));
return ret;
}
int Top=0;
void Jump(int u,int anc)
{
f[u]=5e18;int U=u;u=fa[u];
while(top[u]^top[anc])
{
f[U]=min(f[U],Query(1,1,n,dfn[top[u]],dfn[u],U));
u=fa[top[u]];
}
f[U]=min(f[U],Query(1,1,n,dfn[anc],dfn[u],U));
}
void DFS(int u)
{
S[++Top]=u;
if(u!=1)
{
int l=1,r=Top-1,v=u;
while(l<=r)
{
int mid=(l+r)>>1;
if(dis[u]-dis[S[mid]]<=L[u])v=S[mid],r=mid-1;
else l=mid+1;
}
Jump(u,v);
}
Modify(1,1,n,dfn[u]);
for(int i=h[u];i;i=e[i].next)DFS(e[i].v);
--Top;
}
int main()
{
n=read();type=read();
for(int i=2;i<=n;++i)fa[i]=read(),S[i]=read(),P[i]=read(),Q[i]=read(),L[i]=read();
for(int i=2;i<=n;++i)Add(fa[i],i,S[i]);
dfs(1);dfs(1,1);memset(S,0,sizeof(S));DFS(1);
for(int i=2;i<=n;++i)printf("%lld\n",f[i]);
return 0;
}