题目链接
- 理论上答案是可以超过int存储范围的,反正没有这种数据,我不管了
点击查看代码
#include <bits/stdc++.h>
using namespace std;
vector<int>a[200005];
int c[200005],w[200005];
int f[200005],g[200005];
int s[200005],n;
struct t1
{
int l,r;
int sum,bj;
}t[200000*19+5];
void spread(int p,int l,int r)
{
int mid=(l+r)>>1;
if(t[p].l)
{
if(l==mid)
{
t[t[p].l].sum+=t[p].bj;
}
else
{
t[t[p].l].bj+=t[p].bj;
}
}
if(t[p].r)
{
if(mid+1==r)
{
t[t[p].r].sum+=t[p].bj;
}
else
{
t[t[p].r].bj+=t[p].bj;
}
}
t[p].bj=0;
}
int ask(int p,int l,int r,int x)
{
if(p==0)
{
return INT_MIN;
}
spread(p,l,r);
if(l==r)
{
return t[p].sum;
}
int mid=(l+r)>>1;
if(x<=mid)
{
return ask(t[p].l,l,mid,x);
}
else
{
return ask(t[p].r,mid+1,r,x);
}
}
int cnt;
int New()
{
cnt++;
t[cnt].l=t[cnt].r=0;
t[cnt].sum=INT_MIN;
t[cnt].bj=0;
return cnt;
}
void change(int &p,int l,int r,int id,int y)
{
if(p==0)
{
p=New();
}
if(l==r)
{
t[p].sum=y;
return;
}
int mid=(l+r)>>1;
if(id<=mid)
{
change(t[p].l,l,mid,id,y);
}
else
{
change(t[p].r,mid+1,r,id,y);
}
}
int merge(int p,int q,int l,int r)
{
if(!q)
{
return p;
}
spread(q,l,r);
if(!p)
{
return q;
}
spread(p,l,r);
if(l==r)
{
t[p].sum=max(t[p].sum,t[q].sum);
}
else
{
int mid=(l+r)>>1;
t[p].l=merge(t[p].l,t[q].l,l,mid);
t[p].r=merge(t[p].r,t[q].r,mid+1,r);
}
return p;
}
int v;
void dfs(int p,int n1,int fa)
{
int va=ask(p,1,n,c[n1]);
if(va!=INT_MIN)
{
f[p]=max(f[p],w[n1]+v+g[n1]+va);
}
for(int i=0;i<a[n1].size();i++)
{
if(a[n1][i]!=fa)
{
v=v+(g[n1]-f[a[n1][i]]);
dfs(p,a[n1][i],n1);
v=v-(g[n1]-f[a[n1][i]]);
}
}
}
void dp(int n1,int fa)
{
s[n1]=1;
int h=0;
for(int i=0;i<a[n1].size();i++)
{
if(a[n1][i]!=fa)
{
dp(a[n1][i],n1);
s[n1]+=s[a[n1][i]];
g[n1]+=f[a[n1][i]];
if(s[a[n1][i]]>s[h])
{
h=a[n1][i];
}
}
}
f[n1]=g[n1];
change(n1,1,n,c[n1],w[n1]+g[n1]);
if(h==0)
{
return;
}
f[n1]=max(f[n1],ask(h,1,n,c[n1])+w[n1]+g[n1]-f[h]);
t[h].bj+=(g[n1]-f[h]);
n1=merge(n1,h,1,n);
for(int i=0;i<a[n1].size();i++)
{
if(a[n1][i]!=fa&&a[n1][i]!=h)
{
v=-f[a[n1][i]];
dfs(n1,a[n1][i],n1);
t[a[n1][i]].bj+=(g[n1]-f[a[n1][i]]);
n1=merge(n1,a[n1][i],1,n);
}
}
}
int main()
{
int T;
cin>>T;
while(T--)
{
cin>>n;
cnt=0;
for(int i=1;i<=n;i++)
{
a[i].clear();
c[i]=read1();
g[i]=0;
i=New();
}
for(int i=1;i<=n;i++)
{
w[i]=read1();
}
for(int i=1;i<n;i++)
{
int u,v;
u=read1();
v=read1();
a[u].push_back(v);
a[v].push_back(u);
}
dp(1,0);
cout<<f[1]<<endl;
}
return 0;
}