【bzoj4012】[HNOI2015]开店 动态点分治+STL-vector
题目描述
风见幽香有一个好朋友叫八云紫,她们经常一起看星星看月亮从诗词歌赋谈到人生哲学。最近她们灵机一动,打算在幻想乡开一家小店来做生意赚点钱。这样的想法当然非常好啦,但是她们也发现她们面临着一个问题,那就是店开在哪里,面向什么样的人群。很神奇的是,幻想乡的地图是一个树形结构,幻想乡一共有 n个地方,编号为 1 到 n,被 n-1 条带权的边连接起来。每个地方都住着一个妖怪,其中第 i 个地方的妖怪年龄是 x_i。妖怪都是些比较喜欢安静的家伙,所以它们并不希望和很多妖怪相邻。所以这个树所有顶点的度数都小于或等于 3。妖怪和人一样,兴趣点随着年龄的变化自然就会变化,比如我们的 18 岁少女幽香和八云紫就比较喜欢可爱的东西。幽香通过研究发现,基本上妖怪的兴趣只跟年龄有关,所以幽香打算选择一个地方 u(u为编号),然后在 u开一家面向年龄在 L到R 之间(即年龄大于等于 L、小于等于 R)的妖怪的店。也有可能 u这个地方离这些妖怪比较远,于是幽香就想要知道所有年龄在 L 到 R 之间的妖怪,到点 u 的距离的和是多少(妖怪到 u 的距离是该妖怪所在地方到 u 的路径上的边的权之和) ,幽香把这个称为这个开店方案的方便值。幽香她们还没有决定要把店开在哪里,八云紫倒是准备了很多方案,于是幽香想要知道,对于每个方案,方便值是多少呢。
输入
第一行三个用空格分开的数 n、Q和A,表示树的大小、开店的方案个数和妖怪的年龄上限。
接下来 n-1 行,每行三个用空格分开的数 a、b、c,表示树上的顶点 a 和 b 之间有一条权为c(1 <= c <= 1000)的边,a和b 是顶点编号。
接下来Q行,每行三个用空格分开的数 u、 a、 b。对于这 Q行的每一行,用 a、b、A计算出 L和R,表示询问“在地方 u开店,面向妖怪的年龄区间为[L,R]的方案的方便值是多少”。对于其中第 1 行,L 和 R 的计算方法为:L=min(a%A,b%A), R=max(a%A,b%A)。对于第 2到第 Q行,假设前一行得到的方便值为 ans,那么当前行的 L 和 R 计算方法为: L=min((a+ans)%A,(b+ans)%A), R=max((a+ans)%A,(b+ans)%A)。
输出
对于每个方案,输出一行表示方便值。
样例输入
10 10 10
0 0 7 2 1 4 7 7 7 9
1 2 270
2 3 217
1 4 326
2 5 361
4 6 116
3 7 38
1 8 800
6 9 210
7 10 278
8 9 8
2 8 0
9 3 1
8 0 8
4 2 7
9 7 3
4 7 0
2 2 7
3 2 1
2 3 4
样例输出
1603
957
7161
9466
3232
5223
1879
1669
1282
0
题解
动态点分治+STL-vector
查找距离和的方法参考 bzoj3924 。即维护子树中所有点到该点的距离和以及所有点到该点父亲节点的距离和。
但是本题带了“年龄”的限制,因此不能直接维护数组。
考虑到本题没有修改操作,因此可以直接把维护的距离和数组改为vector,记录上面的信息,同时记录“年龄”。然后把每个vector按照“年龄”从小到大排序,并求出前缀和,然后上二分查找即可。注意此时需要加上两个虚拟的“哨兵节点”以防止越界。
时间复杂度$O(n\log^2n)$
#include <cstdio> #include <vector> #include <algorithm> #define N 150010 using namespace std; typedef long long ll; struct data { int val; ll dis , sum; data() {} data(int a , ll b , ll c) {val = a , dis = b , sum = c;} bool operator<(const data &a)const {return val == a.val ? dis == a.dis ? sum < a.sum : dis < a.dis : val < a.val;} }; vector<data> va[N] , vb[N]; int a[N] , head[N] , to[N << 1] , len[N << 1] , next[N << 1] , cnt , pos[N] , log[N << 1] , tot; int si[N] , mx[N] , sum , root , vis[N] , fa[N]; ll deep[N] , md[20][N << 1]; void add(int x , int y , int z) { to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt; } void dfs(int x , int fa) { int i; pos[x] = ++tot , md[0][tot] = deep[x]; for(i = head[x] ; i ; i = next[i]) if(to[i] != fa) deep[to[i]] = deep[x] + len[i] , dfs(to[i] , x) , md[0][++tot] = deep[x]; } ll calc(int x , int y) { ll t = deep[x] + deep[y]; x = pos[x] , y = pos[y]; if(x > y) swap(x , y); int k = log[y - x + 1]; return t - 2 * min(md[k][x] , md[k][y - (1 << k) + 1]); } void getroot(int x , int fa) { int i; si[x] = 1 , mx[x] = 0; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]] && to[i] != fa) getroot(to[i] , x) , si[x] += si[to[i]] , mx[x] = max(mx[x] , si[to[i]]); mx[x] = max(mx[x] , sum - si[x]); if(mx[x] < mx[root]) root = x; } void solve(int x) { int i; vis[x] = 1; for(i = head[x] ; i ; i = next[i]) if(!vis[to[i]]) sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , fa[root] = x , solve(root); } ll query(int x , int p) { int i , t; ll ans = 0; for(i = x ; i ; i = fa[i]) { t = lower_bound(va[i].begin() , va[i].end() , data(p , 0 , 0)) - va[i].begin() - 1; ans += va[i][t].sum + t * calc(i , x); } for(i = x ; fa[i] ; i = fa[i]) { t = lower_bound(vb[i].begin() , vb[i].end() , data(p , 0 , 0)) - vb[i].begin() - 1; ans -= vb[i][t].sum + t * calc(fa[i] , x); } return ans; } int main() { int n , m , k , i , j , x , y , z; ll last = 0; scanf("%d%d%d" , &n , &m , &k); for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &a[i]); for(i = 1 ; i < n ; i ++ ) scanf("%d%d%d" , &x , &y , &z) , add(x , y , z) , add(y , x , z); dfs(1 , 0); for(i = 2 ; i <= tot ; i ++ ) log[i] = log[i >> 1] + 1; for(i = 1 ; (1 << i) <= tot ; i ++ ) for(j = 1 ; j <= tot - (1 << i) + 1 ; j ++ ) md[i][j] = min(md[i - 1][j] , md[i - 1][j + (1 << (i - 1))]); mx[0] = 1 << 30 , sum = n , getroot(1 , 0) , solve(root); for(i = 1 ; i <= n ; i ++ ) for(j = i ; j ; j = fa[j]) va[j].push_back(data(a[i] , calc(i , j) , 0)) , vb[j].push_back(data(a[i] , calc(i , fa[j]) , 0)); for(i = 1 ; i <= n ; i ++ ) { va[i].push_back(data(-1 , 0 , 0)) , va[i].push_back(data(1 << 30 , 0 , 0)); vb[i].push_back(data(-1 , 0 , 0)) , vb[i].push_back(data(1 << 30 , 0 , 0)); sort(va[i].begin() , va[i].end()) , sort(vb[i].begin() , vb[i].end()); for(j = 1 ; j < (int)va[i].size() ; j ++ ) va[i][j].sum = va[i][j - 1].sum + va[i][j].dis; for(j = 1 ; j < (int)vb[i].size() ; j ++ ) vb[i][j].sum = vb[i][j - 1].sum + vb[i][j].dis; } while(m -- ) { scanf("%d%d%d" , &x , &y , &z) , y = (y + last) % k , z = (z + last) % k; if(y > z) swap(y , z); printf("%lld\n" , last = query(x , z + 1) - query(x , y)); } return 0; }