[ABC133F] Colorful Tree (可持久化线段树 or 树上差分)

trick: 可持久化线段树 or 树上差分
题意:有一棵树,树上有 N 个顶点,编号为 1 至 N 。这棵树上的第 i 条边连接顶点 u 和 v,这条边的颜色为 Ci 长度为 Di 。
有Q 个询问,每次询问中把把所有是 x 颜色的边的值改成 y,问 u 到 v 的距离是多少。上个询问的修改不影响后面的。

做法1:u 到 v 的距离是 dis[u]+dis[v] - 2*dis[LCA(u,v)].
现在就可以考虑 根节点到 某一个点的距离怎么计算。
这个可以使用可持久化线段树来做。
具体做法,每个点维护的是其到根节点的信息,这里的信息指的是根节点到当前节点的每个颜色出现的次数和总和。
维护方法就是每个节点都是通过其父亲转移过来。
那么对于一次询问中 根节点到 u节点的距离就是 = dis[u]+y * cnt-sum(cnt就是根节点到当前节点 x 颜色的数量,sum就是根节点到当前节点 x颜色的总和).

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int>pir;
const int N = 1e5 + 5;
pir tree[N * 30];
int lson[N * 30], rson[N * 30], rt[N], tot;
void update(int l, int r, int &rt, int pre, int col, int val)
{
    rt = ++tot;
    tree[rt] = tree[pre];
    lson[rt] = lson[pre];
    rson[rt] = rson[pre];
    tree[rt].first++;
    tree[rt].second += val;
    if(l == r)return;
    int mid = l + r >> 1;
    if(col <= mid)update(l, mid, lson[rt], lson[pre], col, val);
    else update(mid + 1, r, rson[rt], rson[pre], col, val);
}
pir query(int l, int r, int rt, int col)
{
    if(l == r)
    {
        return tree[rt];
    }
    int mid = l + r >> 1;
    if(col <= mid)return query(l, mid, lson[rt], col);
    else return query(mid + 1, r, rson[rt], col);
}
struct edge
{
    int v, col, val;
};
std::vector<edge> g[N];
struct  HLD
{
    int fa[N], sz[N], wc[N], dep[N], vistime;
    void dfs1(int u, int f)// 获得 子树大小,深度,父节点,重儿子等信息
    {
        fa[u] = f;
        sz[u] = 1;
        dep[u] = dep[f] + 1;
        for(auto [v, c, a] : g[u])
        {
            if(v == f)continue;
            dfs1(v, u);
            sz[u] += sz[v];
            if(sz[v] > sz[wc[u]])wc[u] = v;
        }
    }
    int dfn[N], rdfn[N], top[N];
    void dfs2(int u, int Top)// 获取 dfs 序当前重链的链头
    {
        dfn[u] = ++vistime;
        rdfn[vistime] = u;
        top[u] = Top;
        if(wc[u] != 0)
        {
            dfs2(wc[u], Top);
            for(auto [v, c, a] : g[u])
            {
                if(v == fa[u] || v == wc[u])continue;
                dfs2(v, v);
            }
        }
    }
    int lca(int u, int v)
    {
        while (top[u] != top[v])
        {
            if (dep[top[u]] > dep[top[v]])
            {
                u = fa[top[u]];
            }
            else
            {
                v = fa[top[v]];
            }
        }
        return dep[u] < dep[v] ? u : v;
    }
} hld;
int dis[N], n, q;
void dfs(int u, int fa)
{
    for(auto [v, col, a] : g[u])
    {
        if(v == fa)continue;
        update(1, n, rt[v], rt[u], col, a);
        dis[v] = dis[u] + a;
        dfs(v, u);
    }
}
ll get_val(int u, int col, int val)
{
    auto [cnt, sum] = query(1, n, rt[u], col);
    ll ans = dis[u] + 1ll * cnt * val - sum;
    return ans;
}
void solve()
{
    cin >> n >> q;
    for(int i = 1; i < n; i++)
    {
        int u, v, c, a;
        cin >> u >> v >> c >> a;
        g[u].push_back({v, c, a});
        g[v].push_back({u, c, a});
    }
    hld.dfs1(1, 0);
    hld.dfs2(1, 0);
    dfs(1, 0);
    while(q--)
    {
        int x, y, u, v;
        cin >> x >> y >> u >> v;
        int L = hld.lca(u, v);
        ll ans = get_val(u, x, y) + get_val(v, x, y) - 2 * get_val(L, x, y);
        cout << ans << '\n';
    }
}
int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
    std::cout.tie(0);

    solve();
}
/*

*/

做法2:我们可以把询问拆贡献,拆到 u,v,lca(u,v) 这三个点上。
把问题转化成求根节点到 u节点的在把颜色为 x的改成 y后的距离。
那么其实我们维护的是一条到根节点的链。
这个就可以通过 DFS中维护每个颜色的出现次数和总和来实现。
具体的就是 进入到一个点的时候 加上这条边的贡献,回溯的时候减去这条边的贡献。
这样就可以计算这一条链的贡献。
我一开始想的是让 lca(u,v) 子树的贡献 - u子树的贡献 -v子树的贡献,这显然是错的。很可惜。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int>pir;
const int N = 1e5 + 5;
struct node
{
    int v, w, c;
};
std::vector<node> g[N];
struct zzzzz
{
    int fa[N], sz[N], wc[N], vistime, dep[N], dis[N];
    int dfn[N], rdfn[N], top[N];
    ll tree[N << 2], tag[N << 2];

    void dfs1(int u, int f)// 获得 子树大小,深度,父节点,重儿子等信息
    {
        fa[u] = f;
        sz[u] = 1;
        dep[u] = dep[f] + 1;
        for(auto [v, w, c] : g[u])
        {
            if(v == f)continue;
            dis[v] = dis[u] + w;
            dfs1(v, u);
            sz[u] += sz[v];
            if(sz[v] > sz[wc[u]])wc[u] = v;
        }
    }
    void dfs2(int u, int Top)// 获取 dfs 序当前重链的链头
    {
        dfn[u] = ++vistime;
        rdfn[vistime] = u;
        top[u] = Top;
        if(wc[u] != 0)
        {
            dfs2(wc[u], Top);
            for(auto [v, w, c] : g[u])
            {
                if(v == fa[u] || v == wc[u])continue;
                dfs2(v, v);
            }
        }
    }


    int lca(int u, int v)
    {
        while (top[u] != top[v])
        {
            if (dep[top[u]] > dep[top[v]])
            {
                u = fa[top[u]];
            }
            else
            {
                v = fa[top[v]];
            }
        }
        return dep[u] < dep[v] ? u : v;
    }
} T;

struct query
{
    int id, x, y, tag;
};
std::vector<query> Q[N];
ll ans[N], cnt[N], sum[N];

void dfs(int u, int fa)
{
    for(auto [id, x, y, tag] : Q[u])
    {
        ans[id] += 1ll * tag * (T.dis[u] + 1ll * cnt[x] * y - sum[x]);
    }
    for(auto [v, w, c] : g[u])
    {
        if(v == fa)continue;
        cnt[c]++, sum[c] += w;
        dfs(v, u);
        cnt[c]--, sum[c] -= w;
    }
}

void solve()
{
    int n, m;
    cin >> n >> m;
    std::vector<node> b(n);
    for(int i = 1; i < n; i++)
    {
        int u, v, c, d;
        cin >> u >> v >> c >> d;
        g[u].push_back({v, d, c});
        g[v].push_back({u, d, c});
    }
    T.dfs1(1, 0);
    T.dfs2(1, 0);
    for(int i = 1; i <= m; i++)
    {
        int x, y, u, v;
        cin >> x >> y >> u >> v;
        if(T.dep[u] > T.dep[v])swap(u, v);
        int L = T.lca(u, v);
        if(L != u)
        {
            Q[L].push_back({i, x, y, -2});
            Q[u].push_back({i, x, y, 1});
            Q[v].push_back({i, x, y, 1});
        }
        else
        {
            Q[L].push_back({i, x, y, -1});
            Q[v].push_back({i, x, y, 1});
        }
    }
    dfs(1, 0);
    for(int i = 1; i <= m; i++)cout << ans[i] << '\n';

}
int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
    std::cout.tie(0);

    solve();
}
/*

*/
posted @   pipipipipi43  阅读(10)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示