Solution -「CF 1303G」Sum of Prefix Sums
Description
Link.
对于一棵树,选出一条链 \((u,v)\),把链上结点从 \(u\) 到 \(v\) 放成一个 长度 \(l\) 的数组,使得 \(\sum_{i=1}^{l}\sum_{j=1}^{i}a_{j}\) 最大,\(a\) 是点权。
Solution
可以发现那个式子等价于 \(\sum_{i=1}^{l}ia_{i}\)。
考虑点分,设当前根为 \(x\)。选出来的 \(u,v\) 一定是叶子(点权为正),因为没有什么本质差别,所以可以一起算。我们把 \(x\) 在 \((u,v)\) 中的位置记作 \(o\),\((u,v)\) 的权值就为 \(\sum_{i=1}^{l}ia_{i}=\sum_{i=1}^{o}ia_{i}+l\sum_{i=o+1}^{l}a_{i}+\sum_{i=o+1}^{l}(i-l)a_{i}\),这是个一次函数,令 \(b_{1}=\sum_{i=1}^{l}ia_{i}=\sum_{i=1}^{o}ia_{i},b_{2}=\sum_{i=o+1}^{l}(i-l)a_{i},k=l\),得 \(\sum_{i=1}^{l}ia_{i}=k\times\sum_{i=o+1}^{l}a_{i}+b_{1}+b_{2}\)。
#include<bits/stdc++.h>
typedef long long ll;
#define sf(x) scanf("%d",&x)
#define ssf(x) scanf("%lld",&x)
struct Line {
ll k,b;
Line():k(0),b(0){}
Line(ll _k,ll _b):k(_k),b(_b){}
}lns[10000010];
std::vector<int> G[200010];
ll a[200010],ans,stk[6][200010];
// stk[0]: sum(i=1~l)i*a[i]
// stk[1]: sum(i=o+1~l)(i-l)*a[i]
// stk[2]: sum(i=o+1~l)a[i]
// stk[3]: all the nodes we passed and possible to be the final node
// stk[4]: l
// stk[5]: where to belong to
int n,szf[200010],tot,tr[800010],top,rt,del[200010],siz[200010],mxdep,dep[200010];
ll ff(ll x,int i){return lns[i].k*x+lns[i].b;}
ll getk(int i){return lns[i].k;}
bool chk(ll x,int i,int j){return ff(x,i)>ff(x,j);}
void ins(int l,int r,int now,int t)
{
if(l^r)
{
if(chk(l,t,tr[now]) && chk(r,t,tr[now])) tr[now]=t;
else if(chk(l,t,tr[now]) || chk(r,t,tr[now]))
{
int mid=(l+r)>>1;
if(chk(mid,t,tr[now])) tr[now]^=t^=tr[now]^=t;
if(chk(l,t,tr[now])) ins(l,mid,now<<1,t);
else ins(mid+1,r,now<<1|1,t);
}
}
else if(chk(l,t,tr[now])) tr[now]=t;
}
int find(int l,int r,int now,int t) // query line id
{
if(l^r)
{
int mid=(l+r)>>1,res;
if(mid>=t) res=find(l,mid,now<<1,t);
else res=find(mid+1,r,now<<1|1,t);
if(chk(t,res,tr[now])) return res;
else return tr[now];
}
else return tr[now];
}
void clear(int l,int r,int now)
{
int mid=(l+r)>>1;
tr[now]=0;
if(l^r) clear(l,mid,now<<1),clear(mid+1,r,now<<1|1);
}
void get_root(int now,int las,int all)
{
siz[now]=1;
szf[now]=0;
for(int to:G[now])
{
if((to^las) && !del[to])
{
get_root(to,now,all);
siz[now]+=siz[to];
szf[now]=std::max(szf[now],siz[to]);
}
}
szf[now]=std::max(szf[now],all-siz[now]);
if(szf[now]<szf[rt]) rt=now;
}
void get_value(int now,ll prf0,ll prf1,ll prf2,int wr,int las)
{
if((now^rt) && !wr) wr=now;
mxdep=std::max(mxdep,dep[now]=dep[las]+1);
bool lef=1;
for(int to:G[now]) if((to^las) && !del[to])
lef=0,get_value(to,prf0+prf2+a[to],prf1+a[to]*dep[now],prf2+a[to],wr,now);
if(lef)
++top,stk[0][top]=prf0,stk[1][top]=prf1,stk[2][top]=prf2-a[rt],
stk[3][top]=now,stk[4][top]=dep[now],stk[5][top]=wr;
}
void get_ans(int now)
{
del[now]=1;
top=mxdep=0;
get_value(now,a[now],0,a[now],0,0);
++top;
stk[0][top]=a[now];
stk[1][top]=stk[2][top]=stk[5][top]=0;
stk[3][top]=now;
stk[4][top]=1;
stk[5][top+1]=stk[5][0]=-1;
clear(1,mxdep,1);
int i=1,j;
while(i<=top)
{
j=i;
while(stk[5][i]==stk[5][j]) ans=std::max(ff(stk[4][j],find(1,mxdep,1,stk[4][j]))+stk[0][j],ans),++j;
j=i;
while(stk[5][i]==stk[5][j]) lns[++tot]=Line(stk[2][j],stk[1][j]),ins(1,mxdep,1,tot),++j;
i=j;
}
clear(1,mxdep,1);
i=top;
while(i)
{
j=i;
while(stk[5][i]==stk[5][j]) ans=std::max(ff(stk[4][j],find(1,mxdep,1,stk[4][j]))+stk[0][j],ans),--j;
j=i;
while(stk[5][i]==stk[5][j]) lns[++tot]=Line(stk[2][j],stk[1][j]),ins(1,mxdep,1,tot),--j;
i=j;
}
for(int to:G[now]) if(!del[to]) rt=0,get_root(to,now,siz[to]),get_ans(rt);
}
int main()
{
sf(n);
for(int i=1,x,y;i<n;++i)
{
sf(x),sf(y);
G[x].emplace_back(y);
G[y].emplace_back(x);
}
for(int i=1;i<=n;++i) ssf(a[i]);
szf[0]=n+1;
get_root(1,0,n);
get_ans(rt);
printf("%lld\n",ans);
return 0;
}