Solution -「CF 1303G」Sum of Prefix Sums

Description

Link.

对于一棵树,选出一条链 \((u,v)\),把链上结点从 \(u\)\(v\) 放成一个 长度 \(l\) 的数组,使得 \(\sum_{i=1}^{l}\sum_{j=1}^{i}a_{j}\) 最大,\(a\) 是点权。

Solution

可以发现那个式子等价于 \(\sum_{i=1}^{l}ia_{i}\)

考虑点分,设当前根为 \(x\)。选出来的 \(u,v\) 一定是叶子(点权为正),因为没有什么本质差别,所以可以一起算。我们把 \(x\)\((u,v)\) 中的位置记作 \(o\)\((u,v)\) 的权值就为 \(\sum_{i=1}^{l}ia_{i}=\sum_{i=1}^{o}ia_{i}+l\sum_{i=o+1}^{l}a_{i}+\sum_{i=o+1}^{l}(i-l)a_{i}\),这是个一次函数,令 \(b_{1}=\sum_{i=1}^{l}ia_{i}=\sum_{i=1}^{o}ia_{i},b_{2}=\sum_{i=o+1}^{l}(i-l)a_{i},k=l\),得 \(\sum_{i=1}^{l}ia_{i}=k\times\sum_{i=o+1}^{l}a_{i}+b_{1}+b_{2}\)

#include<bits/stdc++.h>
typedef long long ll;
#define sf(x) scanf("%d",&x)
#define ssf(x) scanf("%lld",&x)
struct Line {
	ll k,b;
	Line():k(0),b(0){}
	Line(ll _k,ll _b):k(_k),b(_b){}
}lns[10000010];
std::vector<int> G[200010];
ll a[200010],ans,stk[6][200010];
// stk[0]: sum(i=1~l)i*a[i]
// stk[1]: sum(i=o+1~l)(i-l)*a[i]
// stk[2]: sum(i=o+1~l)a[i]
// stk[3]: all the nodes we passed and possible to be the final node
// stk[4]: l
// stk[5]: where to belong to
int n,szf[200010],tot,tr[800010],top,rt,del[200010],siz[200010],mxdep,dep[200010];
ll ff(ll x,int i){return lns[i].k*x+lns[i].b;}
ll getk(int i){return lns[i].k;}
bool chk(ll x,int i,int j){return ff(x,i)>ff(x,j);}
void ins(int l,int r,int now,int t)
{
	if(l^r)
	{
		if(chk(l,t,tr[now]) && chk(r,t,tr[now]))	tr[now]=t;
		else if(chk(l,t,tr[now]) || chk(r,t,tr[now]))
		{
			int mid=(l+r)>>1;
			if(chk(mid,t,tr[now]))	tr[now]^=t^=tr[now]^=t;
			if(chk(l,t,tr[now]))	ins(l,mid,now<<1,t);
			else	ins(mid+1,r,now<<1|1,t); 
		}
	}
	else if(chk(l,t,tr[now]))	tr[now]=t;
}
int find(int l,int r,int now,int t) // query line id
{
	if(l^r)
	{
		int mid=(l+r)>>1,res;
		if(mid>=t)	res=find(l,mid,now<<1,t);
		else	res=find(mid+1,r,now<<1|1,t);
		if(chk(t,res,tr[now]))	return res;
		else	return tr[now];
	}
	else	return tr[now];
}
void clear(int l,int r,int now)
{
	int mid=(l+r)>>1;
	tr[now]=0;
	if(l^r)	clear(l,mid,now<<1),clear(mid+1,r,now<<1|1);
}
void get_root(int now,int las,int all)
{
	siz[now]=1;
	szf[now]=0;
	for(int to:G[now])
	{
		if((to^las) && !del[to])
		{
			get_root(to,now,all);
			siz[now]+=siz[to];
			szf[now]=std::max(szf[now],siz[to]);
		}
	}
	szf[now]=std::max(szf[now],all-siz[now]);
	if(szf[now]<szf[rt])	rt=now;
}
void get_value(int now,ll prf0,ll prf1,ll prf2,int wr,int las)
{
	if((now^rt) && !wr)	wr=now;
	mxdep=std::max(mxdep,dep[now]=dep[las]+1);
	bool lef=1;
	for(int to:G[now])	if((to^las) && !del[to])
		lef=0,get_value(to,prf0+prf2+a[to],prf1+a[to]*dep[now],prf2+a[to],wr,now);
	if(lef)
		++top,stk[0][top]=prf0,stk[1][top]=prf1,stk[2][top]=prf2-a[rt],
		stk[3][top]=now,stk[4][top]=dep[now],stk[5][top]=wr;
}
void get_ans(int now)
{
	del[now]=1;
	top=mxdep=0;
	get_value(now,a[now],0,a[now],0,0);
	++top;
	stk[0][top]=a[now];
	stk[1][top]=stk[2][top]=stk[5][top]=0;
	stk[3][top]=now;
	stk[4][top]=1;
	stk[5][top+1]=stk[5][0]=-1;
	clear(1,mxdep,1);
	int i=1,j;
	while(i<=top)
	{
		j=i;
		while(stk[5][i]==stk[5][j])	ans=std::max(ff(stk[4][j],find(1,mxdep,1,stk[4][j]))+stk[0][j],ans),++j;
		j=i;
		while(stk[5][i]==stk[5][j])	lns[++tot]=Line(stk[2][j],stk[1][j]),ins(1,mxdep,1,tot),++j;
		i=j;
	}
	clear(1,mxdep,1);
	i=top;
	while(i)
	{
		j=i;
		while(stk[5][i]==stk[5][j])	ans=std::max(ff(stk[4][j],find(1,mxdep,1,stk[4][j]))+stk[0][j],ans),--j;
		j=i;
		while(stk[5][i]==stk[5][j])	lns[++tot]=Line(stk[2][j],stk[1][j]),ins(1,mxdep,1,tot),--j;
		i=j;
	}
	for(int to:G[now])	if(!del[to])	rt=0,get_root(to,now,siz[to]),get_ans(rt);
}
int main()
{
	sf(n);
	for(int i=1,x,y;i<n;++i)
	{
		sf(x),sf(y);
		G[x].emplace_back(y);
		G[y].emplace_back(x);
	}
	for(int i=1;i<=n;++i)	ssf(a[i]);
	szf[0]=n+1;
	get_root(1,0,n);
	get_ans(rt);
	printf("%lld\n",ans);
	return 0;
}
posted @ 2021-05-14 12:26  cirnovsky  阅读(48)  评论(0编辑  收藏  举报