2020牛客寒假算法基础集训营2 J-求函数 (线段树维护矩阵乘法)
题目链接:https://ac.nowcoder.com/acm/contest/3003/J
思路:
方法①
f1(1)=k1+b1=(k1)+(b1)
f2(f1(1))=k2∗(f1(1))+b2=k2∗k1+k2∗b1+b2=(k2∗k1)+(k2∗b1+b2)
f3(f2(f1(1)))=(k3∗k2∗k1)+(k3∗k2∗b1+k3∗b2+b3)
通过上面的展开,我们可以发现一个式子可以分成两部分:∏Ki 与 ∑ri=l(bi*∏rj=i+1Kj)
分别用线段树维护这两部分即可,现在考虑如果合并区[l1 , r1 ] 与 [r1+1 ,r2 ]
假设左区间的第一部分为 N1 第二部分为 M1
右区间的第一部分为 N2 第二部分为 M2
合并后区间的第一部分为N1*N2,第二部分为N2 * M1 + M2
#include<iostream> #include<algorithm> #include<cstring> using namespace std; typedef long long ll; const int mod=1e9+7; const int maxn=2e5+10; struct node{ ll l,r,k,b; }tree[maxn<<2]; ll k[maxn],b[maxn],n,m,op,l1,r1,po,k1,b1; void pushup(int rt) { tree[rt].k=(tree[rt<<1].k*tree[rt<<1|1].k)%mod; tree[rt].b=((tree[rt<<1|1].k*tree[rt<<1].b)%mod+tree[rt<<1|1].b)%mod; } void build(ll rt,ll l,ll r) { tree[rt].l=l; tree[rt].r=r; if(l==r){ tree[rt].k=k[l],tree[rt].b=b[l]; return; } ll mid=(l+r)>>1; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } void update(ll rt,ll pos) { if(tree[rt].l==pos&&tree[rt].r==pos){ tree[rt].k=k[pos],tree[rt].b=b[pos]; return; } ll mid=(tree[rt].l+tree[rt].r)>>1; if(pos<=mid) update(rt<<1,pos); else update(rt<<1|1,pos); pushup(rt); } typedef pair<ll,ll> p; p query(int rt,int l,int r,int ll,int rr) { if(ll>r || rr<l) return p(-1,-1); if(l>=ll && r<=rr) return p(tree[rt].k,tree[rt].b); int mid=(l+r)>>1; p p1=query(rt<<1,l,mid,ll,rr); p p2=query(rt<<1|1,mid+1,r,ll,rr); if(p1.first==-1) return p2; if(p2.first==-1) return p1; int k1=p1.first,b1=p1.second; int k2=p2.first,b2=p2.second; return p(1ll*k1*k2%mod,(1ll*b1*k2+b2)%mod); } int main() { scanf("%lld%lld",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&k[i]); for(int i=1;i<=n;i++) scanf("%d",&b[i]); build(1,1,n); for(int i=1;i<=m;i++){ scanf("%lld",&op); if(op==1){ scanf("%lld%lld%lld",&po,&k1,&b1); k[po]=k1; b[po]=b1; update(1,po); } else{ scanf("%lld%lld",&l1,&r1); p p1=query(1,1,n,l1,r1); int k=p1.first,b=p1.second; printf("%d\n",((k+b)%mod+mod)%mod); } } return 0; }
方法②矩阵乘法
这篇博客讲的不错:https://www.cnblogs.com/BakaCirno/p/12270838.html
#include<iostream> #include<iostream> #include<cstring> #define mid ((l+r)>>1) typedef long long ll; using namespace std; const int maxn=2e5+10; const int mod=1e9+7; int n,m; ll k[maxn],b[maxn]; struct MX{ ll m[2][2]; MX(){memset(m,0,sizeof(m));} friend MX operator *(const MX&a,const MX&b){ MX res; for(int i=0;i<2;i++) for(int j=0;j<2;j++){ for(int k=0;k<2;k++) res.m[i][j]+=a.m[i][k]*b.m[k][j]; res.m[i][j]%=mod; } return res; } }mx[maxn<<2]; void update(int rt,int l,int r,int pos) { if(l==r){mx[rt].m[0][0]=k[l],mx[rt].m[1][0]=b[l],mx[rt].m[1][1]=1;return;} if(pos<=mid) update(rt<<1,l,mid,pos); else update(rt<<1|1,mid+1,r,pos); mx[rt]=mx[rt<<1]*mx[rt<<1|1]; } MX query(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) return mx[rt]; MX res; res.m[0][0]=res.m[1][1]=1; if(L<=mid) res=res*query(rt<<1,l,mid,L,R); if(R>mid) res=res*query(rt<<1|1,mid+1,r,L,R); return res; } int main() { cin>>n>>m; for(int i=1;i<=n;i++) scanf("%lld",&k[i]); for(int i=1;i<=n;i++) scanf("%lld",&b[i]); for(int i=1;i<=n;i++) update(1,1,n,i); for(int i=1,opt,l,r;i<=m;i++){ cin>>opt; if(opt==1){ scanf("%d",&l); scanf("%lld%lld",&k[l],&b[l]); update(1,1,n,l); } else{ scanf("%d%d",&l,&r); MX res; res.m[0][0]=res.m[0][1]=1; res=res*query(1,1,n,l,r); printf("%lld\n",res.m[0][0]%mod); } } }