线段树

线段树的各类应用

  1. 直接用线段树维护信息。
  2. 线段树合并维护信息。
  3. 线段树结合势能分析
  4. 线段树分治

应用1:直接用线段树维护信息

例题1(所谓热身题):P7735 [NOI2021] 轻重边

题意

有一棵大小为 \(n\) 的树,现在有 \(m\) 个操作:

  1. 给定两个点 \(a\)\(b\),首先对于 \(a\)\(b\) 路径上的所有点 \(x\)(包含 \(a\)\(b\)),你要将与 \(x\) 相连的所有边变为轻边。然后再将 \(a\)\(b\) 路径上包含的所有边变为重边。(起初树上的边均为轻边)
  2. 给定两个点 \(a\)\(b\),你需要计算当前 \(a\)\(b\) 的路径上一共包含多少条重边。

\(t\) 组数据。\(t\le 3;n,m\le 10^5\)

解法

直接维护边的状态显然很难,并且将一条路径上每个点的边进行修改更难。(可能有高级算法)

有一种别样的做法:考虑给第 \(i\) 个点赋一个点权 \(a_i\),然后每条边是否为重边即为其所连的两个点 \(a\) 值是否相同。在进行如此转化后,两种操作均能使用树剖实现:(起初 \(\forall i,a_i=i\),需要保证起初两两 \(a\) 值不同)对于第 \(j\) 个操作,若其为操作 1,则将路径上的点的 \(a\) 值均赋为之前未曾出现过的 \(a\) 值(如 \(q+j\));否则直接查询路径上有多少对相邻的点的 \(a\) 值相同即可。

代码

注意路径上两段的拼接方式。

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxn=100010;
int n,q,t,i,u,v,o,a,b,tim,tot;
int h[maxn],fa[maxn],son[maxn],dep[maxn];
int top[maxn],dfn[maxn],rnk[maxn];
struct edge{int to,nxt;}E[maxn<<1];
int dfs1(int p,int f){
    int lp,to,mx=0,mv=0,nw,sz=1;
    for(lp=h[p];lp;lp=E[lp].nxt){
        to=E[lp].to;
        if(to==f) continue;
        fa[to]=p;
        dep[to]=dep[p]+1;
        nw=dfs1(to,p);
        sz+=nw;
        if(nw>mx){
            mx=nw;
            mv=to;
        }
    }
    son[p]=mv;
    return sz;
}
void dfs2(int p,int f,int t){
    int lp,to;
    top[p]=t;
    dfn[p]=++tim;
    rnk[tim]=p;
    if(!son[p]) return;
    dfs2(son[p],p,t);
    for(lp=h[p];lp;lp=E[lp].nxt){
        to=E[lp].to;
        if(to==f) continue;
        if(to==son[p]) continue;
        dfs2(to,p,to);
    }
}
int l[maxn<<2],r[maxn<<2],m[maxn<<2],len[maxn<<2];
struct seg{
    int lv,rv,tv,sum;
    seg():lv(0),rv(0),tv(0),sum(0){}
    seg(int x,int y,int z,int w):lv(x),rv(y),tv(z),sum(w){}
    inline seg operator +(seg rt){return seg(lv,rt.rv,0,sum+rt.sum+(rv==rt.lv));}
}tr[maxn<<2],lp,rp;
#define ls(p) p<<1
#define rs(p) p<<1|1 
#define lv(p) tr[p].lv
#define rv(p) tr[p].rv
#define tv(p) tr[p].tv
#define sum(p) tr[p].sum
void build(int p,int lt,int rt){
    l[p]=lt;r[p]=rt;
    m[p]=(lt+rt)>>1;
    len[p]=rt-lt+1;
    sum(p)=0;
    if(lt==rt){
        lv(p)=rv(p)=rnk[lt];
        return;
    }
    build(ls(p),lt,m[p]);
    build(rs(p),m[p]+1,rt);
    tr[p]=tr[ls(p)]+tr[rs(p)];
}
inline void pushdown(int p){
    if(!tv(p)) return;
    sum(ls(p))=len[ls(p)]-1;
    sum(rs(p))=len[rs(p)]-1;
    lv(ls(p))=rv(ls(p))=
    lv(rs(p))=rv(rs(p))=
    tv(ls(p))=tv(rs(p))=tv(p);
    tv(p)=0;
}
void change(int p,int lt,int rt){
    if(lt<=l[p]&&rt>=r[p]){
        lv(p)=rv(p)=tv(p)=tim;
        sum(p)=len[p]-1;
        return;
    }
    pushdown(p);
    if(lt<=m[p]) change(ls(p),lt,rt);
    if(rt>m[p]) change(rs(p),lt,rt);
    tr[p]=tr[ls(p)]+tr[rs(p)];
}
void query(int p,int lt,int rt,seg &qt){
    if(lt<=l[p]&&rt>=r[p]){
        qt=tr[p]+qt;
        return;
    }
    pushdown(p);
    if(rt>m[p]) query(rs(p),lt,rt,qt);
    if(lt<=m[p]) query(ls(p),lt,rt,qt);
}
int main(){
    scanf("%d",&t); 
    while(t--){
        scanf("%d%d",&n,&q);
        for(i=1;i<n;++i){
            scanf("%d%d",&u,&v);
            E[++tot]={v,h[u]};h[u]=tot;
            E[++tot]={u,h[v]};h[v]=tot;
        }
        dep[1]=1;
        dfs1(1,0);
        dfs2(1,0,1);
        build(1,1,n);
        while(q--){
            scanf("%d%d%d",&o,&a,&b);
            if(o==1){
                ++tim;
                while(top[a]!=top[b]){
                    if(dep[top[a]]>dep[top[b]]){
                        change(1,dfn[top[a]],dfn[a]);
                        a=fa[top[a]];
                    }
                    else{
                        change(1,dfn[top[b]],dfn[b]);
                        b=fa[top[b]];
                    }
                }
                if(dfn[a]>dfn[b]) swap(a,b);
                change(1,dfn[a],dfn[b]);
            }
            else{
                lp=seg(-1,-2,0,0);
                rp=seg(-3,-4,0,0);
                while(top[a]!=top[b]){
                    if(dep[top[a]]>dep[top[b]]){
                        query(1,dfn[top[a]],dfn[a],lp);
                        a=fa[top[a]];
                    }
                    else{
                        query(1,dfn[top[b]],dfn[b],rp);
                        b=fa[top[b]];
                    }
                }
                if(dfn[a]<dfn[b]) query(1,dfn[a],dfn[b],rp);
                else query(1,dfn[b],dfn[a],lp);
                swap(rp.lv,rp.rv);
                printf("%d\n",(rp+lp).sum);
            }
        }
        for(i=1;i<=n;++i) h[i]=son[i]=top[i]=0;
        tim=tot=0;
    }
    return 0;
}

应用2:线段树合并维护信息

例题2:P5327 [ZJOI2019] 语言

题意

有一棵大小为 \(n\) 的树,且给定树上的 \(m\) 条路径。两点间连通当且仅当有 同一条路径 经过这两个点。求连通的点对数。\(n,m\le 10^5\)。(可加强到 \(n,m\le10^6\)。)

解法

考虑每个点能够和哪些点连通。可以发现如果知道了某个点所在的路径,则只需要维护每个点所在的每条路径的端点,每个端点在这棵树上的虚树就会包括这个点能到达的所有点,且这棵虚树一定会包括这个节点本身。

考虑维护这些端点。可以使用 树上差分 + 权值线段树合并 以维护每个点对应的所有路径的所有端点。同时需要考虑如何能得到虚树的大小。有一个经典结论:将它们按照 dfn 序排序,每个点与排序后的下一个点(令最后一个点在序列中的下一个点为第一个点)的所有简单路径中,这棵虚树的每条边刚好均被包括了两次。

故而直接维护这些点的权值线段树,然后进行线段树合并即可。代码比树剖简单。至于要做到 \(O(n\log n)\),可以使用欧拉序 + st 表求 LCA 实现单次 \(O(1)\) 查询。

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxl=22;
const int maxn=100010;
const int maxt=200010;
int n,m,i,u,v,j,t,dt=1,tot,tim1,tim2;
long long ans;
int h[maxn],nd[maxl][maxt],dep[maxn];
int fa[maxn],dfn1[maxn],dfn2[maxn];
int lb[maxt],stk[30],rnk[maxn];
struct edge{int to,nxt;}E[maxt];
void dfs1(int p,int f,int d){
    dfn1[p]=++tim1;
    dfn2[p]=++tim2;
    dep[p]=d;
    rnk[tim2]=p;
    nd[0][tim1]=p;
    int lp,to;
    for(lp=h[p];lp;lp=E[lp].nxt){
        to=E[lp].to;
        if(to==f) continue;
        fa[to]=p;
        dfs1(to,p,d+1);
        nd[0][++tim1]=p;
    }
}
inline int lca(int x,int y){
    x=dfn1[x]; y=dfn1[y];
    if(x>y) swap(x,y); dt=lb[y-x+1];
    int lt=nd[dt][x],rt=nd[dt][y-(1<<dt)+1];
    if(dep[lt]<dep[rt]) return lt;
    return rt;
}
inline int dis(int x,int y){return dep[x]+dep[y]-(dep[lca(x,y)]<<1);}
struct node{
    int ls,rs,lt,rt,sum;
    long long psum;
}tr[maxn*maxl*10];
#define ls(p) tr[p].ls
#define rs(p) tr[p].rs
#define lt(p) tr[p].lt
#define rt(p) tr[p].rt 
#define sum(p) tr[p].sum
#define psum(p) tr[p].psum
inline void pushup(int p){
    lt(p)=lt(ls(p)); if(!lt(p)) lt(p)=lt(rs(p));
    rt(p)=rt(rs(p)); if(!rt(p)) rt(p)=rt(ls(p));
    psum(p)=psum(ls(p))+psum(rs(p));
    if(rt(ls(p))&&lt(rs(p))) psum(p)+=dis(rt(ls(p)),lt(rs(p)));
}
inline void insert(int pt,int p,int d){
    int ps=dfn2[p];
    int lt=1,rt=n,mt;
    while(lt<rt){
        stk[++t]=pt;
        mt=(lt+rt)>>1;
        if(ps<=mt){
            if(!ls(pt)) ls(pt)=++tot;
            rt=mt; pt=ls(pt);
        }
        else{
            if(!rs(pt)) rs(pt)=++tot;
            lt=mt+1; pt=rs(pt);
        }
    }
    sum(pt)+=d; lt(pt)=rt(pt)=0;
    if(sum(pt)>0) lt(pt)=rt(pt)=p;
    while(t) pushup(stk[t--]);
}
int merge(int x,int y,int lt,int rt){
    if(!(x&&y)) return x|y;
    if(lt==rt){
        sum(x)+=sum(y); lt(x)=rt(x)=0;
        if(sum(x)>0) lt(x)=rt(x)=rnk[lt];
        return x;
    }
    int mt=(lt+rt)>>1;
    ls(x)=merge(ls(x),ls(y),lt,mt);
    rs(x)=merge(rs(x),rs(y),mt+1,rt);
    pushup(x); return x;
}
void dfs2(int p,int f){
    int lp,to;
    for(lp=h[p];lp;lp=E[lp].nxt){
        to=E[lp].to;
        if(to==f) continue;
        dfs2(to,p);
        merge(p,to,1,n);
    }
    ans+=psum(p)+dis(lt(p),rt(p));
}
int main(){
    for(i=2;i<maxt;++i) lb[i]=lb[i>>1]+1;
    scanf("%d%d",&n,&m);
    for(i=1;i<n;++i){
        scanf("%d%d",&u,&v);
        E[++tot]={v,h[u]};h[u]=tot;
        E[++tot]={u,h[v]};h[v]=tot;
    }
    tot=n;
    dfs1(1,0,1);
    for(i=1;i<maxl;++i,dt<<=1){
        for(j=dt+1;j<=tim1;++j){
            u=nd[i-1][j]; v=nd[i-1][j-dt];
            if(dep[u]<dep[v]) nd[i][j-dt]=u;
            else nd[i][j-dt]=v;
        }
    }
    while(m--){
        scanf("%d%d",&u,&v);
        insert(u,u,1);insert(u,v,1);
        insert(v,u,1);insert(v,v,1);
        i=lca(u,v);
        insert(i,u,-1);insert(i,v,-1);
        if(fa[i]){
            insert(fa[i],u,-1);
            insert(fa[i],v,-1);
        }
    }
    dfs2(1,0);
    printf("%lld",ans>>2);
    return 0;
}

例题3:P6773 [NOI2020] 命运

题意

给定一棵大小为 \(n\) 的树,边权可以是 \(0\)\(1\)。同时给定 \(m\) 条限制,每条限制的形式是树上的一条 祖先后代链 中至少有一条边权为 \(1\)。求有多少种赋边权的方案使得所有限制都被满足。\(1\le n,m\le 5\times 10^5\)

解法

考虑 dfs 整棵树,同时确定知道某条边赋边权为 \(1\) 时会造成怎样的影响。显然在自下而上确定边权时,在包括某条边 \(u\rightarrow v\) 的,且尚未包括边权为 \(1\) 的祖先后代链中,若此边边权为 \(1\) 时,则不需要考虑这所有的限制;而若这些祖先后代链中,最深的祖先深度为 \(dep_u\)(也就是最深的祖先为 \(u\)),则必须要将这条边边权赋值为 \(1\)

故而可以考虑用树形 dp 维护有多少种在 \(u\) 的子树内赋边权的方案,满足所有未满足的限制中,最深的祖先深度为 \(i\),记方案数为 \(dp_{u,i}\)。考虑如何转移:

对于 \(u\) 节点,在合并其子节点 \(v\) 时,若 \(u\rightarrow v\) 边边权为 \(1\) 时,则下端点在 \(v\) 子树内,上端点在 \(u\) 及其祖先的所有限制均将满足,有 \(dp_{u,i}\leftarrow dp_{u,i}\sum_{j=0}^{dep_u}dp_{v,j}\)。若 \(u\rightarrow v\) 边边权为 \(0\),则仍然需要考虑后代在 \(v\) 子树的祖先后代链,有 \(dp_{u,i}\leftarrow \sum_{j=0}^i dp_{u,i}dp_{v,j}+\sum_{j=0}^idp_{u,j}dp_{v,i}-dp_{u,i}dp_{v,i}\),也就是 \(dp_{u,i}\leftarrow dp_{u,i}\sum_{j=0}^idp_{v,j}+dp_{v,i}\sum_{j=0}^{i-1}dp_{u,j}\)

如果直接这样转移,则每合并一个子节点,均要枚举当前 \(u\) 节点深度范围的 dp 值,时空复杂度为 \(O(n^2)\),使用 std::vector 优化空间可得 \(64\ pts\)。(如果只考虑 \(m\) 条限制涉及的祖先后代链的端点,则可以建立这些点组成的虚树,将某条边边权赋为 \(1\) 等效于在虚树上等效于原树 \(k\) 条边的边使用 \(2^k-1\) 种方案赋值,可多得 \(8\ pts\))记 \(S_{u,i}=\sum_{j=0}^i dp_{u,j}\),则转移形式变成了 \(dp_{u,i}\leftarrow dp_{u,i}(S_{v,dep_u}+S_{v,i})+dp_{v,i}S_{u,i-1}\)。(特别地,\(\forall u,S_{u,-1}=0\)

考虑 dp 的初值。显然在 \(u\) 子树内在怎样赋边权,始终不会影响后代为 \(u\) 的祖先后代链。故记后代为 \(u\) 的所有祖先后代链的最深祖先为 \(f_u\),则 \(dp_{u,dep_{f_u}}=1\)(如果 \(u\) 为叶节点,则此为唯一 dp 值)。特别地,若无满足要求的链,则 \(f_u=dep_{f_u}=0\)。同时考虑 \(dp_u\) 中的每个值均需要由 \(u\) 的子节点的 有意义的 dp 值 转移而来,导致我们只需要考虑目前节点和其子节点的所有 dp 值而不用考虑其他 dp 值。故而考虑线段树合并维护 dp 转移。

具体地,在树上每个节点建线段树以维护当前的 dp 值和某一段 dp 值的和,在线段树合并的同时维护目前的 dp 前缀和。具体只需要在线段树合并到两棵线段树中一者节点不存在、或合并到叶节点再累加前缀和。操作就是将当前节点对应的线段树的某个区间乘上一个数,需要 pushdown,和普通线段树同样只需要在两棵线段树中同时遍历到的非叶子节点处 pushdown 即可。

像这样的虽然 dp 值多但是真正有用的 dp 值很少/值域很小的 dp,可以使用 整体 dp 的方式优化,具体实现和上述相似。

代码比较简洁易调,甚至不需要专门去过大样例。

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxl=20;
const int maxn=500010;
const int md=998244353;
int n,m,i,u,v,tot;
int h[maxn],pt[maxn],dep[maxn];
struct edge{int to,nxt;}E[maxn<<1],Pt[maxn];
struct node{int ls,rs,sum,mul;}tr[maxn*maxl];
#define ls(p) tr[p].ls
#define rs(p) tr[p].rs
#define sum(p) tr[p].sum
#define mul(p) tr[p].mul 
inline void pushup(int p){
    sum(p)=sum(ls(p))+sum(rs(p));
    if(sum(p)>=md) sum(p)-=md;
}
inline void pushdown(int p){
    if(mul(p)==1) return;
    mul(ls(p))=(1LL*mul(ls(p))*mul(p))%md;
    mul(rs(p))=(1LL*mul(rs(p))*mul(p))%md;
    sum(ls(p))=(1LL*sum(ls(p))*mul(p))%md;
    sum(rs(p))=(1LL*sum(rs(p))*mul(p))%md;
    mul(p)=1;
}
void query(int p,int lt,int rt,int rp,int &ret){
    if(rt<=rp){
        ret+=sum(p);
        if(ret>=md) ret-=md;
        return;
    }
    pushdown(p);
    int mt=(lt+rt)>>1;
    if(ls(p)) query(ls(p),lt,mt,rp,ret);
    if(rp>mt&&rs(p)) query(rs(p),mt+1,rt,rp,ret);
}
void merge(int &x,int y,int lt,int rt,int &s1,int &s2){
    if(!(x||y)) return;
    if(!x){
        s1+=sum(y); if(s1>=md) s1-=md;
        sum(y)=(1LL*sum(y)*s2)%md;
        mul(y)=(1LL*mul(y)*s2)%md;
        x=y;
    }
    else if(!y){
        s2+=sum(x); if(s2>=md) s2-=md;
        sum(x)=(1LL*sum(x)*s1)%md;
        mul(x)=(1LL*mul(x)*s1)%md;
    }
    else if(lt==rt){
        s1+=sum(y); if(s1>=md) s1-=md; int s3=sum(x);
        sum(x)=(((1LL*sum(x)*s1)%md)+((1LL*sum(y)*s2)%md))%md;
        s2+=s3; if(s2>=md) s2-=md;
    }
    else{
        pushdown(x); pushdown(y);
        int mt=(lt+rt)>>1;
        merge(ls(x),ls(y),lt,mt,s1,s2);
        merge(rs(x),rs(y),mt+1,rt,s1,s2);
        pushup(x);
    }
}
void dfs(int p,int f,int d){
    dep[p]=d;
    int lp,to=0,mp;
    for(lp=pt[p];lp;lp=Pt[lp].nxt){
        mp=Pt[lp].to;
        if(dep[mp]>to) to=dep[mp]; 
    }
    int lt=0,rt=n,mt,pt=p,s1,s2;
    while(lt<rt){
        sum(pt)=mul(pt)=1;
        mt=(lt+rt)>>1;
        ++tot;
        if(to<=mt){
            ls(pt)=tot;
            rt=mt;
        }
        else{
            rs(pt)=tot;
            lt=mt+1;
        }
        pt=tot;
    }
    sum(pt)=mul(pt)=1;
    for(lp=h[p];lp;lp=E[lp].nxt){
        to=E[lp].to;
        if(to==f) continue;
        dfs(to,p,d+1);
        s1=s2=0;
        query(to,0,n,d,s1);
        merge(p,to,0,n,s1,s2);
    }
}
int main(){
    scanf("%d",&n);
    for(i=1;i<n;++i){
        scanf("%d%d",&u,&v);
        E[++tot]={v,h[u]};h[u]=tot;
        E[++tot]={u,h[v]};h[v]=tot;
    }
    tot=n;
    scanf("%d",&m);
    for(i=1;i<=m;++i){
        scanf("%d%d",&u,&v);
        Pt[i]={u,pt[v]};pt[v]=i;
    }
    dfs(1,0,1);
    u=1;
    while(ls(u)) u=ls(u);
    printf("%d",sum(u));
    return 0;
}

拓展题1:P5298 [PKUWC2018]Minimax

题意

有一棵大小为 \(n\) 的根为 \(1\) 号节点的树,每个节点最多有两个子节点。对于叶子节点 \(i\),其权值为 \(w_i\),且保证 \(w_i\) 各不相同;而对于非叶子节点 \(i\),其权值会有 \(P_i\) 的概率成为子节点权值中最大值,有 \(1-p_i\) 的概率成为子节点权值中最小值。

\(1\) 号点的权值的所有可能的取值有 \(m\) 种,\(\boldsymbol i\) 小的权值\(v_i\),概率为 \(d_i\),求 \(\sum_{i=1}^m iv_id_i^2\),对 \(998244353\) 取模。

\(n\le 3\times 10^5,1\le w_i\le 10^9,0<P_i<1\)

解法

(这不就是 P6773 [NOI2020] 命运 吗?)

(但是为什么要把最后概率大小关系求出来?)

(哦,是权值大小关系,那没事了)

我们在不讨论空间等条件时,可以有如下 dp 方程:

\(dp_{u,x}\) 表示 \(u\) 节点 权值为 \(\boldsymbol x\) 时的概率,则转移如下:

  • \(u\) 节点为叶节点,则显然 \(dp_{u,w_u}=1;\forall p\ne w_u,dp_{u,p}=0\)

  • \(u\) 节点只有一个子节点 \(son_u\),则 \(\forall p,dp_{u,p}=dp_{son_u,p}\)\(u\) 节点的取值只能是其子节点取值。

  • \(u\) 节点有两个子节点 \(ls_u,rs_u\),则

    \[\begin{aligned}\forall p,dp_{u,p}&=P_u\left(\left(dp_{rs_u,p}\sum_{i=0}^p dp_{ls_u,i}\right)+\left(dp_{ls_u,p}\sum_{i=0}^p dp_{rs_u,i}\right)\right)+\left(1-P_u\right)\left(\left(dp_{rs_u,p}\sum_{i=p}dp_{ls_u,i}\right)+\left(dp_{ls_u,p}\sum_{i=p}dp_{rs_u,i}\right)\right)\\&=dp_{ls_u,p}\left(\left(P_u\sum_{i=0}^pdp_{rs_u,i}\right)+\left(\left(1-P_u\right)\sum_{i=p}dp_{rs_u,i}\right)\right)+dp_{rs_u,p}\left(\left(P_u\sum_{i=0}^pdp_{ls_u,i}\right)+\left(\left(1-P_u\right)\sum_{i=p}dp_{ls_u,i}\right)\right)\end{aligned} \]

虽然 dp 第二维值域很大,但是由于每个节点的子树内叶节点有限,故这个节点权值的所有取值有限,实际用到的 dp 值也很有限。故此时可以用线段树合并优化 dp 统计。具体做法和 P6773 [NOI2020] 命运 类似。统计后缀和就是总和减去前缀和。虽然是黑题,但是码长和调试难度较为清新。

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxl=35;
const int maxn=300010;
const int maxx=1000000000;
const int md=998244353;
int n,i,f,lx,rx,mx,pt,tot,cnt,ans;
int ls[maxn],rs[maxn],pv[maxn];
inline int Pow(int d,int z){
    int ret=1;
    do{
        if(z&1) ret=(1LL*ret*d)%md;
        d=(1LL*d*d)%md;
    }while(z>>=1);
    return ret;
}
const int Inv=Pow(10000,md-2);
struct node{
    int ls,rs;
    int sum,mul;
}tr[maxn*maxl];
#define ls(p) tr[p].ls
#define rs(p) tr[p].rs
#define sum(p) tr[p].sum
#define mul(p) tr[p].mul
inline void pushup(int p){
    sum(p)=sum(ls(p))+sum(rs(p));
    if(sum(p)>=md) sum(p)-=md;
}
inline void pushdown(int p){
    if(mul(p)==1) return;
    if(ls(p)){
        mul(ls(p))=(1LL*mul(p)*mul(ls(p)))%md;
        sum(ls(p))=(1LL*mul(p)*sum(ls(p)))%md;
    } 
    if(rs(p)){
        mul(rs(p))=(1LL*mul(p)*mul(rs(p)))%md;
        sum(rs(p))=(1LL*mul(p)*sum(rs(p)))%md;
    }
    mul(p)=1;
}
void merge(int &x,int y,int lt,int rt,int &s1,int &s2,int &s3,int &s4,int px,int py){
    if(!(x||y)) return;
    int mt;
    if(!x){
        mt=sum(y);
        sum(y)=(1LL*sum(y)*((1LL*s1*px+1LL*s2*py)%md))%md;
        mul(y)=(1LL*mul(y)*((1LL*s1*px+1LL*s2*py)%md))%md;
        s3+=mt; if(s3>=md) s3-=md; s4-=mt; if(s4<0) s4+=md;
        x=y;
    }
    else if(!y){
        mt=sum(x);
        sum(x)=(1LL*sum(x)*((1LL*s3*px+1LL*s4*py)%md))%md;
        mul(x)=(1LL*mul(x)*((1LL*s3*px+1LL*s4*py)%md))%md;
        s1+=mt; if(s1>=md) s1-=md; s2-=mt; if(s2<0) s2+=md;
    }
    else{
        pushdown(x); pushdown(y); mt=(lt+rt)>>1;
        merge(ls(x),ls(y),lt,mt,s1,s2,s3,s4,px,py);
        merge(rs(x),rs(y),mt+1,rt,s1,s2,s3,s4,px,py);
        pushup(x);
    }
}
void dfs(int p){
    if(!ls[p]) return;
    dfs(ls[p]);
    if(rs[p]){
        dfs(rs[p]);int s1=0,s2=sum(rs[p]),s3=0,s4=sum(ls[p]);
        merge(ls[p],rs[p],1,maxx,s1,s2,s3,s4,pv[p],md+1-pv[p]);
    }
    tr[p]=tr[ls[p]];
}
void dfst(int p,int lt,int rt){
    if(lt==rt){
        ans=(ans+(((1LL*(++cnt)*lt)%md)*
                  ((1LL*sum(p)*sum(p))%md)))%md;
        return;
    }
    pushdown(p);
    int mt=(lt+rt)>>1;
    if(ls(p)) dfst(ls(p),lt,mt);
    if(rs(p)) dfst(rs(p),mt+1,rt);
}
int main(){
    scanf("%d%d",&n,&i);
    for(i=2;i<=n;++i){
        scanf("%d",&f);
        if(!ls[f]) ls[f]=i;
        else rs[f]=i;
    }
    tot=n;
    for(i=1;i<=n;++i){
        scanf("%d",&f);
        if(ls[i]) pv[i]=(1LL*f*Inv)%md;
        else{
            lx=1;rx=maxx;pt=i;
            while(lx<rx){
                ++tot; mx=(lx+rx)>>1;
                sum(pt)=mul(pt)=1;
                if(f<=mx){
                    ls(pt)=tot;
                    rx=mx;
                }
                else{
                    rs(pt)=tot;
                    lx=mx+1;
                }
                pt=tot;
            }
            sum(pt)=mul(pt)=1;
        }
    }
    dfs(1);
    dfst(1,1,maxx);
    printf("%d",ans);
    return 0;
}

应用3:线段树结合势能分析


引入:势能分析

势能分析是 OI 中一项十分重要的内容,主要用于计算某种算法/数据结构在若干次操作后的 整体 时间复杂度,例如 SAM/PAM/Manacher/Splay。

势能分析的核心操作是找到某个核心量(又称势能函数),满足某个操作在对时间复杂度造成怎样的贡献时,对这个量造成了怎样的对应的贡献。

例题1:P4145 上帝造题的七分钟 2 / 花神游历各国

题意

给出一个长为 \(n\) 的序列 \(a\),支持 \(m\) 次操作,操作包括区间开根(下取整)和区间求和。\(n,m\le 10^5,a_i\le 10^{12}\)

解法

考虑有且只有对 \(1\) 进行开根后会得到本身,而形如 \(2^{2^k}(k\in N_+)\) 的数在 \(k+1\) 次开根后会得到 \(1\),同时 \(\forall a<b<c\)\(b\) 得到 \(1\) 的开根次数在 \(a\)\(c\) 之间。故某个数 \(a\)\(O(\log\log a)\) 次开根后会变成 \(1\)。定义势能函数为 \(\sum_{i=1}^n \log\log a_i\),则在对一个数开根后势能函数会减少 \(O(1)\),时间复杂度会增加最多 \(O(\log n)\),故最后开根的时间复杂度不会超过 \(O(n\log n\log\log 10^{12})\);而区间求和总时间复杂度显然不超过 \(O(m\log n)\)。故最后总时间复杂度为 \(O(m\log n+n\log n\log\log 10^{12})\)

当然可以使用并查集跳过所有等于 \(1\) 的数(每次一定会对某个区间内所有非 \(1\) 的数进行开根),用树状数组统计前缀和,可以使常数更小。

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxn=100010;
#define ll long long
int n,m,i,o,lt,rt;
ll a[maxn];
struct node{
    int l,r,m,cnt;
    ll sum;
}tr[maxn<<2];
#define l(p) tr[p].l
#define r(p) tr[p].r
#define m(p) tr[p].m
#define ls(p) p<<1
#define rs(p) p<<1|1 
#define cnt(p) tr[p].cnt
#define sum(p) tr[p].sum 
inline void pushup(int p){
    sum(p)=sum(ls(p))+sum(rs(p));
    cnt(p)=cnt(ls(p))+cnt(rs(p));
}
void build(int p,int l,int r){
    l(p)=l;r(p)=r; 
    m(p)=(l+r)>>1;
    if(l==r){
        sum(p)=a[l];
        if(sum(p)>1) cnt(p)=1;
        return;
    }
    build(ls(p),l,m(p));
    build(rs(p),m(p)+1,r);
    pushup(p);
}
void change(int p,int l,int r){
    if(!cnt(p)) return;
    if(l(p)==r(p)){
        sum(p)=sqrt(sum(p));
        cnt(p)=(sum(p)>1);
        return;
    }
    if(l<=m(p)) change(ls(p),l,r);
    if(r>m(p)) change(rs(p),l,r);
    pushup(p);
}
ll query(int p,int l,int r){
    if(l<=l(p)&&r>=r(p)) return sum(p);
    ll ret=0;
    if(l<=m(p)) ret=query(ls(p),l,r);
    if(r>m(p)) ret+=query(rs(p),l,r);
    return ret;
}
int main(){
    scanf("%d",&n);
    for(i=1;i<=n;++i) scanf("%lld",a+i);
    build(1,1,n);
    scanf("%d",&m);
    while(m--){
        scanf("%d%d%d",&o,&lt,&rt);
        if(lt>rt) swap(lt,rt);
        if(o) printf("%lld\n",query(1,lt,rt));
        else change(1,lt,rt);
    }
    return 0;
}

例题2:CF438D The Child and Sequence

题意

给定长为 \(n\) 的序列 \(a\),支持 \(m\) 次操作,包括区间对正整数取模、单点修改为不大于 \(10^9\) 的正整数、区间查询和。\(n,m\le 10^5,a_i\le 10^9\)

解法

考虑某个数 \(a\) 在对 \(b\) 取模后,若 \(b>a\)\(a\) 不变,否则 \(a\) 至少减半。故某个数 \(a\)\(O(\log a)\) 次对更小的数取模后会变为 \(0\)。定义势能函数为 \(\sum_{i=1}^n \log a_i\),则单次修改最多会使势能增加 \(\log 10^9\),而区间取模时每对 \(k\) 个非一的数取模,时间复杂度会增加最多 \(O(k\log n)\),势能函数至少会减小 \(k\)。综上,时间复杂度不会超过 \(O(m\log n+n\log 10^9\log n)\)

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxn=100010;
int n,m,i,o,d,lt,rt;
int a[maxn];
struct node{
    int l,r,m,mx;
    long long sum;
}tr[maxn<<2];
#define l(p) tr[p].l
#define r(p) tr[p].r
#define m(p) tr[p].m
#define ls(p) p<<1
#define rs(p) p<<1|1
#define mx(p) tr[p].mx
#define sum(p) tr[p].sum
inline void pushup(int p){
    sum(p)=sum(ls(p))+sum(rs(p));
    mx(p)=max(mx(ls(p)),mx(rs(p)));
} 
void build(int p,int l,int r){
    l(p)=l;r(p)=r;
    m(p)=(l+r)>>1;
    if(l==r){
        mx(p)=sum(p)=a[l];
        return;
    }
    build(ls(p),l,m(p));
    build(rs(p),m(p)+1,r);
    pushup(p); 
}
void md(int p,int l,int r,int x){
    if(mx(p)<x) return;
    if(l(p)==r(p)){
        mx(p)%=x;
        sum(p)=mx(p);
        return;
    }
    if(l<=m(p)) md(ls(p),l,r,x);
    if(r>m(p)) md(rs(p),l,r,x);
    pushup(p);
}
void change(int p,int x,int v){
    if(l(p)==r(p)){
        mx(p)=sum(p)=v;
        return;
    }
    if(x<=m(p)) change(ls(p),x,v);
    else change(rs(p),x,v);
    pushup(p);
}
long long query(int p,int l,int r){
    if(l<=l(p)&&r>=r(p)) return sum(p);
    long long ret=0;
    if(l<=m(p)) ret=query(ls(p),l,r);
    if(r>m(p)) ret+=query(rs(p),l,r);
    return ret;
}
int main(){
    scanf("%d%d",&n,&m);
    for(i=1;i<=n;++i) scanf("%d",a+i);
    build(1,1,n);
    while(m--){
        scanf("%d",&o);
        if(o==1){
            scanf("%d%d",&lt,&rt);
            printf("%lld\n",query(1,lt,rt));
        }
        else if(o==2){
            scanf("%d%d%d",&lt,&rt,&d);
            md(1,lt,rt,d);
        }
        else{
            scanf("%d%d",&rt,&d);
            change(1,rt,d);
        } 
    }
    return 0;
}

例题3:UOJ #228 基础数据结构练习题

题意

给出一个长为 \(n\) 的序列 \(a\),进行 \(m\) 次操作,包括区间加、区间开根下取整、区间求和。\(n,m\le 10^5\)\(a_i\) 在任何时刻均 \(\le 10^5\)

解法

这道题和 P4145 很像,做法看似与 P4145 相同。但是区间加会影响时间复杂度:

假设某个区间为若干个 \(2^{16}-1\)\(2^{16}\) 交替组合起来。如果只在区间内的数相同的情况下停止往下递归,则区间的值域会先后变为 \([2^8-1,2^8]\)\([2^4-1,2^4]\)\([2^2-1,2^2]\),最后变成 \([1,2]\)。此时区间加 \(2^{16}-2\) 又会变成原样,不得不每次操作均要暴力开根修改到线段树叶节点上。

可以证明这种情况只会在区间极差为 \(1\) 的情况出现。若区间值域形如 \([x^2-1,(x+c)^2]\ (c,x>0)\)\((x+c)^2-(x^2-1)\le c\),则 \(x^2+2cx+1\le c\),由 \(x,c\ge 1\) 可知 \(x^2+2cx+1>2cx\ge 2c>c\);故这种情况只会在区间极差为 \(1\) 的情况下出现。此时可以把开根操作变为区间减操作。

同时定义区间势能为 区间开根前后极差之差。证明复杂度方式和 P4145 相似。

代码

注意开 long long

点此查看代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int maxn=100010;
int n,m,i,o,lt,rt,dt; 
ll ans;
struct seg{
    int l,r,m,len;
    ll mn,mx,add,sum;
}tr[maxn<<2];
#define l(p) tr[p].l
#define r(p) tr[p].r
#define m(p) tr[p].m
#define ls(p) p<<1
#define rs(p) p<<1|1
#define mn(p) tr[p].mn
#define mx(p) tr[p].mx
#define len(p) tr[p].len
#define sum(p) tr[p].sum
#define add(p) tr[p].add
inline void pushup(int p){
    mn(p)=min(mn(ls(p)),mn(rs(p)));
    mx(p)=max(mx(ls(p)),mx(rs(p)));
    sum(p)=sum(ls(p))+sum(rs(p));
}
void build(int p,int l,int r){
    l(p)=l;r(p)=r;
    m(p)=(l+r)>>1;
    len(p)=r-l+1; 
    if(l==r){
        scanf("%lld",&mn(p));
        mx(p)=sum(p)=mn(p);
        return;
    }
    build(ls(p),l,m(p));
    build(rs(p),m(p)+1,r);
    pushup(p);
}
inline void pushdown(int p){
    if(!add(p)) return;
    mn(ls(p))+=add(p);mn(rs(p))+=add(p);
    mx(ls(p))+=add(p);mx(rs(p))+=add(p);
    sum(ls(p))+=add(p)*len(ls(p));
    sum(rs(p))+=add(p)*len(rs(p));
    add(ls(p))+=add(p);add(rs(p))+=add(p);
    add(p)=0; 
}
void change(int p){
    if(mx(p)==1) return;
    if((lt<=l(p)&&rt>=r(p))&&
       ((mx(p)-mn(p))==((ll)(sqrt(mx(p)))-(ll)(sqrt(mn(p)))))){
        ll d=mx(p)-(ll)(sqrt(mx(p)));
        add(p)-=d; mx(p)-=d; mn(p)-=d;
        sum(p)-=d*len(p); return;
    }
    pushdown(p);
    if(lt<=m(p)) change(ls(p));
    if(rt>m(p)) change(rs(p));
    pushup(p);
}
void addt(int p){
    if(lt<=l(p)&&rt>=r(p)){
        mn(p)+=dt; mx(p)+=dt; add(p)+=dt;
        sum(p)+=1LL*len(p)*dt; return;
    }
    pushdown(p);
    if(lt<=m(p)) addt(ls(p));
    if(rt>m(p)) addt(rs(p));
    pushup(p);
}
void query(int p){
    if(lt<=l(p)&&rt>=r(p)){
        ans+=sum(p);
        return;
    }
    pushdown(p);
    if(lt<=m(p)) query(ls(p));
    if(rt>m(p)) query(rs(p));
}
int main(){
    scanf("%d%d",&n,&m); build(1,1,n);
    while(m--){
        scanf("%d%d%d",&o,&lt,&rt);
        if(o==1){
            scanf("%d",&dt);
            addt(1);
        }
        else if(o==2) change(1);
        else{
            ans=0; query(1);
            printf("%lld\n",ans);
        }
    }
    return 0;
}

例题4

题意

有一个长为 \(n\) 的序列 \(a\),有 \(m\) 次操作,包括区间加 highbit,区间加 lowbit,区间求和。\(n,m\le 5\times 10^5\)

解法

感谢 @w23c3c3 和 @CE_WA_TLE 对该题的做法的贡献。

显然加 hightbit 不会影响到某个数中间的一部分 \(1\),加 lowbit 会使某个数的 \(1\) 减少直到只剩一个 \(1\)。但是可能会有对某个数轮流加 highbit 和 lowbit,此时这个数内一直仍然有两个 \(1\)。此时则不能在某个区间内每个数只有一个 \(1\) 时停止向下修改,而是某个区间内每个数的 \(1\) 的个数不超过两个 \(1\) 时停止向下修改。

在线段树中需要对每个节点维护 hightbit 和 lowbit 之和以及 \(\text{min}_{\text{highbit}\ne \text{lowbit}}\text{(hightbit-lowbit)}\),在区间加 hightbit 和 lowbit 时需要对 \(\min\) 进行处理,如果 \(\min\)\(0\) 则需要向下修改。


Segment Tree Beats

可以用于解决区间取最值和区间历史和问题。

例题1:HDU5306 Gorgeous Sequence

题意

有一个长为 \(n\) 的序列 \(a\)。现在需要进行下面三种操作共 \(m\) 次:

  • 区间对 \(t\)\(\min\)
  • 查询区间最大值。
  • 查询区间和。

解法

考虑在只改变区间最大值时不向下操作。我们分三种情况考虑:

  • \(t\) 不小于区间最大值。此时不会进行操作,直接返回。
  • \(t\) 大于区间次大值。此时可以只改变最大值,并且改变区间和。
  • \(t\) 不大于区间次大值。此时需要找到新的区间次大值和区间和,所以需要向下操作。

定义区间势能函数为区间内不同值的个数。此时只有在某个区间内不只有一个值时,才会从这个区间向下操作,且最后一定会使得区间内的不同值个数减小。线段树内节点对应的区间长度和为 \(O(n\log n)\) 级别,所以总复杂度为 \(O((m+n)\log n)\)

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int MXBUF=1<<21;
const int BRDBUF=MXBUF-1;
char bufr[MXBUF],*_head,*_tail;
inline char GetChar(){
	if(_head!=_tail) goto End;
	_head=_tail=bufr;
	_tail+=fread(bufr,1,MXBUF,stdin);
	End:return *(_head++);
}
int wr=-1;
char bufw[MXBUF];
inline void Flush(){
	fwrite(bufw,1,wr+1,stdout);
	wr=-1;
}
inline void PutChar(const char c){
	if(wr==BRDBUF) Flush();
	bufw[++wr]=c;
}
template<typename T>
inline void Read(T &a){
	a=0; char ch=GetChar();
	while((ch^'0')>9) ch=GetChar();
	while((ch^'0')<10){
		a=a*10+(ch^'0');
		ch=GetChar();
	}
}
template<typename T,typename... A>
void Read(T &t,A &...a){
	Read(t);
	Read(a...);
}
template<typename T>
inline void Write(T a){
	char BUF[130]; int top=-1;
	do{BUF[++top]=(a%10)|'0';}while(a/=10);
	do{PutChar(BUF[top]);}while(top--);
}
#define ll long long
const int maxn=1000010;
const int maxb=maxn<<2;
int n,m,o,l,r,t,T,am,sub[maxb];
ll as,sum[maxb];
inline void cmax(int &x,int y){if(x<y) x=y;}
struct seg{int mx,cnt,sx;}tr[maxb];
inline void Pushup(seg &ret,const seg &a,const seg &b){
	ret.cnt=0;
	if(a.mx>b.mx){
		memcpy(&ret,&a,8);
		ret.sx=max(a.sx,b.mx);
	}
	else{
		memcpy(&ret,&b,12);
		if(a.mx==b.mx){
			ret.cnt+=a.cnt;
			cmax(ret.sx,a.sx);
		}
		else cmax(ret.sx,a.mx);
	}
}
#define ls(p) p<<1
#define rs(p) p<<1|1
#define mx(p) tr[p].mx
#define sx(p) tr[p].sx
#define cnt(p) tr[p].cnt
void build(int p,int l,int r){
	sub[p]=0;
	if(l==r){
		Read(mx(p));
		sx(p)=-1; sum[p]=mx(p);
		cnt(p)=1; return;
	}
	int m=(l+r)>>1;
	build(ls(p),l,m);
	build(rs(p),m+1,r);
	Pushup(tr[p],tr[ls(p)],tr[rs(p)]);
	sum[p]=sum[ls(p)]+sum[rs(p)];
} 
inline void pushtag(int p,int t){
	sum[p]-=1LL*t*cnt(p);
	mx(p)-=t; sub[p]+=t;
}
inline void pushdown(int p){
	if(!sub[p]) return;
	int mt=max(mx(ls(p)),mx(rs(p)));
	if(mx(ls(p))==mt) pushtag(ls(p),sub[p]);
	if(mx(rs(p))==mt) pushtag(rs(p),sub[p]);
	sub[p]=0;
}
void cover(int p,int l,int r){
	if(t>=mx(p)) return;
	if(::l<=l&&::r>=r&&t>sx(p)){
		pushtag(p,mx(p)-t);
		return;
	}
	pushdown(p); int m=(l+r)>>1;
	if(::l<=m) cover(ls(p),l,m);
	if(::r>m) cover(rs(p),m+1,r);
	Pushup(tr[p],tr[ls(p)],tr[rs(p)]);
	sum[p]=sum[ls(p)]+sum[rs(p)];
}
void quemax(int p,int l,int r){
	if(::l<=l&&::r>=r){
		cmax(am,mx(p));
		return;
	}
	pushdown(p); int m=(l+r)>>1;
	if(::l<=m) quemax(ls(p),l,m);
	if(::r>m) quemax(rs(p),m+1,r);
}
void quesum(int p,int l,int r){
	if(::l<=l&&::r>=r){
		as+=sum[p];
		return;
	}
	pushdown(p); int m=(l+r)>>1; 
	if(::l<=m) quesum(ls(p),l,m);
	if(::r>m) quesum(rs(p),m+1,r);
}
int main(){
	Read(T);
	while(T--){
		Read(n,m);build(1,1,n);
		while(m--){
			Read(o,l,r);
			if(!o){Read(t);cover(1,1,n);}
			else{
				if(o==1){am=0;quemax(1,1,n);Write(am);}
				else{as=0;quesum(1,1,n);Write(as);}
				PutChar(10);
			}
		}
	}
	Flush();
	return 0;
}

拓展题1:P3747 [六省联考 2017] 相逢是问候

题意

有一个长为 \(n\) 的序列 \(a\),有 \(m\) 种操作:

  • \(\forall i\in[l,r]\)\(a_i\leftarrow c^{a_i}\)不取模
  • \(\sum_{i=l}^r a_i\),对 \(p\) 取模。

其中每个操作的 \(c\)\(p\) 相同。\(n,m\le 5\times 10^4,p\le 10^8,0<c<p,0\le a_i<p\)

解法


引入:扩展欧拉定理

欧拉定理:\(\forall a,b\in N_+\),若 \(\gcd(a,b)=1\),则有 \(a^{\phi(b)}\equiv1\pmod b\)\(\phi\) 为欧拉函数,\(\phi(a)\) 就是小于 \(a\) 且与 \(a\) 互质的正整数的数量)。

扩展欧拉定理:\(\forall a,b,c\in N_+,b>\phi(c),a^b\equiv a^{\phi(c)+(b\bmod\phi(c))}\pmod c\)

证明:参见 OI Wiki 相关内容


\(f(a,0)=a,f(a,i)=c^{f(a,i-1)}\),若 \(f(a,i-1)>\phi(p)\),则 \(f(a,i)\bmod p= c^{(f(a,i-1)\bmod\phi(p))+\phi(p)}\bmod p\);以此类推,若 \(f(a,i-2)>\phi(\phi(p))\),则 \(f(a,i)\bmod p=c^{(f(a,i-1)\bmod\phi(p))+\phi(p)}\bmod p=c^{(c^{(f(a,i-2)\bmod \phi(\phi(p)))+\phi(\phi(p))}\bmod \phi(p))+\phi(p)}\bmod p\)。(若 \(f\) 值不够大则不再往下计算)最后若每个 \(f\) 值足够大,则在计算某个 \(f(a,k)\bmod \phi(\phi(\cdots\phi(p)\cdots))\) 时会有 \(\phi(\phi(\cdots\phi(p)\cdots))\)\(1\) 的情况,这时该层指数值为 \(1\),不必往下计算。\(i\) 更大时,可以发现最后 \(f(a,i)\bmod p\) 同样可以表达成形如 \(c^{(c^{\cdots (c\bmod \phi(\phi(\cdots\phi(p)\cdots)))+\phi(\phi(\cdots\phi(p)\cdots))\cdots}\bmod \phi(p))+\phi(p)}\bmod p\) 的形式,且指数层数相同。

也就是说,对于某个 \(a\),如果 \(i\) 值足够大,则会存在一个 \(k_a\),满足 \(\forall i\ge k_a,f(a,i)\equiv f(a,i+1)\pmod p\)。我们就可以定义势能函数为每个数 \(a\) 被赋为 \(c^a\) 的次数,势能函数最大为 \(\sum k_a\)

显然我们只需要讨论 \(c>1\) 的情况。即使 \(c=2,a=0\),也有 \(f(a,6)=2^{65536}\),故 \(a\) 值小于 \(p\) 的情况可不必考虑,其会在很有限次操作后变得大于 \(p\),也就是说:

\(P(p,0)=p,P(p,i)=\phi(P(p,i-1))\),则 \(k_a\) 值约为使得 \(P(p,x)=1\) 的最小 \(x\),也就是上述形式的指数层数。

我们可以证明 \(k_a\) 值为 \(O(\log p)\) 级别的。

证明:由于欧拉函数是积性函数,且对于任意一个奇质数 \(a\),必有 \(\phi(a)=a-1\) 为偶数,故任意一个大于 \(1\) 的奇数的欧拉函数一定为偶数。同时对于任何一个偶数 \(B=2^k(2f+1)(f\in N)\),由于 \(\phi(2^k)=2^{k-1}\),故 \(\phi(B)=2^{k-1}\phi(2f+1)\)。又因为 \(\forall c\in N,\phi(c)\le c\),故而 \(\phi(B)\le 2^{k-1}(2f+1)=\frac 12B\),欧拉函数不会超过原来偶数的一半。故 \(k_a\)\(O(\log p)\) 级别的。

故直接维护前 \(O(\log p)\) 次的修改,按照上述的方式计算 \(c^{c^{\cdots a}}\) 即可。直接调用快速幂的时间复杂度,计算单个 \(c^{c^{\cdots a}}\) 的时间复杂度为 \(O(\log^2)\) 的,同时由 \(k_a\) 值很少可得模数很少,直接预处理 \(c^uP(p,v)\)\(c^{2^{16}u}P(p,v)\),然后拼凑即可去掉一个 \(\log\)。注意维护当前指数是否大于目前的 \(P(p)\) 值。

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxl=40;
const int maxv=10010;
const int maxd=1<<15;
const int maxn=50010;
int n,m,P,c,i,j,t,o,tp,lt,rt,ts,tmp,phi,lim,ans;
int v[maxv],pr[maxv],pm[maxl];
int pwn[maxd][maxl],pwx[maxd][maxl];
int val[maxn],brd[maxl];
long long tl;
bool fl;
inline int Pow(int z,int md){
    fl|=(z>=brd[md]);
    return (1LL*pwn[z&(maxd-1)][md]*pwx[z>>15][md])%pm[md];
}
int Calc(int d,int a,int md){
    fl=0;
    if(!pm[md]) return 0;
    if(!d){
        if(a>=pm[md]) a%=pm[md],fl=1;
        return a;
    }
    int ret=Calc(d-1,a,md+1);
    if(fl){
        fl=0;
        ret+=pm[md+1];
    }
    return Pow(ret,md);
}
struct seg{
    int l,r,m;
    int sum,mn;
}tr[maxn<<2];
#define l(p) tr[p].l
#define r(p) tr[p].r
#define m(p) tr[p].m
#define ls(p) p<<1
#define rs(p) p<<1|1
#define mn(p) tr[p].mn
#define sum(p) tr[p].sum
inline void pushup(int p){
    sum(p)=sum(ls(p))+sum(rs(p));
    if(sum(p)>=P) sum(p)-=P;
    mn(p)=min(mn(ls(p)),mn(rs(p)));
}
void build(int p,int l,int r){
    l(p)=l;r(p)=r;
    m(p)=(l+r)>>1;
    if(l==r){
        scanf("%d",&sum(p));
        val[l(p)]=sum(p);
        return;
    }
    build(ls(p),l,m(p));
    build(rs(p),m(p)+1,r);
    pushup(p);
}
void change(int p){
    if(mn(p)==lim) return;
    if(l(p)==r(p)){
        ++mn(p);
        sum(p)=Calc(mn(p),val[l(p)],0);
        return;
    }
    if(lt<=m(p)) change(ls(p));
    if(rt>m(p)) change(rs(p));
    pushup(p);
}
void query(int p){
    if(lt<=l(p)&&rt>=r(p)){
        ans+=sum(p);
        if(ans>=P) ans-=P;
        return;
    }
    if(lt<=m(p)) query(ls(p));
    if(rt>m(p)) query(rs(p));
}
int main(){
    for(i=2;i<maxv;++i){
        if(!v[i]) v[i]=pr[++t]=i;
        for(j=1;j<=t;++j){
            if((v[i]<pr[j])||(i*pr[j]>=maxv)) break;
            v[i*pr[j]]=pr[j];
        }
    }
    t=0;
    scanf("%d%d%d%d",&n,&m,&P,&c);
    build(1,1,n); tp=phi=pm[0]=P;
    while(tp!=1){
        tmp=tp; ts=sqrt(tp);
        for(i=1;pr[i]<=ts;++i){
            if(tmp%pr[i]) continue;
            while(!(tmp%pr[i])) tmp/=pr[i];
            phi/=pr[i]; phi*=pr[i]-1;
            if(tmp==1) break;
        }
        if(tmp!=1) phi/=tmp,phi*=tmp-1;
        tp=pm[++t]=phi;
    }
    for(i=0;i<t;++i){
        tp=pm[i]; pwn[0][i]=pwx[0][i]=1;
        for(j=1;j<maxd;++j){
            tl=1LL*pwn[j-1][i]*c;
            if(tl>=tp){
                tl%=tp;
                if(!brd[i]) brd[i]=j;
            }
            pwn[j][i]=(1LL*pwn[j-1][i]*c)%tp;
        }
        if(!brd[i]) brd[i]=INT_MAX;
        pwx[1][i]=(1LL*pwn[maxd-1][i]*c)%tp;
        for(j=2;j<maxd;++j) pwx[j][i]=(1LL*pwx[j-1][i]*pwx[1][i])%tp;
    }
    lim=t+10;
    while(m--){
        scanf("%d%d%d",&o,&lt,&rt);
        if(!o) change(1);
        else{
            ans=0; query(1);
            printf("%d\n",ans);
        }
    }
    return 0;
} 

拓展题2:CF679E Bear and Bad Powers of 42

题意

定义一个正整数是坏的,当且仅当它是 \(42\) 的次幂,否则它是好的。

给定一个长度为 \(n\) 的序列 \(a_i\),保证初始时所有数都是好的。

\(q\) 次操作,每次操作有三种可能:

  • 1 i 查询 \(a_i\)
  • 2 l r x\(a_{l\dots r}\) 赋值为一个好的数 \(x\)
  • 3 l r x\(a_{l \dots r}\) 都加上 \(x\),重复这一过程直到所有数都变好。

\(n,q \le 10^5\)\(a_i,x \le 10^9\)

解法

考虑维护当前每个数和最近的(大于本身的\(42\) 次幂的差值,记 \(a_i\) 对应的差值为 \(d_i\)。当前主要难以实现的即为操作 \(3\)

我们有如下的初始策略:\(3\) 操作即为某个区间区间加直到某次区间加后区间没有为 \(0\)\(d_i\)。具体实现上会有所差异,因为对某个 \(a_i\) 加上某个数后其对应的 \(a_i+d_i\) 可能会变,\(\lfloor\log_{42}a_i\rfloor\) 会增加。对于包含了对应的 \(a_i\) 的区间,直接暴力往下修改即可。在不考虑操作 \(2\) 的情况下,由于每次暴力向下操作都会使叶节点对应的 \(a_i\)\(\lfloor\log_{42}a_i\rfloor\) 增加,故时间复杂度为 \(O((n+q)\log n \log V)\) 的(其中 \(V\) 为值域,估算可证其不会超过 \(10^{14}\),下同)。

由于操作 \(2\) 的存在,考虑另外一种情况:数据形如如下形式:

100000 100000
41 41 41(×100000)
3 1 100000 1
2 1 100000 41
3 1 100000 1
2 1 100000 41
…………(重复这两种操作)

这样每次操作均会暴力 dfs 整棵线段树,时间复杂度不低于 \(O(n^2)\)

考虑在操作 \(2\) 存在之后,可以把值相同的一整段一起操作,由势能分析可得单次操作时间复杂度仍是 \(O(\log n\log V)\) 的。维护操作 \(3\) 方法的其他部分相似。

至于实现上,需要维护 \(d_i\) 的区间减标记(\(a_i\) 的区间加标记),区间 \(d_i\) 的最小值,区间覆盖的 \(a_i\) 和区间覆盖的 \(d_i\) 最小值;操作就是区间加和区间覆盖。

我的实现方式中有以下需要注意的地方:

注实现上一定要注意区间加和区间覆盖的问题。在对某个节点 pushdown 区间覆盖标记时,子节点的区间加标记要置零;pushdown 区间加标记时,若子节点有区间覆盖标记则直接累加到区间覆盖标记上。

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int maxl=20;
const int maxn=100010;
int n,q,o,v,lt,rt;
ll vn;
bool fl;
struct seg{
    int l,r,m;
    ll val,mn,cov,covn;
}tr[maxn<<2];
#define l(p) tr[p].l
#define r(p) tr[p].r
#define m(p) tr[p].m
#define ls(p) p<<1
#define rs(p) p<<1|1
#define mn(p) tr[p].mn
#define val(p) tr[p].val
#define cov(p) tr[p].cov
#define covn(p) tr[p].covn
inline void pushup(int p){
    mn(p)=min(mn(ls(p)),mn(rs(p)));
}
void build(int p,int l,int r){
    l(p)=l;r(p)=r;
    m(p)=(l+r)>>1;
    if(l==r){
        scanf("%lld",&val(p));
        for(vn=1;vn<=val(p);vn*=42);
        mn(p)=vn-val(p); return;
    }
    build(ls(p),l,m(p));
    build(rs(p),m(p)+1,r);
    pushup(p);
}
inline void pushdown(int p){
    if(cov(p)){
        cov(ls(p))=cov(rs(p))=cov(p);
        covn(ls(p))=mn(ls(p))=
        covn(rs(p))=mn(rs(p))=covn(p);
        val(ls(p))=val(rs(p))=0;
        cov(p)=covn(p)=0;
    }
    if(val(p)){
        mn(ls(p))-=val(p);
        mn(rs(p))-=val(p);
        if(cov(ls(p))){
            cov(ls(p))+=val(p);
            covn(ls(p))-=val(p);
        } 
        else val(ls(p))+=val(p);
        if(cov(rs(p))){
            cov(rs(p))+=val(p);
            covn(rs(p))-=val(p);
        }
        else val(rs(p))+=val(p);
        val(p)=0;
    }
    if(l(ls(p))==r(ls(p))){
        if(cov(ls(p))){
            val(ls(p))=cov(ls(p));
            cov(ls(p))=0;
        }
    } 
    if(l(rs(p))==r(rs(p))){
        if(cov(rs(p))){
            val(rs(p))=cov(rs(p));
            cov(rs(p))=0;
        }
    }
}
void change(int p){
    if(lt<=l(p)&&rt>=r(p)){
        if(l(p)==r(p)){
            if(cov(p)){
                val(p)=cov(p);
                cov(p)=0;
            }
            val(p)+=v;
            for(vn=1;vn<val(p);vn*=42);
            if(vn==val(p)) fl=1,vn*=42;
            mn(p)=vn-val(p);
            return;
        }  
        else if(cov(p)){
            cov(p)+=v;
            for(vn=1;vn<cov(p);vn*=42);
            if(vn==cov(p)) fl=1,vn*=42;
            mn(p)=covn(p)=vn-cov(p);
            return;
        }
        else if(mn(p)>v){
            mn(p)-=v;
            val(p)+=v;
            return;
        }
    }
    pushdown(p);
    if(lt<=m(p)) change(ls(p));
    if(rt>m(p)) change(rs(p));
    pushup(p); 
}
void cover(int p){
    if(lt<=l(p)&&rt>=r(p)){
        cov(p)=v; val(p)=0;
        covn(p)=mn(p)=vn;
        if(l(p)==r(p)){
            val(p)=v;
            cov(p)=0;
        }
        return;
    }
    pushdown(p);
    if(lt<=m(p)) cover(ls(p));
    if(rt>m(p)) cover(rs(p));
    pushup(p);
}
int main(){
    scanf("%d%d",&n,&q);
    build(1,1,n);
    while(q--){
        scanf("%d",&o);
        if(o==1){
            scanf("%d",&lt);
            rt=1;
            while(l(rt)!=r(rt)){
                pushdown(rt);
                if(lt<=m(rt)) rt=ls(rt);
                else rt=rs(rt);
            }
            if(cov(rt)){
                val(rt)=cov(rt);
                cov(rt)=0; 
            }
            printf("%lld\n",val(rt));
        } 
        else{
            scanf("%d%d%d",&lt,&rt,&v);
            if(o==2){
                for(vn=1;vn<=v;vn*=42);
                vn-=v; cover(1);
            }
            else{
                do{
                    fl=0;
                    change(1);
                }while(fl);
            }
        }
    }
    return 0;
}

应用4:线段树分治

某些时候我们会遇到按时间进行的一系列操作,形如造成某个影响/复原之前的某个操作带来的影响;或是在某个序列的多个区间上查询信息。

对于前者,可以以操作时间轴建立线段树,将每个影响对应起效的时间段插入到线段树上;最后 dfs 一遍线段树,在每个节点上加入对应影响,在代表查询的叶节点上查询,在从某个节点回溯时删除这个影响。对于后者,操作方式相似。

引入:CF601E A Museum Robbery (时间轴分治)

题解

例题:P7470 [NOI Online 2021 提高组] 岛屿探险 (询问区间分治)

题意

\(n\) 个物品,每个物品有两个权值 \(a_i,b_i\)\(q\) 次询问,每次询问给出 \(l,r,c,d\),求 \([l,r]\) 中有多少个物品满足 \((a\oplus c)\le\min(b,d)\)\(1\le n,q\le 10^5,1\le a,b,c,d\le 2^{24}-1\)

解法

考虑 \([l,r]=[1,n]\) 时的做法。

考虑分开处理 \(\min(b,d)=b\)\(\min(b,d)=d\) 的情况。此时需要将所有物品按照 \(b\) 值排序,将所有询问按照 \(d\) 值排序。

若对于某个询问 \(\{1,n,c,d\}\),对某些物品有 \(\min(b,d)=d\),则可以对这些物品的 \(a\) 值建立 Trie 树,然后查询 Trie 树上有多少个满足异或 \(c\) 不大于 \(d\) 的值。实现上,考虑按二进制位匹配 \(c\oplus d\),记在匹配前 \(i\) 位时匹配到节点 \(p\),若第 \(d\) 的第 \(i-1\) 位为 \(0\),(令个位为第 \(0\) 位)则 \(p\) 对应的前缀后只能填 \(0\),将 \(p\) 跳转至 \(son_{p,\left\lfloor\frac {d}{2^{i-1}}\right\rfloor\bmod 2}\);否则若 \(p\) 对应的前缀后填 \(0\),则显然之后无论接什么后缀,均会不大于 \(d\),可以直接将 \(son_{p,\left\lfloor\frac {d}{2^{i-1}}\right\rfloor\bmod 2}\) 对应的贡献加上,也就是此节点对应的子树的所有 \(a\) 值当前被插入次数总数;然后将 \(p\) 跳转至 \(son_{p,1-\left(\left\lfloor\frac {d}{2^{i-1}}\right\rfloor\bmod 2\right)}\) 处以继续匹配。

若对于某个物品 \(\{a,b\}\),对某些 询问\(\min(b,d)=b\),则将这些询问对应的 \(c\) 值插入 Trie 中,则需要按照上述方式寻找满足异或 \(a\) 不大于 \(b\) 的值,对这些值对应的所有询问的答案加一。方法与上述类似。如果使用点差分的方式进行子树加,则需要在新插入一个 \(d\) 值时,预先处理当前询问的 \(\boldsymbol d\) 值对应的节点的整条链上的点差分之和 \(\boldsymbol S\)然后将当前询问对应的答案减去 \(\boldsymbol S\),有可能插入了跟之前某个询问一样的 \(\boldsymbol d\) 值,此时不能仅仅在 Trie 的节点上删除之前的操作造成的影响。(这个菜鸡现在总算看出来了)

时间复杂度为 \(O((n+q)\log V)\)\(V\) 为值域。

考虑在线的情况。可以对 \(n\) 个物品建立线段树,把每个询问对应的区间插入线段树上,最后 dfs 一遍整棵线段树,在进入某个节点时统计其对对应询问的贡献。区间对物品排序可以归并。时间复杂度约为 \(O((n+q)\log n\log V)\)

代码

点此查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxl=24;
const int maxn=100010;
const int INF=1145141919;
int n,q,i,j,k,c,v,lt,rt,ct,dt,pt,at,tot;
int a[maxn],b[maxn],pe[maxn],tmp[maxn];
int siz[maxn*maxl],trie[maxn*maxl][2];
int h[maxn*maxl],nxt[maxn],ans[maxn];
bool f;
void dfst(int p,int d){
    if(!p) return;
    d+=siz[p]; 
    v=h[p]; h[p]=0;
    while(v){
        ans[v]+=d;
        at=v; v=nxt[v];
        nxt[at]=0;
    }
    dfst(trie[p][0],d);
    dfst(trie[p][1],d);
    siz[p]=trie[p][0]=trie[p][1]=0;
}
struct query{
    int ct,dt,idx;
    inline bool operator <(const query &a)const{return dt<a.dt;}
};
struct seg{
    int l,r,m;
    vector<query> que;
}tr[maxn<<2];
#define l(p) tr[p].l
#define r(p) tr[p].r
#define m(p) tr[p].m
#define ls(p) p<<1
#define rs(p) p<<1|1
#define que(p) tr[p].que
void build(int p,int l,int r){
    l(p)=l;r(p)=r;
    m(p)=(l+r)>>1;
    if(l==r) return;
    build(ls(p),l,m(p));
    build(rs(p),m(p)+1,r);
}
void insert(int p){
    if(lt<=l(p)&&rt>=r(p)){
        que(p).push_back((query){ct,dt,i});
        return;
    }
    if(lt<=m(p)) insert(ls(p));
    if(rt>m(p)) insert(rs(p));
}
void dfs(int p){
    if(l(p)!=r(p)){
        dfs(ls(p));
        dfs(rs(p));
        j=m(p)+1; c=l(p)-1; 
        for(i=l(p);i<=m(p);++i){
            while(j<=r(p)&&b[pe[j]]>b[pe[i]]) tmp[++c]=pe[j++];
            tmp[++c]=pe[i];
        }
        while(j<=r(p)) tmp[++c]=pe[j++];
        for(c=l(p);c<=r(p);++c) pe[c]=tmp[c]; 
    }
    else tmp[l(p)]=pe[l(p)];
    tot=1; j=que(p).size()-1;
    if(j<0) return;
    sort(que(p).begin(),que(p).end());
    while(j>=0&&que(p)[j].dt>b[tmp[l(p)]]) --j;
    if(j>=0){
        for(c=l(p);c<=r(p);++c){
            v=a[tmp[c]]; 
            pt=1;++siz[1];
            for(i=maxl-1;i>=0;--i){
                f=(v>>i)&1;
                if(!trie[pt][f]) trie[pt][f]=++tot;
                pt=trie[pt][f]; ++siz[pt]; 
            }
            if(b[tmp[c]]==b[tmp[c+1]]) continue;
            while(j>=0&&que(p)[j].dt>b[tmp[c+1]]){
                pt=1; at=0; 
                dt=que(p)[j].dt;
                ct=que(p)[j].ct;
                for(i=maxl-1;i>=0;--i){
                    f=(ct>>i)&1;
                    if((dt>>i)&1){
                        at+=siz[trie[pt][f]];
                        pt=trie[pt][!f];
                    }
                    else pt=trie[pt][f];
                }
                at+=siz[pt];
                ans[que(p)[j--].idx]+=at;
            }
        }
        while(tot){
            trie[tot][0]=trie[tot][1]=siz[tot]=0;
            --tot;
        }
        tot=1; 
    }  
    j=l(p); v=que(p)[que(p).size()-1].dt;
    while(j<=r(p)&&b[pe[j]]>=v) ++j;
    if(j>r(p)) return;
    for(i=que(p).size()-1;i>=0;){
        v=que(p)[i].ct;
        pt=1; ct=siz[1]; 
        for(c=maxl-1;c>=0;--c){
            f=(v>>c)&1; 
            if(!trie[pt][f]) trie[pt][f]=++tot;
            pt=trie[pt][f]; ct+=siz[pt];
        }
        ans[que(p)[i].idx]-=ct;
        nxt[que(p)[i].idx]=h[pt];
        h[pt]=que(p)[i].idx;
        if(!(i--)) v=-1; 
        else v=que(p)[i].dt;
        if(que(p)[i+1].dt==v) continue;
        while(j<=r(p)&&b[pe[j]]>=v){
            ct=a[pe[j]]; dt=b[pe[j]]; pt=1;
            for(c=maxl-1;c>=0;--c){
                f=(ct>>c)&1;
                if((dt>>c)&1){
                    ++siz[trie[pt][f]];
                    pt=trie[pt][!f];
                }
                else pt=trie[pt][f];
            }
            ++siz[pt]; ++j;
        }
    }
    siz[0]=0; dfst(1,0); 
}
int main(){
    scanf("%d%d",&n,&q);
    build(1,1,n);
    for(i=1;i<=n;++i) scanf("%d%d",a+i,b+i),pe[i]=i;
    for(i=1;i<=q;++i){
        scanf("%d%d%d%d",&lt,&rt,&ct,&dt);
        insert(1);
    }
    dfs(1);
    for(i=1;i<=q;++i) printf("%d\n",ans[i]);
    return 0;
}

点击点赞为 Fran-Cen 助力!

posted @ 2022-10-03 23:17  Fran-Cen  阅读(38)  评论(0编辑  收藏  举报