CF1140G Double Tree


题解

首先如果我们要确定出每个\(dis_{i \to i+1 , i \in odd}\)
这个可以用两遍树形\(DP\)来解决
一遍是考虑走子树子树绕过来的
一遍是考虑从走祖先绕过来的
然后就可以考虑用倍增来解决了
\(st1[u][i][0/1][0/1]\)表示从点\(u\)开始向上跳\(2^j\)步,开始的位置位于左/右边的树,结束的位置位于左/右边的树
倍增的时候就用两个数组\(dp1[0/1]/dp2[0/1]\)表示从\(u/v\)到左/右边树的\(LCA\)
不断往上跳着更新即可
注意要先把\(dp1/dp2\)的数组的值记录下来再用记录的值更新==

代码

#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
# define LL long long
const int M = 300005 ;
const LL INF = 1e18 ;
using namespace std ;

inline LL read() {
	char c = getchar() ; LL x = 0 , w = 1 ;
	while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
	while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
	return x*w ;
}

int n , lg[M] , dep[M] ; 
LL w[M] , st[M][22] , st1[M][22][2][2] , dp1[2] , dp2[2] ;
struct Node { int v ; LL w1 , w2 ; } ;
vector < Node > vec[M] ;
void dfs1(int u , int father) {
	LL w1 , w2 ;
	st[u][0] = father ; dep[u] = dep[father] + 1 ;
	for(int i = 0 , v ; i < vec[u].size() ; i ++) {
		v = vec[u][i].v , w1 = vec[u][i].w1 , w2 = vec[u][i].w2 ;
		if(v == father) continue ; dfs1(v , u) ;
		w[u] = min( w[u] , w[v] + w1 + w2 ) ;
	}
}
void dfs2(int u , int father) {
	LL w1 , w2 ;
	for(int i = 0 , v ; i < vec[u].size() ; i ++) {
		v = vec[u][i].v , w1 = vec[u][i].w1 , w2 = vec[u][i].w2 ;
		if(v == father) continue ;
		w[v] = min( w[v] , w[u] + w1 + w2 ) ;
		dfs2(v , u) ;
	}
}
void dfs(int u , int father) {
	LL w1 , w2 ;
	for(int i = 0 , v ; i < vec[u].size() ; i ++) {
		v = vec[u][i].v , w1 = vec[u][i].w1 , w2 = vec[u][i].w2 ;
		if(v == father) continue ;
		dfs(v , u) ;
		st1[v][0][0][0] = min( w1 , w[v] + w2 + w[u] ) ;
		st1[v][0][1][1] = min( w2 , w[u] + w1 + w[v] ) ;
		st1[v][0][0][1] = min( w[v] + w2 , w1 + w[u] ) ;
		st1[v][0][1][0] = min( w2 + w[u] , w[v] + w1 ) ;
	}
}
inline int LCA(int u , int v) {
	if(dep[u] < dep[v]) swap(u , v) ;
	for(int i = lg[n] ; i >= 0 ; i --) if(dep[st[u][i]] >= dep[v]) u = st[u][i] ;
	if(u == v) return u ;
	for(int i = lg[n] ; i >= 0 ; i --) if(st[u][i] != st[v][i]) u = st[u][i] , v = st[v][i] ;
	return st[u][0] ;
}
inline LL Solve(int x , int y) {
	int u = ((x + 1) >> 1) , v = ((y + 1) >> 1) , lcap ;
	if(dep[u] < dep[v]) {
		swap(u , v) ; 
		swap(x , y) ;
	}
	lcap = LCA(u , v) ; 
	LL tp0 , tp1 ;
	dp1[(x + 1) & 1] = 0 ; dp1[((x + 1) & 1) ^ 1] = w[u] ;
	for(int i = lg[n] ; i >= 0 ; i --) {
		if(dep[st[u][i]] >= dep[lcap]) {
			tp0 = dp1[0] ; tp1 = dp1[1] ;
			dp1[0] = min( tp0 + st1[u][i][0][0] , tp1 + st1[u][i][1][0] ) ;
			dp1[1] = min( tp0 + st1[u][i][0][1] , tp1 + st1[u][i][1][1] ) ;
			u = st[u][i] ;
		}
	}
	u = v ;
	dp2[(y + 1) & 1] = 0 ; dp2[((y + 1) & 1) ^ 1] = w[u] ;
	for(int i = lg[n] ; i >= 0 ; i --) {
		if(dep[st[u][i]] >= dep[lcap]) {
			tp0 = dp2[0] , tp1 = dp2[1] ;
			dp2[0] = min( tp0 + st1[u][i][0][0] , tp1 + st1[u][i][1][0] ) ;
			dp2[1] = min( tp0 + st1[u][i][0][1] , tp1 + st1[u][i][1][1] ) ;
			u = st[u][i] ;
		}
	}
	return min(dp1[0] + dp2[0] , dp1[1] + dp2[1]) ;
}
int main() {
	n = read() ;
	for(int i = 2 ; i <= n ; i ++) lg[i] = lg[i >> 1] + 1 ;
	for(int i = 1 ; i <= n ; i ++) w[i] = read() ;
	LL w1 , w2 ;
	for(int i = 1 , u , v ; i < n ; i ++) {
		u = read() ; v = read() ; w1 = read() ; w2 = read() ;
		vec[u].push_back( (Node) { v , w1 , w2 } ) ;
		vec[v].push_back( (Node) { u , w1 , w2 } ) ;
	}
	memset(st1 , 31 , sizeof(st1)) ;
	dfs1(1 , 0) ; dfs2(1 , 0) ; dfs(1 , 0) ;
	for(int j = 1 ; j <= lg[n] ; j ++)
		for(int u = 1 ; u <= n ; u ++) {
			st[u][j] = st[st[u][j - 1]][j - 1] ;
			st[u][j] = st[st[u][j - 1]][j - 1] ;
			st1[u][j][0][0] = min( st1[u][j - 1][0][0] + st1[st[u][j - 1]][j - 1][0][0] , st1[u][j - 1][0][1] + st1[st[u][j - 1]][j - 1][1][0] ) ;
			st1[u][j][0][1] = min( st1[u][j - 1][0][0] + st1[st[u][j - 1]][j - 1][0][1] , st1[u][j - 1][0][1] + st1[st[u][j - 1]][j - 1][1][1] ) ;
			st1[u][j][1][0] = min( st1[u][j - 1][1][0] + st1[st[u][j - 1]][j - 1][0][0] , st1[u][j - 1][1][1] + st1[st[u][j - 1]][j - 1][1][0] ) ;
			st1[u][j][1][1] = min( st1[u][j - 1][1][0] + st1[st[u][j - 1]][j - 1][0][1] , st1[u][j - 1][1][1] + st1[st[u][j - 1]][j - 1][1][1] ) ;
		}
	int Q = read() , x , y ;
	while(Q --) {
		x = read() ; y = read() ;
		printf("%lld\n",Solve(x , y)) ;
	}
	return 0 ;
}
posted @ 2019-04-14 21:47  beretty  阅读(273)  评论(0编辑  收藏  举报