[BJOI2018]求和(树链剖分)

题目描述

master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 kkk 次方和,而且每次的 kkk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?

输入输出格式

输入格式:

第一行包含一个正整数 nnn ,表示树的节点数。

之后 n−1n-1n1 行每行两个空格隔开的正整数 i,ji, ji,j ,表示树上的一条连接点 iii 和点 jjj 的边。

之后一行一个正整数 mmm ,表示询问的数量。

之后每行三个空格隔开的正整数 i,j,ki, j, ki,j,k ,表示询问从点 iii 到点 jjj 的路径上所有节点深度的 kkk 次方和。由于这个结果可能非常大,输出其对 998244353998244353998244353 取模的结果。

树的节点从 111 开始标号,其中 111 号节点为树的根。

输出格式:

对于每组数据输出一行一个正整数表示取模后的结果。

输入输出样例

输入样例#1: 复制
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
输出样例#1: 复制
33
503245989

说明

样例解释

以下用 d(i)d (i)d(i) 表示第 iii 个节点的深度。

对于样例中的树,有 d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2 。

因此第一个询问答案为 (25+15+05) mod 998244353=33(2^5 + 1^5 + 0^5)\ mod\ 998244353 = 33(25+15+05) mod 998244353=33 ,第二个询问答案为 (245+145+245) mod 998244353=503245989(2^{45} + 1^{45} + 2^{45})\ mod\ 998244353 = 503245989(245+145+245) mod 998244353=503245989 。

数据范围

对于 30%30\%30% 的数据, 1≤n,m≤1001 \leq n,m \leq 1001n,m100 。

对于 60%60\%60% 的数据, 1≤n,m≤10001 \leq n,m \leq 10001n,m1000 。

对于 100%100\%100% 的数据, 1≤n,m≤300000,1≤k≤501 \leq n,m \leq 300000, 1 \leq k \leq 501n,m300000,1k50 。

另外存在5个不计分的hack数据

提示

数据规模较大,请注意使用较快速的输入输出方式。


先抱怨一波好吧,
  bjoi多么良心,出了那么多数据结构,mmp,再看看hnoi,woc,真的不要脸
  所以我这种菜鸡就只能打一打树剖这种傻逼题来泄愤了
 
题解:
 
  1)50次方,就是把每一个次方的值建一个线段树,后面就可以直接询问了
  2)取mod是一个很玄学的东西
  3)注意一下空间
  4)所有树剖题目都是一个思路,不难想但是难调
 

#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int N=300001;
const int mod=998244353;
int n,m,dep[51][N],fa[N],size[N],son[N],maxx,fid[N];
int l[N],tot,head[N*2],num,top[N];
struct node{
    int to,next;
}e[N*2];
struct tr{
    int sum[N*4]; 
}t[51];
int read()
{
    int x=0,w=1;char ch=getchar();
    while(ch>'9'||ch<'0'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
    return x*w;
}

void add(int from,int to)
{    
    num++;
    e[num].to=to;
    e[num].next=head[from];
    head[from]=num;
}

void dfs1(int x)
{
    size[x]=1;
    for(int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(!dep[1][v]&&v!=1)
        {    
            dep[1][v]=dep[1][x]+1;fa[v]=x;
            maxx=max(maxx,dep[1][v]);
            dfs1(v);
            size[x]+=size[v];
            if(size[v]>size[son[x]])son[x]=v;
        }
    }
}

void dfs2(int x,int tp)
{
    l[x]=++tot;top[x]=tp;fid[tot]=x;
    if(son[x])dfs2(son[x],tp);
    for(int i=head[x];i;i=e[i].next)
    {
        int v=e[i].to;
        if(v!=fa[x]&&v!=son[x])
        dfs2(v,v);
    }
}

void init()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1);
    dep[1][1]=0;fa[1]=1;
    dfs2(1,1);
}

void build(int rt,int root,int left,int right)
{
    
    if(left==right){
        t[rt].sum[root]=dep[rt][fid[left]];
        return ;
    }
    int mid=(left+right)>>1;
    build(rt,root<<1,left,mid);
    build(rt,root<<1|1,mid+1,right);
    t[rt].sum[root]=(ll)(t[rt].sum[root<<1]+t[rt].sum[root<<1|1])%mod;
    if(t[rt].sum[root]>=mod)t[rt].sum[root]-=mod;
}

int query(int rt,int root,int left,int right,int L,int R)
{
    if(left>R||right<L)return 0;
    if(L<=left&&right<=R)return t[rt].sum[root]%mod;
    int mid=(left+right)>>1;
    ll a=0,b=0;
    if(mid>=L)a=query(rt,root<<1,left,mid,L,R)%mod;
    if(mid<R) b=query(rt,root<<1|1,mid+1,right,L,R)%mod;
    return (a+b)%mod;
}

int cal(int rt,int x,int y)
{
    ll ans=0;
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx]<dep[fy]){swap(fx,fy);swap(x,y);}
        ans+=query(rt,1,1,n,l[fx],l[x]);
        ans%=mod;
     x=fa[fx],fx=top[x];
    }
    if(l[x]>l[y])swap(x,y);
    ans+=query(rt,1,1,n,l[x],l[y]);
    ans%=mod;
    return ans;
}

void perp()
{    
    for(int i=2;i<=50;i++)
        for(int j=1;j<=n;j++)
        {dep[i][j]=((ll)dep[i-1][j]%mod*(ll)dep[1][j]%mod)%mod;}
    for(int i=1;i<=50;i++)
    build(i,1,1,n);
}

void solve()
{
    m=read();
    for(int i=1;i<=m;i++)
        {int x=read(),y=read(),k=read();
        printf("%d\n",cal(k,x,y));}
}

int main()
{
    init();
    perp();
    solve();
    return 0;
}

 

posted @ 2018-04-25 22:40  Epiphyllum_thief  阅读(320)  评论(0编辑  收藏  举报