求和VII 倍增法求LCA + 树上差分

6744: 求和VII

时间限制: 2 Sec  内存限制: 256 MB
提交: 578  解决: 80
[提交] [状态] [讨论版] [命题人:admin]

题目描述

master对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k次方和,而且每次的k可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil并不会这么复杂的操作,你能帮他解决吗?

 

输入

第一行包含一个正整数n,表示树的节点数。
之后n−1行每行两个空格隔开的正整数i,j,表示树上的一条连接点i和点j的边。
之后一行一个正整数m,表示询问的数量。
之后每行三个空格隔开的正整数i,j,k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。
树的节点从1开始标号,其中1号节点为树的根。

 

输出

对于每组数据输出一行一个正整数表示取模后的结果。

 

样例输入

5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45

 

样例输出

33
503245989

 

提示

以下用d(i)表示第i个节点的深度。
对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为(25+15+05) mod 998244353=33,第二个询问答案为(245+145+245) mod 998244353=503245989。

对于30%的数据,1≤n,m≤100;
对于60%的数据,1≤n,m≤1000;
对于100%的数据,1≤n,m≤300000,1≤k≤50。

 

来源/分类

北京OI2018 

 

[提交] [状态]

因为k比较小,所以预处理k

然后用lca求树上最近公共祖先l

计算时用 val(u,k) + val(v,k) - val(l,k) - val(fa(l),k) 公式

代码:

#include <bits/stdc++.h>
using namespace std;
const  int maxn=3e5+100;
#define INF 0x3f3f3f;
const int mod=998244353;
typedef  long long ll;
ll val[maxn][60];
int nextt[maxn<<1];
int head[maxn<<1];
int s[maxn<<1];
int cnt = 0;
int fa[maxn][50];
ll dep[maxn<<1];
ll c[100];
bool vis[maxn<<1];

void add(int u,int v)
{
    s[++cnt] = v;
    nextt[cnt] = head[u];
    head[u] = cnt;
    //cnt++;
}

void dfs(int u)
{
    vis[u] = 1;
    for(int i=0; fa[u][i]; i++){
        fa[u][i+1] = fa[fa[u][i]][i];
    }
    for(int i=head[u]; i; i=nextt[i]){
        int v = s[i];
        if(vis[v])  continue;
        dep[v] = dep[u] + 1;
        fa[v][0] = u;
        for(int j=1; j<=50; j++){
            c[j] = c[j-1] * dep[v] % mod;
        }
        for(int j=1; j<=50; j++){
            val[v][j] = (c[j] + val[u][j]) % mod;
        }
        dfs(v);
    }
}

int lca(int u,int v)
{
    if(dep[u] < dep[v])
        swap(u,v);
    int d = dep[u] - dep[v];
    for(int i=0; d; i++){
        if(d & 1){
            u = fa[u][i];
        }
        d >>= 1;
    }
    if(u == v)
        return u;
    for(int i=20; i>=0; i--){
        if(fa[u][i] != fa[v][i]){
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}
int main()
{
    int n;
    scanf("%d",&n);
    int u,v;
    c[0] = 1;
    for(int i=1; i<n; i++){
        scanf("%d%d",&u,&v);
        add(u,v);
        add(v,u);
    }
    dfs(1);
    int m,w;
    scanf("%d",&m);
    while(m--){
        scanf("%d%d%d",&u,&v,&w);
        int l = lca(u,v);
        ll ans = (val[u][w] + val[v][w] - val[fa[l][0]][w] + mod - val[l][w] + mod) % mod;
        printf("%lld\n",ans);
    }
    return 0;
}

 

posted @ 2018-08-22 21:19  任小喵  阅读(193)  评论(0编辑  收藏  举报