计蒜客 2019南昌邀请网络赛J Distance on the tree(主席树)题解
题意:给出一棵树,给出每条边的权值,现在给出m个询问,要你每次输出u~v的最短路径中,边权 <= k 的边有几条
思路:当时网络赛的时候没学过主席树,现在补上。先树上建主席树,然后把边权交给子节点,然后数量就变成了 u + v - lca * 2。专题里那道算点权的应该算原题吧。1A = =,强行做模板题提高自信。
代码:
#include<cmath> #include<set> #include<map> #include<queue> #include<cstdio> #include<vector> #include<cstring> #include <iostream> #include<algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn = 1e5 + 10; const int M = maxn * 30; const ull seed = 131; const int INF = 0x3f3f3f3f; const int MOD = 1e9 + 7; int n, m; int root[maxn], tot; struct Edge{ int v, next; ll w; }edge[maxn << 1]; int head[maxn], tol; void addEdge(int u, int v, ll w){ edge[tol].v = v; edge[tol].w = w; edge[tol].next = head[u]; head[u] = tol++; } struct node{ int lson, rson; int sum; }T[maxn * 40]; void init(){ memset(T, 0, sizeof(T)); memset(root, 0, sizeof(root)); memset(head, -1, sizeof(head)); tot = tol = 0; } vector<int> vv; int getid(int x){ return lower_bound(vv.begin(), vv.end(), x) - vv.begin() + 1; } void update(int l, int r, int &now, int pre, int v, int pos){ T[++tot] = T[pre], T[tot].sum += v, now = tot; if(l == r) return; int m = (l + r) >> 1; if(pos <= m) update(l, m, T[now].lson, T[pre].lson, v, pos); else update(m + 1, r, T[now].rson, T[pre].rson, v, pos); } void build(int now, int pre, ll w){ update(1, vv.size(), root[now], root[pre], 1, getid(w)); for(int i = head[now]; i != -1; i = edge[i].next){ int v = edge[i].v; if(v == pre) continue; build(v, now, edge[i].w); } } int query(int l, int r, int now, int pre, int lca, int k){ if(l == r){ if(k >= l) return T[now].sum + T[pre].sum - T[lca].sum * 2; return 0; } if(r <= k) return T[now].sum + T[pre].sum - T[lca].sum * 2; int m = (l + r) >> 1; int sum = 0; if(k <= m) return query(l, m, T[now].lson, T[pre].lson, T[lca].lson, k); else{ sum = query(m + 1, r, T[now].rson, T[pre].rson, T[lca].rson, k); return sum + T[T[now].lson].sum + T[T[pre].lson].sum - T[T[lca].lson].sum * 2; } } //lca int fa[maxn][20]; int dep[maxn]; void lca_dfs(int u, int pre, int d){ dep[u] = d; fa[u][0] = pre; for(int i = head[u]; i != -1; i = edge[i].next){ int v = edge[i].v; if(v != pre) lca_dfs(v, u, d + 1); } } void lca_update(){ for(int i = 1; (1 << i) <= n; i++){ for(int u = 1; u <= n; u++){ fa[u][i] = fa[fa[u][i - 1]][i - 1]; } } } int lca_query(int u, int v){ if(dep[u] < dep[v]) swap(u, v); int d = dep[u] - dep[v]; for(int i = 0; (1 << i) <= d; i++){ if(d & (1 << i)){ u = fa[u][i]; } } if(u != v){ for(int i = (int)log2(n); i >= 0; i--){ if(fa[u][i] != fa[v][i]){ u = fa[u][i]; v = fa[v][i]; } } u = fa[u][0]; } return u; } int u1[maxn], v1[maxn]; ll k1[maxn]; int main(){ init(); vv.clear(); scanf("%d%d", &n, &m); vv.push_back(0); for(int i = 1; i <= n - 1; i++){ int u, v; ll w; scanf("%d%d%lld", &u, &v, &w); addEdge(u, v, w); addEdge(v, u, w); vv.push_back(w); } for(int i = 1; i <= m; i++){ scanf("%d%d%lld", &u1[i], &v1[i], &k1[i]); vv.push_back(k1[i]); } sort(vv.begin(), vv.end()); vv.erase(unique(vv.begin(), vv.end()), vv.end()); lca_dfs(1, 0, 1); lca_update(); build(1, 0, 0); for(int i = 1; i <= m; i++){ int lca = lca_query(u1[i], v1[i]); printf("%d\n", query(1, vv.size(), root[u1[i]], root[v1[i]], root[lca], getid(k1[i]))); } return 0; }