「CodePlus 2018 4 月赛」Tommy 的结合(斜率优化)
题意:
思路:
n<=66 \(dp[i][j]\)表示匹配到i,j的时候最大值,转移枚举下一个点,总复杂度\(O(n^4)\)
为了降低复杂度
\(dp[l][r]=max(dp[i][j]+c[l][r]-dis[i][l]^2-dis[j][r]^2)\)
如果想只枚举\(r\) ,可以把\(dp[i][j]-dis[i][l]^2\)给处理出来,就多开一个数组
\(dp[l][r]\)表示恰好到 \(G[l'][r]\)表示往后多走了一个\(l\)
\(dp[i][j]=max(G[i][k]-(dis[j-1]-dis[k])^2)+C[i][j]\)
\(G[i][j]=max(dp[l][j]-(dis[i-1]-dis[l])^2)\)
这样复杂度就是\(O(n^3)\)
看这样的式子可以想到斜率优化(链上的时候)
\(dp[i][j]=max(G[i][k]-dis[k]^2+2*dis[j-1]*dis[k])+c[i][j]-dis[j-1]^2\)
dis是单调递增的
\(G[l]-dis_l^2+2*dis_{j-1}*dis_l>=G[r]-dis_r^2+2*dis_{j-1}*dis_r\)
\(dis_{j-1}>=\frac{G_r-G_l+dis_l^2-dis_r^2}{dis_l-dis_r}\)
维护斜率递增的凸包,G同理
上面可以解决链上的情况
然后就是树上的斜率优化,每次从父亲节点继承\(l,r\),可以二分这时候加入的右端点,记录原来栈里的值用于还原。
这样子总的复杂度就是\(O(n^2logn)\)
#include<bits/stdc++.h>
#define M 2705
#define ll long long
using namespace std;
void Rd(int &res) {
res=0;
char c;
int fl=1;
while(c=getchar(),c<48)if(c=='-')fl=-1;
do res=(res<<1)+(res<<3)+(c^48);
while(c=getchar(),c>=48);
res*=fl;
}
struct Node {
int tot,n,to[M],pr[M],la[M],fa[M],dep[M],a[M],L[M],R[M],Id[M],id,dis[M];
void add(int x,int y) {
to[++tot]=y,pr[tot]=la[x],la[x]=tot;
}
void dfs(int x,int f) {
dep[x]=dep[f]+a[x],L[x]=++id,Id[id]=x,dis[x]=dis[f]+a[x];
for(int i=la[x]; i; i=pr[i])if(to[i]!=f)dfs(to[i],x);
R[x]=id;
}
} A[2];
int C[M][M];
struct P2 {
ll dis[2][M],dp[M][M],G[M][M];
struct node {
ll y,x;
} stk[M],stk2[M][M],od[M][M],Od[M];
int L1[M],R1[M],L2[M][M],R2[M][M];//前面针对dp的,后面针对G的
ll up(node l,node r) {//l在r的右面
return r.y-l.y+l.x*l.x-r.x*r.x;
}
ll down(node l,node r) {
return l.x-r.x;
}
ll calc(ll x) {
return 1ll*x*x;
}
void Dfs(int x,int f,int rt) {
int L,R,p;
L1[x]=L1[f],R1[x]=R1[f];
int &l=L1[x],&r=R1[x];
if(x!=1) {
p=l,L=l+1,R=r;
while(L<=R) {
int mid=(L+R)>>1;
if(2.0*dis[1][f]*down(stk[mid],stk[mid-1])>=1.0*up(stk[mid],stk[mid-1]))p=mid,L=mid+1;
else R=mid-1;
}
l=p,dp[rt][x]=stk[l].y-calc(stk[l].x-dis[1][f])+C[rt][x];
}
L=l,R=r-1,p=r;
node now=(node)<%G[rt][x],dis[1][x]%>;
while(L<=R) {
int mid=(L+R)>>1;
if(1.0*up(stk[mid+1],stk[mid])*down(now,stk[mid+1])>=1.0*up(now,stk[mid+1])*down(stk[mid+1],stk[mid]))p=mid,R=mid-1;
else L=mid+1;
}
r=p,Od[x]=stk[r+1],stk[++r]=now;
for(int y,i=A[1].la[x]; i; i=A[1].pr[i])if((y=A[1].to[i])!=f)Dfs(y,x,rt);
if(x!=1)stk[R1[x]]=Od[x];
}
void dfs(int x,int f) {
if(x!=1) {
for(int y,i=1; i<=A[1].n; i++) {
y=A[1].Id[i],L2[x][y]=L2[f][y],R2[x][y]=R2[f][y];
int &l=L2[x][y],&r=R2[x][y],L=l+1,R=r,p=l;
while(L<=R) {
int mid=(L+R)>>1;
if(2.0*dis[0][f]*down(stk2[y][mid],stk2[y][mid-1])>=1.0*up(stk2[y][mid],stk2[y][mid-1]))p=mid,L=mid+1;
else R=mid-1;
}
l=p,G[x][y]=stk2[y][p].y-calc(dis[0][f]-stk2[y][p].x);
}
L1[0]=0,R1[0]=-1,Dfs(1,0,x);
}
for(int i=1; i<=A[1].n; i++) {
int y=A[1].Id[i],&l=L2[x][y],&r=R2[x][y],L=l,R=r-1,p=r;
node now=(node)<%dp[x][y],dis[0][x]%>;
while(L<=R) {
int mid=(L+R)>>1;
if(1.0*up(stk2[y][mid+1],stk2[y][mid])*down(now,stk2[y][mid+1])>=1.0*up(now,stk2[y][mid+1])*down(stk2[y][mid+1],stk2[y][mid]))p=mid,R=mid-1;
else L=mid+1;
}
r=p,od[x][y]=stk2[y][r+1],stk2[y][++r]=now;
}
for(int y,i=A[0].la[x]; i; i=A[0].pr[i])if((y=A[0].to[i])!=f)dfs(y,x);
for(int y,i=1; i<=A[1].n; i++)y=A[1].Id[i],stk2[y][R2[x][y]]=od[x][y];
}
void solve() {
A[0].dfs(1,0),A[1].dfs(1,0);
for(int i=1; i<=A[0].n; i++)dis[0][i]=A[0].dis[i];
for(int i=1; i<=A[1].n; i++)dis[1][i]=A[1].dis[i];
memset(dp,-63,sizeof(dp)),memset(G,-63,sizeof(G)),memset(R2,-1,sizeof(R2)),memset(R1,-1,sizeof(R1));
dp[1][1]=G[1][1]=0,dfs(1,0);
ll ans=-1e18;
for(int i=1; i<=A[0].n; i++)for(int j=1; j<=A[1].n; j++)ans=max(ans,dp[i][j]);
printf("%lld\n",ans);
}
} p2;
int main() {
Rd(A[0].n),Rd(A[1].n);
for(int i=2; i<=A[0].n; i++)Rd(A[0].a[i]);
for(int i=2; i<=A[1].n; i++)Rd(A[1].a[i]);
for(int i=2; i<=A[0].n; i++)Rd(A[0].fa[i]),A[0].add(A[0].fa[i],i);
for(int i=2; i<=A[1].n; i++)Rd(A[1].fa[i]),A[1].add(A[1].fa[i],i);
for(int i=2; i<=A[0].n; i++)for(int j=2; j<=A[1].n; j++)Rd(C[i][j]);
p2.solve();
return 0;
}