bzoj5293: [Bjoi2018]求和

题目链接

bzoj5293: [Bjoi2018]求和

题解

暴力
对于lca为1的好坑啊....

代码

#include<cmath> 
#include<cstdio> 
#include<algorithm> 
inline int read() { 
 	int x = 0,f = 1; 
    char c  = getchar(); 
    while(c < '0' || c > '9')c = getchar(); 
    while(c <= '9' && c >= '0')x = x * 10 + c - '0',c = getchar(); 
    return x * f ; 			 
} 				
#define int long long 			
const int mod = 998244353;  
const int maxn = 300007;    
struct node { 
    int v,next; 
} edge[maxn << 1]; 
int head[maxn], num = 0 ; 
inline void add_edge(int u,int v) { 
    edge[++ num].v = v; edge[num].next = head[u];head[u] = num; 
} 

int P[maxn][57]; 
int n,m; 
int deep[maxn] = {-1},sum[maxn][57]; 
int dad[maxn][27]; 
void dfs(int x,int fa) { 
    for(int i = 0; dad[x][i]; ++ i) dad[x][i + 1] = dad[dad[x][i]][i]; 
    for(int i = head[x];i ;i = edge[i].next) { 
        int v = edge[i].v; 
        if(v == fa)continue; 
        deep[v] = deep[x] + 1; dad[v][0] = x; 
        dfs(v,x); 
    } 
} 
int k ; 
int lca(int x,int y) { 
    if(deep[x] > deep[y]) std::swap(x,y); 
    for(int i = k;i >= 0;-- i) if(deep[dad[y][i]] >= deep[x]) y = dad[y][i]; 
    if(y == x) return x; 
    for(int i = k;i >= 0;-- i) if(dad[y][i] != dad[x][i]) y = dad[y][i],x = dad[x][i]; 
    return dad[x][0]; 
} 
main() { 
    n = read(); 
    for(int i = 1;i <= n;++ i) { 
        P[i][0] = 1,P[i][1] = i; 
        for(int j = 2;j <= 50;++ j)  
            P[i][j] = (long long) P[i][j - 1] * i % mod; 
    } 
    for(int i = 1;i <= n;++ i) for(int j = 1;j <= 50;++ j) P[i][j] = (P[i][j] + P[i - 1][j]) % mod; 
    k = log2(n) + 1; 
    for(int u,v, i = 1;i < n;++ i) { 
        u = read(), v = read(); 
        add_edge(u,v); add_edge(v,u); 
    } 
    int m = read(); 
    dfs(1,0); 
    while(m --) { 
        int a = read(),b = read(),k = read(); 
        int LCA = lca(a,b); 
        if(LCA == 1) { 
            printf("%lld\n",(P[deep[a]][k] + P[deep[b]][k]) % mod); 
        } 
        else if(LCA == b || LCA == a) { 	
            if(LCA == a)  std::swap(a,b);  
                printf("%lld\n",(P[deep[a]][k] - P[deep[b] - 1][k] + mod ) % mod);  
        } 
        else printf("%lld\n",((P[deep[a]][k] + P[deep[b]][k] - P[deep[LCA]][k] - P[deep[LCA] - 1][k]) % mod + mod) % mod); 
    } 
    return 0; 
}  
  
posted @ 2018-08-13 06:06  zzzzx  阅读(157)  评论(0编辑  收藏  举报