自己的题 - 时空树 [树链剖分]

题意 - 时空树

这个是我做了[BJOI2018]求和后想到的一个改编版,不同之处在于我将深度改成了权值。

因为每层树的结构一样,且权值是有规律的,并且k也只有50,所以最朴素的做法就是建一棵树,把它用树链剖分剖了之后,动态开点,建50棵线段树,暴力维护即可。可是这样空间以及常数会及其的大。所以有以下两种优化做法。

  • hdxrieLCA做法。

因为只有求和,所以对于每个节点开一个大小为50的数组A,求出该点到根节点的第ki层的权值和,询问u>v时,答案就为A[u]+A[v]A[LCA]A[fa[LCA]],空间复杂度为大概点数乘以55左右(加上建树等数组),时间复杂度为O(n×50+nlogn)

  • 我的前缀和做法。

这个常数可能比hdxrie的大。我是剖了树之后,维护对于每条剖出来的链的前缀和,交界处暴力算一下即可。空间差不多,时间复杂度为O(n×50+nlog2n)但不知为啥跑的比hdxrie快一点

  • 我的做法:

code

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int M=4e5+10,K=52;
const ll mod=201806147ll;
int n,m,KW;
ll val[M];
ll pow(ll a,int k){
    ll ans=1ll;
    for(;k;k>>=1,(a*=a)%=mod){if(k&1) (ans*=a)%=mod;}
    return ans%mod;
}
struct ss{
    int to,last;
    ss(int a=0,int b=0)
    :to(a),last(b){}
}g[M<<1];
int head[M],cnt;
void add(int a,int b){
    g[++cnt]=ss(b,head[a]);head[a]=cnt;
    g[++cnt]=ss(a,head[b]);head[b]=cnt;
}
int f[M],sze[M],son[M],top[M],dep[M],num[M],tim;
ll sum[K][M];
void dfs1(int a){
    sze[a]=1;
    for(int i=head[a];i;i=g[i].last){
        if(g[i].to==f[a]) continue;
        dep[g[i].to]=dep[a]+1;
        f[g[i].to]=a;
        dfs1(g[i].to);
        sze[a]+=sze[g[i].to];
        if(!son[a]||sze[son[a]]<sze[g[i].to])
        son[a]=g[i].to;
    }
}

void dfs2(int a,int b,bool iso){
    top[a]=b;num[a]=++tim;
    if(iso){
    sum[0][tim]=(sum[0][tim-1]+(val[a]!=0))%mod;
    for(int i=1;i<KW;i++){sum[i][tim]=(sum[i][tim-1]+pow(val[a],i))%mod;}
    }else{
    sum[0][tim]=(val[a]!=0);
    for(int i=1;i<KW;i++){sum[i][tim]=pow(val[a],i);}
    }
    if(!son[a])return;
    dfs2(son[a],b,1);
    for(int i=head[a];i;i=g[i].last){
        if(g[i].to==f[a]||g[i].to==son[a]) continue;
        dfs2(g[i].to,g[i].to,0);
    }
}

ll ask(int a,int b,int k){
    ll ans=0;
    while(top[a]!=top[b]){
        if(dep[top[a]]<dep[top[b]]) swap(a,b);
        if(a==top[a]) ans=(ans+sum[k][num[a]])%mod;
        else ans=ans+(((sum[k][num[a]]-sum[k][num[top[a]]])%mod+mod)%mod+pow(val[top[a]],k)+mod)%mod;
        a=f[top[a]];
    }
    if(dep[a]>dep[b]) swap(a,b);
    if(a==b) ans=(ans+pow(val[a],k))%mod;
    else ans=ans+(((sum[k][num[b]]-sum[k][num[a]])%mod+mod)%mod+pow(val[a],k)+mod)%mod;
    return ans%mod;
}
int a,b,k;
int main(){
    scanf("%d%d%d",&n,&m,&KW);
    for(int i=1;i<=n;i++)scanf("%lld",&val[i]);
    for(int i=1;i<n;i++){
        scanf("%d%d",&a,&b);
        add(a,b);
    }
    dfs1(1);dfs2(1,1,0);
    for(int i=1;i<=m;i++){
        scanf("%d%d%d",&k,&a,&b);
        printf("%lld\n",(ask(a,b,k)+mod)%mod);
    }
    return 0;
}
  • hdxrie的做法
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
const int N=300010;
const LL mod=201806147;
LL val[N][51];
int n,m,p,k,l1,l2,l3,head[N],deep[N],size[N],son[N],fa[N],top[N];
struct line
{
    int to,bef;
}ln[N*2];
void add(int a,int b)
{
    ln[++k].to=b;ln[k].bef=head[a];head[a]=k;
    ln[++k].to=a;ln[k].bef=head[b];head[b]=k;
}
void dfs1(int a,int f)
{
    fa[a]=f;deep[a]=deep[f]+1;
    for(int i=0;i<=p;i++)
     val[a][i]=(val[a][i]+val[f][i])%mod;
    size[a]=1;int maxf=0;
    for(int i=head[a];i;i=ln[i].bef)
     if(ln[i].to!=f)
     {
        dfs1(ln[i].to,a);
        size[a]+=size[ln[i].to];
        if(maxf<size[ln[i].to])
         maxf=size[ln[i].to],son[a]=ln[i].to;
     }
}
void dfs2(int a,int up)
{
    top[a]=up;if(son[a])dfs2(son[a],up);
    for(int i=head[a];i;i=ln[i].bef)
     if(ln[i].to!=fa[a]&&ln[i].to!=son[a])
      dfs2(ln[i].to,ln[i].to);
}
int getpos(int a,int b)
{
    for(;top[a]!=top[b];a=fa[top[a]])
     if(deep[top[a]]<deep[top[b]])swap(a,b);
    if(deep[a]>deep[b])return b;return a;
}
int main()
{
    scanf("%d%d%d",&n,&m,&p);
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&val[i][1]);val[i][0]=1;
        for(int j=2;j<=p;j++)
         val[i][j]=(val[i][j-1]*val[i][1])%mod;
    }
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&l1,&l2);
        add(l1,l2);
    }
    dfs1(1,0);dfs2(1,1);
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d",&l1,&l2,&l3);
        int pos=getpos(l2,l3);
        printf("%lld\n",((val[l2][l1]+val[l3][l1]-val[pos][l1]-val[fa[pos]][l1])%mod+mod)%mod);
    }
    return 0;
}
posted @ 2018-06-16 21:58  VictoryCzt  阅读(142)  评论(0编辑  收藏  举报