[BJOI2018]求和

题目描述

master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k 次方和,而且每次的k可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?

输入格式

第一行包含一个正整数n,表示树的节点数。

之后n-1行每行两个空格隔开的正整数i, j,表示树上的一条连接点ii 和点jj 的边。

之后一行一个正整数m,表示询问的数量。

之后每行三个空格隔开的正整数i, j, k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。

树的节点从1开始标号,其中1号节点为树的根。

输出格式

对于每组数据输出一行一个正整数表示取模后的结果。

数据范围

对于30%的数据,1≤n,n≤100

对于60% 的数据,1≤n,m≤1000。

对于100%的数据,1≤n,m≤300000,1≤k≤50。


考虑30分的做法。先处理出深度,设dep(x)表示x的深度,那么dep(x)=dep(fa[x])+1。对于每次询问,我们可以暴力从询问的两个点LCA走,每次求出当前点的深度的k次方并加入答案中即可。求LCA可以用Tarjan做到O(N+M),每次询问可以做到O(NK),总共就是O(MNK+N+M)≈O(MNK),如果用倍增或者树剖求LCA就是O(M(NK+logN))。这种三次方的级别也就只能过30分了......

考虑60分的做法。根据k的范围为1~50,我们可以预处理出每个点的深度的1~50次方,设idep(x,i)表示x的深度的i次方,那么idep(x,i)=idep(x,i-1) * dep(x),初始化idep(x,0)=1,这个复杂度为O(NK)。然后每次询问还是一个个往上走并加起来即可。若用Tarjan求LCA,总共就是O(NK+MN+N+M)≈O((K+M)N),否则为O((K+M)N+MlogN),降到了平方级别。

然后考虑100分做法。既然每次都要一个个加上,我们为什么不用树上前缀和直接加一加减一减得出答案呢?设val(x,i)为根节点到x的路径上所有点深度的k次方和,那么val(x,i)=val(fa[x],i)+idep(x,i)。时间复杂度还是O(NK)。接下来对于每个询问,首先求出两个点u,v的LCA,设LCA(u,v)=w,那么ans=val(u,k)+val(v,k)-val(w,k)-val(fa[w],k),最后的val(fa[w],k)是因为LCA只计算一次。若用Tarjan求LCA,每次询问的复杂度就是O(1),总共就是O(NK+M+N)≈O(NK+M),否则每次询问就是O(logN),总共就是O(NK+MlogN)。表面上还是平方级别,但已经相比之前60分做法优化掉了一项NM,剩下的NK和MlogN由于K≤50,可以看成是log级别的,所以就是O(NlogN)级别的算法,分肯定拿满。

附上代码,LCA是用树剖求的

#include<iostream>
#include<cstring>
#include<cstdio>
#define maxn 300001
#define maxk 51
#define p 998244353
using namespace std;

struct edge{
    int to,next;
    edge(){}
    edge(const int &_to,const int &_next){ to=_to,next=_next; }
}e[maxn<<1];
int head[maxn],k;

long long dep[maxn],idep[maxn][maxk],val[maxn][maxk];
int size[maxn],fa[maxn],son[maxn],top[maxn];
int n,m;

inline int read(){
    register int x(0),f(1); register char c(getchar());
    while(c<'0'||'9'<c){ if(c=='-') f=-1; c=getchar(); }
    while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return x*f;
}
inline void add(const int &u,const int &v){
    e[k]=edge(v,head[u]);
    head[u]=k++;
}

void dfs_getson(int u){
    size[u]=1;
    for(register int i=head[u];~i;i=e[i].next){
        int v=e[i].to;
        if(v==fa[u]) continue;
        dep[v]=dep[u]+1,fa[v]=u;
        idep[v][0]=1;
        for(register int j=1;j<maxk;j++) idep[v][j]=idep[v][j-1]*dep[v]%p;
        for(register int j=1;j<maxk;j++) val[v][j]=(val[u][j]+idep[v][j])%p;
        dfs_getson(v);
        size[u]+=size[v];
        if(size[v]>size[son[u]]) son[u]=v;
    }
}
void dfs_rewrite(int u,int tp){
    top[u]=tp;
    if(son[u]) dfs_rewrite(son[u],tp);
    for(register int i=head[u];~i;i=e[i].next){
        int v=e[i].to;
        if(v!=son[u]&&v!=fa[u]) dfs_rewrite(v,v);
    }
}

inline int lca(int u,int v){
    while(top[u]!=top[v]){
        if(dep[top[u]]>dep[top[v]]) swap(u,v);
        v=fa[top[v]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    return u;
}

int main(){
    memset(head,-1,sizeof head);
    n=read();
    for(register int i=1;i<n;i++){
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    dfs_getson(1);
    dfs_rewrite(1,1);
    
    m=read();
    while(m--){
        int u=read(),v=read(),t=read();
        int w=lca(u,v);
        printf("%lld\n",(val[u][t]+val[v][t]+(p<<1)-val[w][t]-val[fa[w]][t])%p);
    }
    return 0;
}

*最后的地方相减时要判负数......或者直接加上两个p得了

posted @ 2019-05-09 10:14  修电缆的建筑工  阅读(95)  评论(0编辑  收藏  举报