2020牛客寒假算法基础集训营2 J-求函数(线段树维护矩阵乘法)

传送门

思路

其实可以维护两棵线段树,一颗维护区间k的乘积,一颗维护区间的结果,可以推一下公式解决。
但更方便的使用一颗线段树直接维护矩阵的乘积。
因为 \(f_2(f_1(x))=k_1k_2x+b_1k_2+b_2\)
根据矩阵乘法:

\[\left[ \begin{matrix} 1 & 1 \\ 0 & 0 \end{matrix} \right]\times\left[ \begin{matrix} k_1 & 0 \\ b_1 & 1 \end{matrix} \right]=\left[ \begin{matrix} k_1+b_1 & 1 \\ 0 & 0 \end{matrix} \right]\\\left[ \begin{matrix}   k_1+b_1 & 1  \\   0 & 0   \end{matrix}  \right]\times\left[ \begin{matrix}   k_2 & 0  \\   b_2 & 1  \end{matrix}  \right]=\left[ \begin{matrix}   k_1k_2+b_1k_2+b2 & 1  \\   0 & 0   \end{matrix}  \right] \]

可以让线段树的叶节点维护矩阵 \(\left[ \begin{matrix} k_i & 0 \\ b_i & 1 \end{matrix} \right]\),每个节点维护矩阵 \(l-r\) 的乘积,每次更改就只更改矩阵里的参数,每次查询就查询 \(l-r\) 的乘积,再用矩阵 \(\left[ \begin{matrix} 1 & 1 \\ 0 & 0 \end{matrix} \right]\) 乘以查询的结果,结果矩阵的[0][0]就是答案了。

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAXN=2e5+10;
const int mod=1e9+7;
int n,m;
LL k[MAXN],b[MAXN];
LL md(LL x){
    return (x%mod+mod)%mod;
}
struct Mx{
    LL m[2][2];
    Mx(){memset(m,0,sizeof(m));}
    friend Mx operator * (const Mx& a,const Mx& b){
        Mx res;
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++){
                for(int k=0;k<2;k++)
                    res.m[i][j]+=a.m[i][k]*b.m[k][j];
                res.m[i][j]%=mod;
            }
        return res;
    }
};
 
struct SegTree{
    #define mid ((l+r)>>1)
    Mx mx[MAXN*4];
    void update(int id,int l,int r,int pos){
        if(l==r) {mx[id].m[0][0]=k[l],mx[id].m[1][0]=b[l],mx[id].m[1][1]=1;return;}
        if(pos<=mid) update(id<<1,l,mid,pos);
        else update(id<<1|1,mid+1,r,pos);
        mx[id]=mx[id<<1]*mx[id<<1|1];
    }
    Mx ask(int id,int l,int r,int L,int R){
        if(L<=l&&r<=R) return mx[id];
        Mx res;res.m[0][0]=res.m[1][1]=1;
        if(L<=mid) res=res*ask(id<<1,l,mid,L,R);
        if(R>mid) res=res*ask(id<<1|1,mid+1,r,L,R);
        return res;
    }
    #undef mid
}tr;
 
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%lld",&k[i]);
    for(int i=1;i<=n;i++) scanf("%lld",&b[i]);
    for(int i=1;i<=n;i++) tr.update(1,1,n,i);
    for(int i=1,opt,l,r;i<=m;i++){
        scanf("%d",&opt);
        if(opt==1){
            scanf("%d",&l);
            scanf("%lld%lld",&k[l],&b[l]);
            tr.update(1,1,n,l);
        }
        else{
            scanf("%d%d",&l,&r);
            Mx res;res.m[0][0]=res.m[0][1]=1;
            res=res*tr.ask(1,1,n,l,r);
            printf("%lld\n",res.m[0][0]);
        }
    }
    return 0;
}
posted @ 2020-02-06 20:53  BakaCirno  阅读(238)  评论(0编辑  收藏  举报