浏览器标题切换
浏览器标题切换end
把博客园图标替换成自己的图标
把博客园图标替换成自己的图标end

BZOJ5293: [Bjoi2018]求和 树上差分

Description

master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k
 次方和,而且每次的k 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给
了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
 

Input

第一行包含一个正整数n ,表示树的节点数。
之后n-1 行每行两个空格隔开的正整数i,j ,表示树上的一条连接点i 和点j 的边。
之后一行一个正整数m ,表示询问的数量。
之后每行三个空格隔开的正整数i,j,k ,表示询问从点i 到点j 的路径上所有节点深度的k 次方和。
由于这个结果可能非常大,输出其对998244353 取模的结果。
树的节点从1 开始标号,其中1 号节点为树的根。
 

Output

对于每组数据输出一行一个正整数表示取模后的结果。
1≤n,m≤300000,1≤k≤50
 

Sample Input

5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45

Sample Output

33
503245989
说明
样例解释
以下用d(i) 表示第i 个节点的深度。
对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为(2^5 + 1^5 + 0^5) mod 998244353 = 33 
第二个询问答案为(2^45 + 1^45 + 2^45) mod 998244353 = 503245989。

Solution

因为这个k很小,所以就是个树上差分板子

预处理出深度之后弄个树上前缀和(对于不同的k暴力存就好,因为k最大50)

预处理复杂度是$O(nk)$的

然后每次询问的答案其实就是两个点的前缀和减掉他们的$LCA$的前缀和再加上他们的$LCA$的k次方

#include <bits/stdc++.h>

using namespace std ;

#define N 300010
#define mod 998244353
#define inf 0x3f3f3f3f
#define ll long long

int n , m ;
int dep[ N ] , siz[ N ] , top[ N ] , fa[ N ] ;
int head[ N ] , cnt ;
ll c[ N ][ 60 ] ; 
struct node {
    int to , nxt ;
} e[ N << 1 ] ;

void ins( int u , int v ) {
    e[ ++ cnt ].to = v ;
    e[ cnt ].nxt = head[ u ] ; 
    head[ u ] = cnt ;
}

void dfs1( int u ) {
    siz[ u ] = 1 ; 
    for( int i = head[ u ] ; i ; i = e[ i ].nxt ) {
        int v = e[ i ].to ;
        if( v == fa[ u ] ) continue ;
        fa[ v ] = u ; 
        dep[ v ] = dep[ u ] + 1 ;
        dfs1( v ) ;
        siz[ u ] += siz[ v ] ;
    }
}

void dfs2( int u , int topf ) {
    top[ u ] = topf ;
    int k = 0 ;
    for( int i = head[ u ] ; i ; i = e[ i ].nxt ) {
        if( e[ i ].to != fa[ u ] && siz[ e[ i ].to ] > siz[ k ] ) 
            k = e[ i ].to ; 
    }
    if( !k ) return ;
    dfs2( k , topf ) ;
    for( int i = head[ u ] ; i ; i = e[ i ].nxt ) {
        if( e[ i ].to != k && e[ i ].to != fa[ u ] ) 
            dfs2( e[ i ].to , e[ i ].to ) ;
    }
}

void dfs3( int u ) {
    ll x = 1 ;
    for( int k = 1 ; k <= 50 ; k ++ ) {
        x = 1ll * x * dep[ u ] % mod ;
        c[ u ][ k ] = ( 1ll * c[ fa[ u ] ][ k ] + x ) % mod ;
    }
    for( int i = head[ u ] ; i ; i = e[ i ].nxt ) {
        if( e[ i ].to == fa[ u ] ) continue ;
        dfs3( e[ i ].to ) ;
    }
}

int lca( int x , int y ) {
    while( top[ x ] != top[ y ] ) {
        if( dep[ top[ x ] ] < dep[ top[ y ] ] ) swap( x , y ) ;
        x = fa[ top[ x ] ] ;
    }
    if( dep[ x ] > dep[ y ] ) swap( x , y ) ;
    return x ;
}

ll power( ll a , ll b ) {
    ll ans = 1 , base = a ;
    while( b ) {
        if( b&1 ) ans = ans * base % mod ;
        base = base * base % mod ;
        b >>= 1 ;
    }
    return ans % mod ;
}

int main() {
    scanf( "%d" , &n ) ;
    for(int i = 1 , u , v ; i < n ; i ++ ) {
        scanf( "%d%d" ,  &u , &v ) ;
        ins( u ,  v ) ; ins( v , u ) ;
    }
    dfs1( 1 ) ; dfs2( 1 , 1 ) ; dfs3( 1 ) ;
    scanf( "%d" , &m ) ;
    while( m -- ) {
        int u , v , k ;
        scanf( "%d%d%d" , &u , &v , &k ) ;
        int l = lca( u , v ) ;
        printf( "%lld\n" , ( c[ u ][ k ] % mod + c[ v ][ k ] % mod - 2 * c[ l ][ k ] % mod + power( dep[ l ] , k ) + mod ) % mod ) ;
    }
    return 0 ; 
} 

 

posted @ 2018-10-25 21:14  henry_y  阅读(354)  评论(0编辑  收藏  举报