bzoj4712 洪水(动态dp)

看起来很模板的一个题啊
qwq
但是我还是wei

题目要求的是一个把根节点和所有叶子断开连接的最小花费。

还是想一个比较\(naive\)的做法

我们令\(dp1[i]\)表示,在\(i\)的子树内,把叶子全都隔断的最小代价,那么

\[dp1[i]=max(\sum dp1[p],val[i]) \]

但是这样暴力并不能通过这个题。
考虑怎么来优化更新的过程呢。

由于是树上问题,根据套路,我们对原树进行树链剖分。
\(dp[i]\)表示除去重儿子的所有\(dp1[p]\)的和。
那么我们重新定义矩阵乘法\(c[i][j]=max(c[i][j],a[i][k]+b[k][j])\)之后,就可以通过矩阵来完成转移了

我们令\(g\)表示包含重儿子的\(ans\),然后令\(f\)表示上述的\(dp[i]\)
\(p\)表示重儿子。

那么不难发现

g[p] 0
g[p] 0

f[i] +inf
val[i] 0

做矩阵乘法之后,就能得到

g[i] 0
g[i] 0

那我们可以直接用线段树来维护矩阵乘法来进行快速修改和求值了。

但是有一个需要注意的地方就是,对于重链链尾的所有元素的对应转移矩阵要特殊处理,因为他们的\(g\)是等于\(f\)的。

那么修改的时候,先进行单点修改(要特判链尾)
然后依次修改每一条重链的\(fa\)的转移矩阵即可。

qwq一开始有很多地方都没有想明白,就很wei
细节就直接看代码吧

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<cmath>
#include<map>
#include<set>
#define pb push_back
#define mk make_pair
#define ll long long
#define lson ch[x][0]
#define rson ch[x][1]
#define int long long

using namespace std;

inline int read()
{
  int x=0,f=1;char ch=getchar();
  while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
  while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return x*f;
}

const int maxn = 2e5+1e2;
const int maxm = 2*maxn;
const int inf = 1e18;

struct Ju{
	int x,y;
	int a[3][3];
	Ju operator * (Ju b)
	{
		Ju ans;
		ans.x=ans.y=2;
		memset(ans.a,0x3f,sizeof(ans.a));
		for (int i=1;i<=2;i++)
		  for(int j=1;j<=2;j++)
		    for (int k=1;k<=y;k++)
		    {
		    	ans.a[i][j]=min(ans.a[i][j],a[i][k]+b.a[k][j]);
			}
		return ans;
	}
};

int point[maxn],nxt[maxm],to[maxm],val[maxn];
int cnt,n,m;
Ju pre[maxn];
Ju f[4*maxn];
int top[maxn],newnum[maxn],tail[maxn];
int fa[maxn],son[maxn],size[maxn];
int q;
int back[maxn];
int dp1[maxn],dp[maxn];
int sum[maxn];

//dp[x] doesn't include son[x]
//这个dp数组实质上就是一个sum的形式。

void addedge(int x,int y)
{
	nxt[++cnt]=point[x];
	to[cnt]=y;
	point[x]=cnt;
}

void up(int root)
{
	f[root]=f[2*root+1]*f[2*root];
}

void build(int root,int l,int r)
{
	if (l==r)
	{
		int ymh = back[l];
		f[root].x=f[root].y=2;
		if (tail[top[ymh]]==ymh)
		{
			f[root].a[1][1]=f[root].a[2][1]=dp1[ymh];
		}
		else
		{
			f[root].a[1][1]=dp[ymh];
			f[root].a[1][2]=inf;
			f[root].a[2][1]=val[ymh];
		}
		return;
	}
	int mid = l+r >> 1;
	build(2*root,l,mid);
	build(2*root+1,mid+1,r);
	up(root);
}

void update(int root,int l,int r,int x,Ju p)
{
	if(l==r)
	{
		f[root]=p;
		return;
	}
	int mid = l+r >> 1;
	if (x<=mid) update(2*root,l,mid,x,p);
	else update(2*root+1,mid+1,r,x,p);
	up(root);
}

Ju query(int root,int l,int r,int x,int y)
{
	if (x<=l && r<=y)
	{
		return f[root];
	}
	int mid = l+r >> 1;
	if (x>mid) return query(2*root+1,mid+1,r,x,y);
	if (y<=mid) return query(2*root,l,mid,x,y);
	return query(2*root+1,mid+1,r,x,y)*query(2*root,l,mid,x,y);
}

void dfs1(int x,int faa)
{
	size[x]=1;
	int maxson=-1;
	for (int i=point[x];i;i=nxt[i])
	{
		int p = to[i];
		if (p==faa) continue;
		fa[p]=x;
		dfs1(p,x);
		size[x]+=size[p];
		if (size[p]>maxson)
		{
			maxson=size[p];
			son[x]=p;
		}
	}
}

int tot;

void dfs2(int x,int chain)
{
	top[x]=chain;
	tail[chain]=x;
	newnum[x]=++tot;
	back[tot]=x;
	if (!son[x]) return;
	dfs2(son[x],chain);
	for (int i=point[x];i;i=nxt[i])
	{
		int p = to[i];
		if (!newnum[p]) dfs2(p,p);
	}
}

void solve(int x,int fa)
{
	for (int i=point[x];i;i=nxt[i])
	{
		int p = to[i];
		if (p==fa) continue;
		solve(p,x);
		sum[x]+=dp1[p];
	}
	if (!son[x]) dp1[x]=val[x];
	else dp1[x]=min(sum[x],val[x]);
	dp[x]=sum[x]-dp1[son[x]];
}

void modify(int x,int y)
{
	Ju tmp = query(1,1,n,newnum[x],newnum[x]);
	tmp.a[2][1]+=y;
	val[x]+=y;
	if (tail[top[x]]==x) tmp.a[1][1]+=y;
	update(1,1,n,newnum[x],tmp);
	for (int now = top[x];now!=1;now=top[now])
	{
		int faa = fa[now];
		Ju ymh = query(1,1,n,newnum[faa],newnum[faa]);
		Ju lyf = query(1,1,n,newnum[now],newnum[tail[top[now]]]);
		ymh.a[1][1]+=(lyf.a[1][1]-pre[now].a[1][1]);
		update(1,1,n,newnum[faa],ymh);
		pre[now]=lyf;
		now = fa[now]; 
	} 
}

signed main()
{
  n=read();
  for (int i=1;i<=n;i++) val[i]=read();
  for (int i=1;i<n;i++)
  {
  	int x=read(),y=read();
  	addedge(x,y);
  	addedge(y,x);
  }
  dfs1(1,0);
  dfs2(1,1);
  solve(1,0);
  build(1,1,n);
  for (int i=1;i<=n;i++)
  {
  	 pre[i]=query(1,1,n,newnum[i],newnum[tail[top[i]]]);
  }
  q=read(); 
  for (int i=1;i<=q;i++)
  {
  	 char s[10];
  	 scanf("%s",s+1);
  	 if (s[1]=='Q')
  	 {
  	 	int x=read();
  	 	Ju now = query(1,1,n,newnum[x],newnum[tail[top[x]]]);
  	 	cout<<now.a[1][1]<<"\n";
	 }
	 else
	 {
	 	int x=read(),y=read();
	 	modify(x,y);
	 }
  }
  return 0;
}

posted @ 2019-01-21 09:15  y_immortal  阅读(238)  评论(0编辑  收藏  举报