[CSP-S模拟测试]:射手座之日(dsu on tree)
题目传送门(内部题103)
输入格式
第一行一个数$n$,表示结点的个数。
第二行$n–1$个数,第$i$个数是$p[i+1]$。$p[i]$表示结点$i$的父亲是$p[i]$。数据保证$p[i]<i$。
第三行$n$个数,$a[1],a[2],...,a[n]$,表示关卡表。数据保证这是一个排列。
第四行$n$个数,$x[1],x[2],...,x[n]$,表示结点的权值。
输出格式
输出一个数表示答案。即对于所有可能的回合,你们能获得的总收益是多少。
数据范围与提示
对于$20\%$的数据,满足$n\leqslant 100$。
对于$40\%$的数据,满足$n\leqslant 2,000$。
对于$60\%$的数据,满足$n\leqslant 50,000$。
对于另外$20\%$的数据,排列$a[i]$是用如下的算法生成的:从一号点开始对树做$dfs$,到达一个节点的时候输出这个结点。
对于$100\%$的数据,满足$1\leqslant n\leqslant 200,000,0\leqslant x[i]\leqslant 100,000,p[i]<i$,$a[i]$是一个排列。
题解
其实另外$20\%$的数据就是正解的一个引导。
显然对于这个部分分无非就是输出每一个点的子节点的$size$相互之间的乘积和乘上这个点的$x$值即可。
那么考虑正解。
用$dsu\ on\ tree$思想,每个节点继承儿子中最重的节点的信息,其他儿子暴力合并;用两个数组记录当前位置是否是一个极长区间的左(右)端点并记录另一个端点的位置,加入一个节点时尝试合并左右信息并统计合并后的方案数即可。
时间复杂度:$\Theta(n\log n)$。
期望得分:$100$分。
实际得分:$100$分。
代码时刻
#include<bits/stdc++.h>
using namespace std;
struct rec{int nxt,to;}e[200000];
int head[200001],cnt;
int n;
int a[200001],val[200001],size[200001],son[200001],l[200001],r[200001],vis[200001],dfn[200001],now;
long long ans;
void add(int x,int y)
{
e[++cnt].nxt=head[x];
e[cnt].to=y;
head[x]=cnt;
}
void dfs(int x)
{
size[x]=1;
for(int i=head[x];i;i=e[i].nxt)
{
dfs(e[i].to);
size[x]+=size[e[i].to];
if(size[son[x]]<size[e[i].to])son[x]=e[i].to;
}
}
long long insert(int x)
{
vis[a[x]]=now;
if(vis[a[x]+1]!=now)l[a[x]+1]=r[a[x]+1]=0;
if(vis[a[x]-1]!=now)l[a[x]-1]=r[a[x]-1]=0;
long long L=l[a[x]-1],R=r[a[x]+1],len=l[a[x]-1]+r[a[x]+1]+1;
l[a[x]]=L+1;r[a[x]]=R+1;l[a[x]+R]=r[a[x]-L]=len;
return len*(len+1)/2-L*(L+1)/2-R*(R+1)/2;
}
long long ask(int x)
{
long long res=insert(x);
for(int i=head[x];i;i=e[i].nxt)
res+=ask(e[i].to);
return res;
}
int dfs(int x,int opt)
{
now=dfn[x]=x;
long long res=0,num=0,flag=0;
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=son[x])res-=dfs(e[i].to,0);
if(son[x])
{
flag=dfs(son[x],1);
num+=flag;
now=dfn[x]=dfn[son[x]];
}
num+=insert(x);
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=son[x])num+=ask(e[i].to);
ans+=(res+num-flag)*val[x];
return num;
}
int main()
{
scanf("%d",&n);
for(int i=2;i<=n;i++){int x;scanf("%d",&x);add(x,i);}
for(int i=1;i<=n;i++){int x;scanf("%d",&x);a[x]=i;}
for(int i=1;i<=n;i++)scanf("%d",&val[i]);
dfs(1);
dfs(1,1);
printf("%lld",ans);
return 0;
}
rp++