题解 Decompose

传送门

首先写出转移方程:

\[f_{i, 1}=\sum\limits_v\max\{f_{v, 1..m}\}+w_{i, 1} \]

\[f_{i, j(j>1)}=(\sum\limits_v\max\{f_{v, 1...m}\})-\min\limits_v(\max\{f_{v, 1...m}\}-f_{v, j-1})+w_{i, j} \]

发现下面这个转移带个 min 很讨厌
那么可以换一种使用 max 表示的写法

  • 别觉得动态 DP 里又有 min 又有 max 就没法写了,看看能不能统一用同一种表示

\[f_{i, j(j>1)}=\max\{f_{v, j-1}-\max\{f_{v, k}\}\}+\sum\limits_v\max\limits_j\{f_{v, j}\}+w_{i, j} \]

将一个点的轻儿子视作常数
image
那么就是要用矩阵实现重儿子和一些常数比较大小的过程了
没有轻儿子可以将轻儿子的值视为 -inf
复杂度 \(O(n\log^2nL^3)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
#define N 100010
#define ll long long
//#define int long long

char buf[1<<21], *p1=buf, *p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++)
inline int read() {
	int ans=0, f=1; char c=getchar();
	while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();}
	while (isdigit(c)) {ans=(ans<<3)+(ans<<1)+(c^48); c=getchar();}
	return ans*f;
}

int n, q, l;
ll w[N][5];
int head[N], back[N], ecnt;
struct edge{int to, next;}e[N<<1];
inline void add(int s, int t) {e[++ecnt]={t, head[s]}; head[s]=ecnt;}

namespace force{
	ll f[N][5];
	void dfs(int u) {
		// cout<<"dfs: "<<u<<endl;
		if (head[u]==-1) {
			f[u][1]=w[u][1];
			for (int i=2; i<=l; ++i) f[u][i]=-INF;
			return ;
		}
		ll sum=0; f[u][1]=0;
		for (int i=2; i<=l; ++i) f[u][i]=INF;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			dfs(v);
			ll maxn=-INF;
			for (int j=1; j<=l; ++j) maxn=max(maxn, f[v][j]);
			for (int j=1; j<l; ++j) f[u][j+1]=min(f[u][j+1], maxn-f[v][j]);
			sum+=maxn;
		}
		for (int i=1; i<=l; ++i) f[u][i]=sum-f[u][i]+w[u][i];
	}
	void solve() {
		for (int i=1,u; i<=q; ++i) {
			u=read();
			for (int j=1; j<=l; ++j) w[u][j]=read();
			dfs(1);
			ll ans=-INF;
			for (int j=1; j<=l; ++j) ans=max(ans, f[1][j]);
			printf("%lld\n", ans);
		}
	}
}

namespace task1{
	random_device seed;
	mt19937 rnd(seed());
	struct matrix{
		int n, m;
		ll a[5][5];
		matrix() {n=0; m=0; memset(a, -0x3f, sizeof(a));}
		matrix(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
		void resize(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
		inline ll* operator [] (int t) {return a[t];}
		void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<setw(3)<<a[i][j]<<' '; cout<<endl;}cout<<endl;}
		void random() {for (int i=1; i<=n; ++i) for (int j=1; j<=m; ++j) a[i][j]=rnd();}
		matrix operator * (matrix b) {
			matrix ans(n, b.m);
			for (int i=1; i<=n; ++i)
				for (int k=1; k<=m; ++k)
					for (int j=1; j<=b.m; ++j)
						ans[i][j]=max(ans[i][j], a[i][k]+b[k][j]);
			return ans;
		}
		bool operator == (matrix b) {
			if (n!=b.n||m!=b.m) return 0;
			for (int i=1; i<=n; ++i)
				for (int j=1; j<=m; ++j)
					if (a[i][j]!=b[i][j]) return 0;
			return 1;
		}
	}f[N], val[N<<2], tem;
	int tl[N<<2], tr[N<<2];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define pushup(p) val[p]=val[p<<1|1]*val[p<<1]
	void build(int p, int l, int r) {
		// cout<<"build: "<<p<<' '<<l<<' '<<r<<endl;
		tl(p)=l; tr(p)=r;
		if (l==r) {val[p]=f[l]; return ;}
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	void upd(int p, int pos) {
		if (tl(p)==tr(p)) {val[p]=f[tl(p)]; return ;}
		int mid=(tl(p)+tr(p))>>1;
		if (pos<=mid) upd(p<<1, pos);
		else upd(p<<1|1, pos);
		pushup(p);
	}
	void rebuild(int t) {
		f[t].resize(l, l);
		for (int i=1; i<=l; ++i) f[t][i][1]=w[t][1];
		for (int i=2; i<=l; ++i) f[t][i-1][i]=w[t][i];
	}
	void solve() {
		// cout<<double(sizeof(f))/1000/1000<<endl;
		for (int i=1; i<=n; ++i) rebuild(i);
		build(1, 1, n);
		for (int i=1,u; i<=q; ++i) {
			u=read();
			for (int j=1; j<=l; ++j) w[u][j]=read();
			rebuild(u); upd(1, u);
			tem.resize(1, l); tem[1][l]=0;
			tem=tem*val[1];
			ll ans=-INF;
			for (int j=1; j<=l; ++j) ans=max(ans, tem[1][j]);
			printf("%lld\n", ans);
		}
	}
}

namespace task2{
	ll f[N][5];
	void dfs(int u) {
		// cout<<"dfs: "<<u<<endl;
		if (head[u]==-1) {
			f[u][1]=w[u][1];
			for (int i=2; i<=l; ++i) f[u][i]=-INF;
			return ;
		}
		ll sum=0; f[u][1]=0;
		for (int i=2; i<=l; ++i) f[u][i]=INF;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			dfs(v);
			ll maxn=-INF;
			for (int j=1; j<=l; ++j) maxn=max(maxn, f[v][j]);
			for (int j=1; j<l; ++j) f[u][j+1]=min(f[u][j+1], maxn-f[v][j]);
			sum+=maxn;
		}
		for (int i=1; i<=l; ++i) f[u][i]=sum-f[u][i]+w[u][i];
	}
	void rebuild(int u) {
		if (head[u]==-1) {
			f[u][1]=w[u][1];
			for (int i=2; i<=l; ++i) f[u][i]=-INF;
			return ;
		}
		ll sum=0; f[u][1]=0;
		for (int i=2; i<=l; ++i) f[u][i]=INF;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			ll maxn=-INF;
			for (int j=1; j<=l; ++j) maxn=max(maxn, f[v][j]);
			for (int j=1; j<l; ++j) f[u][j+1]=min(f[u][j+1], maxn-f[v][j]);
			sum+=maxn;
		}
		for (int i=1; i<=l; ++i) f[u][i]=sum-f[u][i]+w[u][i];
	}
	void solve() {
		dfs(1);
		for (int i=1,u; i<=q; ++i) {
			u=read();
			for (int j=1; j<=l; ++j) w[u][j]=read();
			while (u) rebuild(u), u=back[u];
			ll ans=-INF;
			for (int j=1; j<=l; ++j) ans=max(ans, f[1][j]);
			printf("%lld\n", ans);
		}
	}
}

namespace task{
	ll sum[N];
	multiset<ll> lit[N][5];
	int siz[N], msiz[N], mson[N], dep[N], top[N], btm[N], id[N], rk[N], tot;
	struct matrix{
		int n, m;
		ll a[5][5];
		matrix() {n=0; m=0; memset(a, -0x3f, sizeof(a));}
		matrix(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
		void resize(int x, int y) {n=x; m=y; memset(a, -0x3f, sizeof(a));}
		inline ll* operator [] (int t) {return a[t];}
		void put() {for (int i=1; i<=n; ++i) {for (int j=1; j<=m; ++j) cout<<setw(3)<<a[i][j]<<' '; cout<<endl;}cout<<endl;}
		matrix operator * (matrix b) {
			matrix ans(n, b.m);
			for (int i=1; i<=n; ++i)
				for (int k=1; k<=m; ++k)
					for (int j=1; j<=b.m; ++j)
						ans[i][j]=max(ans[i][j], a[i][k]+b[k][j]);
			return ans;
		}
	}f[N], val[N<<2], tem;
	int tl[N<<2], tr[N<<2];
	#define tl(p) tl[p]
	#define tr(p) tr[p]
	#define pushup(p) val[p]=val[p<<1|1]*val[p<<1]
	void build(int p, int l, int r) {
		tl(p)=l; tr(p)=r;
		if (l==r) return ;
		int mid=(l+r)>>1;
		build(p<<1, l, mid);
		build(p<<1|1, mid+1, r);
		pushup(p);
	}
	void upd(int p, int pos) {
		if (tl(p)==tr(p)) {val[p]=f[rk[tl(p)]]; return ;}
		int mid=(tl(p)+tr(p))>>1;
		if (pos<=mid) upd(p<<1, pos);
		else upd(p<<1|1, pos);
		pushup(p);
	}
	matrix query(int p, int l, int r) {
		if (l<=tl(p)&&r>=tr(p)) return val[p];
		int mid=(tl(p)+tr(p))>>1;
		if (l<=mid&&r>mid) return query(p<<1|1, l, r)*query(p<<1, l, r);
		else if (l<=mid) return query(p<<1, l, r);
		else return query(p<<1|1, l, r);
	}
	void build(int t) {
		f[t].resize(l, l);
		for (int i=1; i<=l; ++i) for (int j=1; j<=l; ++j) f[t][i][j]=sum[t]+w[t][j];
		for (int i=2; i<=l; ++i)
			for (int j=1; j<=l; ++j)
				if (j==i-1) f[t][j][i]+=max(*lit[t][i-1].rbegin(), 0ll);
				else f[t][j][i]+=*lit[t][i-1].rbegin();
		upd(1, id[t]);
	}
	void rebuild(int u) {
		int t=top[u];
		tem.resize(1, l); tem[1][l]=0;
		tem=tem*query(1, id[t], id[btm[t]]);
		ll maxn=-INF;
		for (int i=1; i<=l; ++i) maxn=max(maxn, tem[1][i]);
		sum[back[t]]-=maxn;
		for (int i=1; i<=l; ++i) lit[back[t]][i].erase(lit[back[t]][i].find(tem[1][i]-maxn));
		build(u);
		tem.resize(1, l); tem[1][l]=0;
		tem=tem*query(1, id[t], id[btm[t]]);
		maxn=-INF;
		for (int i=1; i<=l; ++i) maxn=max(maxn, tem[1][i]);
		sum[back[t]]+=maxn;
		for (int i=1; i<=l; ++i) lit[back[t]][i].insert(tem[1][i]-maxn);
	}
	void dfs1(int u, int fa) {
		siz[u]=1;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa) continue;
			dep[v]=dep[u]+1;
			dfs1(v, u);
			siz[u]+=siz[v];
			if (siz[v]>msiz[u]) msiz[u]=siz[v], mson[u]=v;
		}
	}
	void dfs2(int u, int fa, int t) {
		// cout<<"dfs2: "<<u<<' '<<fa<<' '<<t<<endl;
		top[u]=t;
		rk[id[u]=++tot]=u;
		if (!mson[u]) {btm[t]=u; return ;}
		dfs2(mson[u], u, t);
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa || v==mson[u]) continue;
			dfs2(v, u, v);
		}
	}
	void dfs3(int u, int fa) {
		// cout<<"dfs3: "<<u<<' '<<fa<<endl;
		for (int i=head[u],v; ~i; i=e[i].next) {
			v = e[i].to;
			if (v==fa || v==mson[u]) continue;
			dfs3(v, u);
		}
		if (mson[u]) dfs3(mson[u], u);
		build(u);
		// cout<<"u: "<<u<<endl;
		// cout<<"sum: "<<sum[u]<<endl;
		// cout<<"mx: "<<*lit[u][1].rbegin()<<endl;
		// f[u].put();
		if (u==top[u]) {
			tem.resize(1, l); tem[1][l]=0;
			tem=tem*query(1, id[u], id[btm[u]]);
			ll maxn=-INF;
			for (int i=1; i<=l; ++i) maxn=max(maxn, tem[1][i]);
			sum[back[u]]+=maxn;
			for (int i=1; i<=l; ++i) lit[back[u]][i].insert(tem[1][i]-maxn);
		}
	}
	void solve() {
		// cout<<double(sizeof(f))/1000/1000<<endl;
		for (int i=1; i<=n; ++i) for (int j=0; j<=l; ++j) lit[i][j].insert(-INF);
		dep[1]=1; dfs1(1, 0); build(1, 1, n); dfs2(1, 0, 1); dfs3(1, 0);
		// cout<<"top: "; for (int i=1; i<=n; ++i) cout<<top[i]<<' '; cout<<endl;
		// cout<<"id: "; for (int i=1; i<=n; ++i) cout<<id[i]<<' '; cout<<endl;
		for (int i=1,u; i<=q; ++i) {
			u=read();
			for (int j=1; j<=l; ++j) w[u][j]=read();
			for (; u; u=back[top[u]]) rebuild(u);
			tem.resize(1, l); tem[1][l]=0;
			tem=tem*query(1, id[1], id[btm[1]]);
			ll ans=-INF;
			for (int j=1; j<=l; ++j) ans=max(ans, tem[1][j]);
			printf("%lld\n", ans);
		}
	}
}

signed main()
{
	freopen("decompose.in", "r", stdin);
	freopen("decompose.out", "w", stdout);

	n=read(); q=read(); l=read();
	memset(head, -1, sizeof(head));
	bool ischain=1;
	for (int i=2; i<=n; ++i) {
		add(back[i]=read(), i);
		if (back[i]!=i-1) ischain=0;
	}
	for (int i=1; i<=n; ++i) for (int j=1; j<=l; ++j) w[i][j]=read();
	// force::solve();
	// task1::solve();
	// task2::solve();
	// if (ischain) task1::solve();
	// else task2::solve();
	task::solve();

	return 0;
}
posted @ 2022-04-03 20:58  Administrator-09  阅读(2)  评论(0编辑  收藏  举报