2021牛客多校9 E、Eyjafjalla(主席树合并或dfs序建主席树)

想了个sb主席树合并的做法发现没有离散化,MLE一下午

正解是dfs序建主席树,口糊一下做法:

每次查询[l,r],先倍增找到最接近r的祖先x,然后找x的子树中大于l的节点个数。查子树可以dfs序建个主席树,对于x有in[x]和out[x],用这两个时间戳上的主席树做差查询即可。

下面是类似主席树合并的做法:

同样是倍增找到祖先x,然后找x的子树中大于l的节点个数。可以通过dfs时线段树向上合并,每次查询x节点上权值线段树区间[l,r]的值,但直接线段树合并MLE(听说还会TLE),所以要离散化同时加一些可持久化的操作:

//合并!!!!!!!!!!!!!!!!
int merge(int x, int y)
{
    if (!x || !y)return x | y;
    int o = ++cnt;
    t[o].sum = t[x].sum + t[y].sum;
    t[o].ls = merge(t[x].ls, t[y].ls);
    t[o].rs = merge(t[x].rs, t[y].rs);
    return o;
}

对于树上的一个节点u,它与子树中的某些点共用一些节点比在以自己为根的线段树开点的空间更有(为啥?因为直接套个动态开点线段树会MLE)

如果u的线段树节点x存在并且v的节点y都存在,由于不知道这个主席树的x和y有多少个父节点共用,所以不能在这上面直接修改,得从新开一个点(事实证明,这样空间是够得)
image

对于其中一个存在,另一个不存在,那直接连就好了

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define fastio ios::sync_with_stdio(false),cin.tie(NULL),cout.tie(NULL)
#define pii pair<int,int>
#define pll pair<ll,ll>
const double PI = acos(-1);
const int inf = 1e9;
const ll lnf = 1e18 + 7;
const int maxn = 1e5 + 10;
const int N = 1 << 30 - 1;
ll mod = 998244353;
double eps = 1e-8;

int MAX;//权值线段树值域

int w[maxn], b[maxn];//每个点的权值

int edge_cnt = 0, head[maxn];

struct edge {
    int to, next;
}e[maxn << 1];

void add(int from, int to)
{
    e[++edge_cnt] = { to,head[from] };
    head[from] = edge_cnt;
}

struct tree
{
    int ls, rs;
    int sum;
}t[maxn * 50];

int cnt = 0, root[maxn];//节点编号,每个点对应线段树的根编号

//主席树动态开点插入
int change(int& x, int l, int r, int p)
{
    int to = ++cnt;
    t[to] = t[x], t[to].sum++;
    if (l >= r)return to;
    int mid = l + r >> 1;
    if (p <= mid)t[to].ls = change(t[to].ls, l, mid, p);
    else t[to].rs = change(t[to].rs, mid + 1, r, p);
    return to;
}

//主席树合并!!!!!!!!!!!!!!!!
int merge(int x, int y)
{
    //cout << x << " " << y << endl;
    if (!x || !y)return x | y;
    int o = ++cnt;
    t[o].sum = t[x].sum + t[y].sum;
    t[o].ls = merge(t[x].ls, t[y].ls);
    t[o].rs = merge(t[x].rs, t[y].rs);
    return o;
}

//查询以p为根的主席树
int query(int p, int x, int y, int l, int r)
{
    if (!p)return 0;
    //cout<<x<<" "<<y<<" " << l << " " << r<<" "<<t[p].sum << endl;
    if (x <= l && r <= y)
        return t[p].sum;
    int mid = l + r >> 1;
    ll res = 0;
    if (y > mid)
        res += query(t[p].rs, x, y, mid + 1, r);
    if (x <= mid)
        res += query(t[p].ls, x, y, l, mid);
    return res;
}


int f[32][maxn], deep[maxn];

void dfs1(int from, int fa)
{
    f[0][from] = fa;
    siz[from] = 1;
    for (int i = 1; (1 << i) <= deep[from]; i++)
        f[i][from] = f[i - 1][f[i - 1][from]];
    for (int i = head[from]; ~i; i = e[i].next)
    {
        int to = e[i].to;
        if (to == f[0][from])
            continue;
        deep[to] = deep[from] + 1;
        dfs1(to, from);
        root[from] = merge(root[from], root[to]);
    }
}

//倍增
int find(int u, int r)
{
    //cout << u << " " << v << endl;
    int tmp = deep[u];
    for (int i = 30; i >= 0; i--)
        if ((1 << i) <= tmp)
        {
            if (w[f[i][u]] <= r)
                u = f[i][u], tmp = deep[u];
        }
    return u;
}

int main()
{
    fastio;
    memset(head, -1, sizeof(head));
    int n;
    cin >> n;
    for (int i = 1; i < n; i++)
    {
        int x, y;
        cin >> x >> y;
        add(x, y);
        add(y, x);
    }

    int S = 0;
    vector<int>a;
    for (int i = 1; i <= n; i++)
    {
        cin >> w[i];
        a.push_back(w[i]);
        if (w[i] > MAX)
            S = i, MAX = w[i];
    }
	
    // 离散化
    a.push_back(0);
    a.push_back(1e9 + 7);
    sort(a.begin(), a.end());
    int tot = unique(a.begin(), a.end()) - a.begin();
    for (int i = 1; i <= n; i++)
        b[i] = lower_bound(a.begin(), a.begin() + tot, w[i]) - a.begin();

    MAX = tot;
    for (int i = 1; i <= n; i++)
        root[i] = change(root[i], 1, MAX, b[i]);

    //预处理lca
    deep[S] = 0;
    dfs1(S, 0);

    //查询
    int q;
    cin >> q;
    while (q--)
    {
        int x, l, r;
        cin >> x >> l >> r;
        if (w[x] < l || w[x] > r)
            cout << 0 << "\n";
        else
        {
            int p = find(x, r);//通过倍增上跳找到w[i]最接近r的祖先
            l = lower_bound(a.begin(), a.begin() + tot, l) - a.begin();
            r = upper_bound(a.begin(), a.begin() + tot, r) - a.begin() - 1;
            cout << query(root[p], l, r, 1, MAX) << "\n";
        }
    }

    return 0;

}
posted @ 2021-08-14 20:54  Lecoww  阅读(59)  评论(0编辑  收藏  举报