冷静 清醒 直视内心.|

艾特玖

园龄:3年11个月粉丝:12关注:7

P3925 aaa被续

P3925 aaa被续

分析

这题不算很麻烦,我们直接来说思路。

算法:树剖+线段树

首先不难想到,我们肯定是从最大值开始算起,接下来考虑该点对答案的贡献

对于枚举的点,我们假设点为i

其对答案的贡献是,i1的路径中,所经过的所有节点的子树中的i在从小到大的排位*vali的和

转化一下,则对于枚举的点,我们要干的事情是知道从i1的路径中val[i],在路径中节点的子树下的排位的和。

这个问题不难解决,我们可以初始化线段树中的每个节点的sumsz[i],这样我们从大到小枚举点的时候,直接算出i1的路径权值和即可。在枚举该点后,要将路径中的所有点点权-1,因为此时从对于路径中每个点的子树中的点的点权从小到大的排位要再少一个了。

总结下,我们要干的是两件事:

  • i1的路径中的权值和*vali即可以得到i对答案的贡献
  • 接着不要忘记将位次-1,即为将从i1的路径中的所有点点权-1

则,我们对线段树的操作时,求区间和,和区间加

来看代码,不难理解

Ac_code

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
const int N = 5e5 + 10,mod = 1e9 + 7;
struct Node
{
    int l,r;
    LL sum;
    int tag;
}tr[N<<2];
int h[N],ne[N<<1],e[N<<1],idx;
int sz[N],fa[N],son[N],dep[N];
int top[N],id[N],nw[N],ts;
PII val[N];
int n;

void add(int a,int b)
{
    e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}

void dfs1(int u,int pa,int depth)
{
    sz[u] = 1,dep[u] = depth;
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==pa) continue;
        fa[j] = u;
        dfs1(j,u,depth+1);
        sz[u] += sz[j];
        if(sz[j]>sz[son[u]]) son[u] = j;
    }
}

void dfs2(int u,int tp)
{
    top[u] = tp,id[u] = ++ts,nw[ts] = sz[u];
    if(!son[u]) return ;
    dfs2(son[u],tp);
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==fa[u]||j==son[u]) continue;
        dfs2(j,j);
    }
}

void pushup(int u)
{
    tr[u].sum = (tr[u<<1].sum + tr[u<<1|1].sum)%mod;
}

void build(int u,int l,int r)
{
    if(l==r)
    {
        tr[u] = {l,r,nw[l],0};
        return ;
    }
    tr[u] = {l,r};
    int mid = l + r >> 1;
    build(u<<1,l,mid),build(u<<1|1,mid+1,r);
    pushup(u);
}

void pushdown(int u)
{
    auto &root = tr[u],&left = tr[u<<1],&right = tr[u<<1|1];
    if(root.tag)
    {
        left.tag += root.tag;
        right.tag += root.tag;
        left.sum -= 1ll*root.tag*(left.r-left.l+1);
        right.sum -= 1ll*root.tag*(right.r-right.l+1);
        root.tag = 0;
    }
}

void modify(int u,int l,int r)
{
    if(l<=tr[u].l&&tr[u].r<=r)
    {
        tr[u].tag ++;
        tr[u].sum -= tr[u].r-tr[u].l+1;
        return ;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l<=mid) modify(u<<1,l,r);
    if(r>mid) modify(u<<1|1,l,r);
    pushup(u);
}

LL query(int u,int l,int r)
{
    if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);
    LL res = 0;
    if(l<=mid) res += query(u<<1,l,r);
    if(r>mid) res += query(u<<1|1,l,r);
    return res;
}

int main()
{
    scanf("%d",&n);
    memset(h,-1,sizeof h);
    for(int i = 0;i < n-1; i ++ )
    {
        int u,v;scanf("%d%d",&u,&v);
        add(u,v),add(v,u);
    }
    dfs1(1,-1,1);
    dfs2(1,1);
    build(1,1,n);
    for(int i=1;i<=n;i++)
    {
        int x;cin>>x;
        val[i] = {x,i};
    }
    sort(val+1,val+1+n,greater<PII>());
    LL ans = 0;
    for(int i=1;i<=n;i++)
    {
        int x = val[i].second;
        LL res = 0;
        while(top[x]!=1)
        {
            res += query(1,id[top[x]],id[x]);
            modify(1,id[top[x]],id[x]);
            x = fa[top[x]];
        }
        res += query(1,1,id[x]);
        modify(1,1,id[x]);
        ans = (ans+res*val[i].first%mod)%mod;
    }
    printf("%lld\n",ans);
    return 0;
}

本文作者:艾特玖

本文链接:https://www.cnblogs.com/aitejiu/p/16191964.html

版权声明:本作品采用艾特玖许可协议进行许可。

posted @   艾特玖  阅读(21)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
收起