nc多校2021-9E.Eyjafjalla
nc多校2021-9E.Eyjafjalla
链接:E-Eyjafjalla_2021牛客暑期多校训练营9 (nowcoder.com)
菜狗的人生第一道主席树题目(自主思考、码代码)
相关知识点:树上倍增、主席树(可持久化线段树)
题意:给定n个节点和n-1条边,每个点有一个温度,1号点温度最高,距离1号点越远,温度越低。有q次查询,每次查询询问从x点爆发病毒,病毒的存活温度区间为[l,r],相邻且温度适宜的点会被传染,一共有多少个点会被传染。
分析:
- 抽象一下题意:给定一棵以1号点为根,有n个节点的树,树上每个节点有一个温度,满足父节点温度大于其所有子节点、根节点温度最高。
- 抽象一下问题:给一个点x和区间[l, r],求包含x在内的且所有点温度在区间内的最大连通块点数。在树上分析可转化为:求以某个点为根的子树中,有多少个点再区间内。
- 进一步分析可得,病毒传播可分两个阶段:
- 从x点向上传播到可存活的最高祖宗节点
- 然后从该点往下传播到所有可存活的子节点中
做法:
- 可以采用树上倍增的方法,在
O(logn)
的复杂度内求的最高的祖宗节点 - 根据树的dfs序建立可持久化线段树(主席树),对每一个询问可做到
O(logn)
内,查询总复杂度:O(nlogn)
具体步骤:
- 读入边,用邻接表存储
- 读入每个点的温度,由于最多有100000个点,而温度范围是1~1e9,所以需要将温度先离散化
- 从1号节点开始dfs遍历,求出dfs序(用
ord[]
存储)、所有子树根节点在dfs序中对应的位置(用in[]
存储)、所有子树遍历的最后一个儿子节点在dfs序中对应的位置(用out[]
存储)、祖宗倍增数组(用fa[][]
存储,第二维开到20即可)。 - 按照dfs序建立可持久化线段树
- 对每个查询,先求最大祖先,然后用可持久化线段树查询点数
树上倍增模板
int fa[N][20]
void dfs(int u, int f)
{
fa[u][0] = f;
for(int i = 1; i < 20; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for(int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j != f) dfs(j, u);
}
}
//找x的第k个父亲节点
int find_fa(int i,int k)
{
for(int x = 0; x <= int(log2(k)); x++)
if((1 << x) & k) //(1<<x)&k可以判断k的二进制表示中,第(x-1)位上是否为1
i = fa[i][x]; //把i往上提
return i;
}
可持久化线段树分析
以题目给出的样例为例,如下图:
节点中存三个信息:左儿子、右儿子、区间内的点的数量
按照dfs序,每次加入一个节点构成一个新版本,一共n+1个版本(初始为0)
每次查询加入该节点后,在区间[l, r]中的点数有多少个。由此可得,如果想获得子树的点数,可以用两个版本(子树根节点的前一个版本和子树最后一个节点的版本)的差值求出。
因为:在dfs序中,这两个点之间的所有点就是该子树的所有节点,加入点之前[l, r]的点数减去加入点之后[l, r]的点数就等于加入的点中在[l, r]的点数。
关于开多少节点:线段树本身的4*N
个节点+每次更新版本会新开logn
个节点,一共(n + 1)个版本,一共开4 * N + 17 * N
个节点即可,空间复杂度为O(4n + (n + 1)logn)
解释一下in[]
和out[]
:
以样例为例,可求得:
ord[] = 1 3 2 4
in[] = 1 3 2 4
out[] = 4 4 2 4
以1号点作为根的子树的根在dfs序中的编号为1,该子树最后一个节点在dfs序中的编号为4
AC代码:
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
const int N = 1e5+10, INF = 1e9+10;
int n, q;
int h[N], e[N * 2], ne[N * 2], idx;
int tem[N], fa[N][20];
//dfs序,子树根节点对应的dfs下标,对应子树的最后一个儿子节点的dfs下标
int ord[N], in[N], out[N], po;
vector<int> nums;
//主席树节点
struct Node
{
//左右儿子,大小,区间需要通过参数传下去
int l, r;
int si;
}tr[N * 4 + 17 * N];
//不同版本的根节点,点数
int root[N], tot;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
//离散化
int find(int x)
{
return lower_bound(nums.begin(), nums.end(), x) - nums.begin();
}
//创建版本0
int build(int l, int r)
{
int p = ++tot;
if(l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l, mid), tr[p].r = build(mid + 1, r);
return p;
}
//上一版本,范围,插入的值
int insert(int p, int l, int r, int x)
{
int q = ++tot;
tr[q] = tr[p];
if(l == r)
{
tr[q].si++;
return q;
}
int mid = l + r >> 1;
if(x <= mid) tr[q].l = insert(tr[p].l, l, mid, x);
else tr[q].r = insert(tr[p].r, mid + 1, r, x);
tr[q].si = tr[tr[q].l].si + tr[tr[q].r].si;
return q;
}
//版本,范围, 温度范围
int query(int p, int l, int r, int min_te, int max_te)
{
if(p == 1 || nums[l] > max_te || nums[r] < min_te) return 0;
if(nums[l] >= min_te && nums[r] <= max_te) return tr[p].si;
int mid = l + r >> 1;
return query(tr[p].l, l, mid, min_te, max_te) + query(tr[p].r, mid + 1, r, min_te, max_te);
}
//倍增记录所有点的父节点,并求dfs序
void dfs(int u, int f)
{
ord[++po] = u;
in[u] = po;
fa[u][0] = f;
for(int i = 1; i < 20; i++)
fa[u][i] = fa[fa[u][i-1]][i-1];
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != f) dfs(j, u);
}
out[u] = po;
}
//倍增找区间内的最大祖先
int find_fa(int u, int max_te)
{
for(int i = 19; ~i; i--)
if(fa[u][i] && tem[fa[u][i]] <= max_te)
u = fa[u][i];
return u;
}
int main()
{
scanf("%d", &n);
memset(h, -1, sizeof h);
for(int i = 0; i < n - 1; i++)
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
for(int i = 1; i <= n; i++)
{
scanf("%d", &tem[i]);
nums.push_back(tem[i]);
}
//离散化
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
//dfs求dfs序和祖宗倍增数组
dfs(1, 0);
//根据dfs序建立主席树
root[0] = build(0, nums.size() - 1);
for(int i = 1; i <= po; i++)
root[i] = insert(root[i - 1], 0, nums.size() - 1, find(tem[ord[i]]));
scanf("%d", &q);
while(q--)
{
int x, l, r;
scanf("%d%d%d", &x, &l, &r);
if(tem[x] > r || tem[x] < l)
{
puts("0");
continue;
}
int ro = find_fa(x, r);;
printf("%d\n", query(root[out[ro]], 0, nums.size() - 1, l, r) - query(root[in[ro] - 1], 0, nums.size() - 1, l, r));
}
return 0;
}