题解 Decompose
首先写出转移方程:
\[f_{i, 1}=\sum\limits_v\max\{f_{v, 1..m}\}+w_{i, 1}
\]
\[f_{i, j(j>1)}=(\sum\limits_v\max\{f_{v, 1...m}\})-\min\limits_v(\max\{f_{v, 1...m}\}-f_{v, j-1})+w_{i, j}
\]
发现下面这个转移带个 min 很讨厌
那么可以换一种使用 max 表示的写法
- 别觉得动态 DP 里又有 min 又有 max 就没法写了,看看能不能统一用同一种表示
\[f_{i, j(j>1)}=\max\{f_{v, j-1}-\max\{f_{v, k}\}\}+\sum\limits_v\max\limits_j\{f_{v, j}\}+w_{i, j}
\]
将一个点的轻儿子视作常数
那么就是要用矩阵实现重儿子和一些常数比较大小的过程了
没有轻儿子可以将轻儿子的值视为 -inf
复杂度 \(O(n\log^2nL^3)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 100010
#define ll long long
//#define int long long
char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
int ans=0, f=1; char c=getchar();
while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
return ans*f;
}
int n, q, l;
ll w[N][5];
int head[N], back[N], ecnt;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}
namespace force{
ll f[N][5];
void dfs(int u) {
// cout<<"dfs: "<<u<<endl;
if (head[u]==-1) {
f[u][1]=w[u][1];
for (int i=2; i<=l; ++i) f[u][i]=-INF;
return ;
}
ll sum=0; f[u][1]=0;
for (int i=2; i<=l; ++i) f[u][i]=INF;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs(v);
ll maxn=-INF;
for (int j=1; j<=l; ++j) maxn=max(maxn, f[v][j]);
for (int j=1; j<l; ++j) f[u][j+1]=min(f[u][j+1], maxn-f[v][j]);
sum+=maxn;
}
for (int i=1; i<=l; ++i) f[u][i]=sum-f[u][i]+w[u][i];
}
void solve() {
for (int i=1,u; i<=q; ++i) {
u=read();
for (int j=1; j<=l; ++j) w[u][j]=read();
dfs(1);
ll ans=-INF;
for (int j=1; j<=l; ++j) ans=max(ans, f[1][j]);
printf("%lld\n", ans);
}
}
}
namespace task1{
random_device seed;
mt19937 rnd(seed());
struct matrix{
int n, m;
ll a[5][5];
matrix() {n=0; m=0; memset(a, -0x3f, sizeof(a));}
matrix(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
void resize(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
inline ll* operator [] (int t) {return a[t];}
void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<setw(3)<<a[i][j]<<' '; cout<<endl;}cout<<endl;}
void random() {for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) a[i][j]=rnd();}
matrix operator * (matrix b) {
matrix ans(n, b.m);
for (int i=1; i<=n; ++i)
for (int k=1; k<=m; ++k)
for (int j=1; j<=b.m; ++j)
ans[i][j]=max(ans[i][j], a[i][k]+b[k][j]);
return ans;
}
bool operator == (matrix b) {
if (n!=b.n||m!=b.m) return 0;
for (int i=1; i<=n; ++i)
for (int j=1; j<=m; ++j)
if (a[i][j]!=b[i][j]) return 0;
return 1;
}
}f[N], val[N<<2], tem;
int tl[N<<2], tr[N<<2];
#define tl(p) tl[p]
#define tr(p) tr[p]
#define pushup(p) val[p]=val[p<<1|1]*val[p<<1]
void build(int p, int l, int r) {
// cout<<"build: "<<p<<' '<<l<<' '<<r<<endl;
tl(p)=l; tr(p)=r;
if (l==r) {val[p]=f[l]; return ;}
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
void upd(int p, int pos) {
if (tl(p)==tr(p)) {val[p]=f[tl(p)]; return ;}
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) upd(p<<1, pos);
else upd(p<<1|1, pos);
pushup(p);
}
void rebuild(int t) {
f[t].resize(l, l);
for (int i=1; i<=l; ++i) f[t][i][1]=w[t][1];
for (int i=2; i<=l; ++i) f[t][i-1][i]=w[t][i];
}
void solve() {
// cout<<double(sizeof(f))/1000/1000<<endl;
for (int i=1; i<=n; ++i) rebuild(i);
build(1, 1, n);
for (int i=1,u; i<=q; ++i) {
u=read();
for (int j=1; j<=l; ++j) w[u][j]=read();
rebuild(u); upd(1, u);
tem.resize(1, l); tem[1][l]=0;
tem=tem*val[1];
ll ans=-INF;
for (int j=1; j<=l; ++j) ans=max(ans, tem[1][j]);
printf("%lld\n", ans);
}
}
}
namespace task2{
ll f[N][5];
void dfs(int u) {
// cout<<"dfs: "<<u<<endl;
if (head[u]==-1) {
f[u][1]=w[u][1];
for (int i=2; i<=l; ++i) f[u][i]=-INF;
return ;
}
ll sum=0; f[u][1]=0;
for (int i=2; i<=l; ++i) f[u][i]=INF;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
dfs(v);
ll maxn=-INF;
for (int j=1; j<=l; ++j) maxn=max(maxn, f[v][j]);
for (int j=1; j<l; ++j) f[u][j+1]=min(f[u][j+1], maxn-f[v][j]);
sum+=maxn;
}
for (int i=1; i<=l; ++i) f[u][i]=sum-f[u][i]+w[u][i];
}
void rebuild(int u) {
if (head[u]==-1) {
f[u][1]=w[u][1];
for (int i=2; i<=l; ++i) f[u][i]=-INF;
return ;
}
ll sum=0; f[u][1]=0;
for (int i=2; i<=l; ++i) f[u][i]=INF;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
ll maxn=-INF;
for (int j=1; j<=l; ++j) maxn=max(maxn, f[v][j]);
for (int j=1; j<l; ++j) f[u][j+1]=min(f[u][j+1], maxn-f[v][j]);
sum+=maxn;
}
for (int i=1; i<=l; ++i) f[u][i]=sum-f[u][i]+w[u][i];
}
void solve() {
dfs(1);
for (int i=1,u; i<=q; ++i) {
u=read();
for (int j=1; j<=l; ++j) w[u][j]=read();
while (u) rebuild(u), u=back[u];
ll ans=-INF;
for (int j=1; j<=l; ++j) ans=max(ans, f[1][j]);
printf("%lld\n", ans);
}
}
}
namespace task{
ll sum[N];
multiset<ll> lit[N][5];
int siz[N], msiz[N], mson[N], dep[N], top[N], btm[N], id[N], rk[N], tot;
struct matrix{
int n, m;
ll a[5][5];
matrix() {n=0; m=0; memset(a, -0x3f, sizeof(a));}
matrix(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
void resize(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
inline ll* operator [] (int t) {return a[t];}
void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<setw(3)<<a[i][j]<<' '; cout<<endl;}cout<<endl;}
matrix operator * (matrix b) {
matrix ans(n, b.m);
for (int i=1; i<=n; ++i)
for (int k=1; k<=m; ++k)
for (int j=1; j<=b.m; ++j)
ans[i][j]=max(ans[i][j], a[i][k]+b[k][j]);
return ans;
}
}f[N], val[N<<2], tem;
int tl[N<<2], tr[N<<2];
#define tl(p) tl[p]
#define tr(p) tr[p]
#define pushup(p) val[p]=val[p<<1|1]*val[p<<1]
void build(int p, int l, int r) {
tl(p)=l; tr(p)=r;
if (l==r) return ;
int mid=(l+r)>>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
void upd(int p, int pos) {
if (tl(p)==tr(p)) {val[p]=f[rk[tl(p)]]; return ;}
int mid=(tl(p)+tr(p))>>1;
if (pos<=mid) upd(p<<1, pos);
else upd(p<<1|1, pos);
pushup(p);
}
matrix query(int p, int l, int r) {
if (l<=tl(p)&&r>=tr(p)) return val[p];
int mid=(tl(p)+tr(p))>>1;
if (l<=mid&&r>mid) return query(p<<1|1, l, r)*query(p<<1, l, r);
else if (l<=mid) return query(p<<1, l, r);
else return query(p<<1|1, l, r);
}
void build(int t) {
f[t].resize(l, l);
for (int i=1; i<=l; ++i) for (int j=1; j<=l; ++j) f[t][i][j]=sum[t]+w[t][j];
for (int i=2; i<=l; ++i)
for (int j=1; j<=l; ++j)
if (j==i-1) f[t][j][i]+=max(*lit[t][i-1].rbegin(), 0ll);
else f[t][j][i]+=*lit[t][i-1].rbegin();
upd(1, id[t]);
}
void rebuild(int u) {
int t=top[u];
tem.resize(1, l); tem[1][l]=0;
tem=tem*query(1, id[t], id[btm[t]]);
ll maxn=-INF;
for (int i=1; i<=l; ++i) maxn=max(maxn, tem[1][i]);
sum[back[t]]-=maxn;
for (int i=1; i<=l; ++i) lit[back[t]][i].erase(lit[back[t]][i].find(tem[1][i]-maxn));
build(u);
tem.resize(1, l); tem[1][l]=0;
tem=tem*query(1, id[t], id[btm[t]]);
maxn=-INF;
for (int i=1; i<=l; ++i) maxn=max(maxn, tem[1][i]);
sum[back[t]]+=maxn;
for (int i=1; i<=l; ++i) lit[back[t]][i].insert(tem[1][i]-maxn);
}
void dfs1(int u, int fa) {
siz[u]=1;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa) continue;
dep[v]=dep[u]+1;
dfs1(v, u);
siz[u]+=siz[v];
if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
}
}
void dfs2(int u, int fa, int t) {
// cout<<"dfs2: "<<u<<' '<<fa<<' '<<t<<endl;
top[u]=t;
rk[id[u]=++tot]=u;
if (!mson[u]) {btm[t]=u; return ;}
dfs2(mson[u], u, t);
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa || v==mson[u]) continue;
dfs2(v, u, v);
}
}
void dfs3(int u, int fa) {
// cout<<"dfs3: "<<u<<' '<<fa<<endl;
for (int i=head[u],v; ~i; i=e[i].next) {
v = e[i].to;
if (v==fa || v==mson[u]) continue;
dfs3(v, u);
}
if (mson[u]) dfs3(mson[u], u);
build(u);
// cout<<"u: "<<u<<endl;
// cout<<"sum: "<<sum[u]<<endl;
// cout<<"mx: "<<*lit[u][1].rbegin()<<endl;
// f[u].put();
if (u==top[u]) {
tem.resize(1, l); tem[1][l]=0;
tem=tem*query(1, id[u], id[btm[u]]);
ll maxn=-INF;
for (int i=1; i<=l; ++i) maxn=max(maxn, tem[1][i]);
sum[back[u]]+=maxn;
for (int i=1; i<=l; ++i) lit[back[u]][i].insert(tem[1][i]-maxn);
}
}
void solve() {
// cout<<double(sizeof(f))/1000/1000<<endl;
for (int i=1; i<=n; ++i) for (int j=0; j<=l; ++j) lit[i][j].insert(-INF);
dep[1]=1; dfs1(1, 0); build(1, 1, n); dfs2(1, 0, 1); dfs3(1, 0);
// cout<<"top: "; for (int i=1; i<=n; ++i) cout<<top[i]<<' '; cout<<endl;
// cout<<"id: "; for (int i=1; i<=n; ++i) cout<<id[i]<<' '; cout<<endl;
for (int i=1,u; i<=q; ++i) {
u=read();
for (int j=1; j<=l; ++j) w[u][j]=read();
for (; u; u=back[top[u]]) rebuild(u);
tem.resize(1, l); tem[1][l]=0;
tem=tem*query(1, id[1], id[btm[1]]);
ll ans=-INF;
for (int j=1; j<=l; ++j) ans=max(ans, tem[1][j]);
printf("%lld\n", ans);
}
}
}
signed main()
{
freopen("decompose.in", "r", stdin);
freopen("decompose.out", "w", stdout);
n=read(); q=read(); l=read();
memset(head, -1, sizeof(head));
bool ischain=1;
for (int i=2; i<=n; ++i) {
add(back[i]=read(), i);
if (back[i]!=i-1) ischain=0;
}
for (int i=1; i<=n; ++i) for (int j=1; j<=l; ++j) w[i][j]=read();
// force::solve();
// task1::solve();
// task2::solve();
// if (ischain) task1::solve();
// else task2::solve();
task::solve();
return 0;
}