线段树+crt

题目地址
这道题主要就是要先想到暴力的处理模数内所有可能的x对应的式子的计算结果,再用线段树维护这个暴力数组来实现 \(\Theta(n log n)\) 建树,\(\Theta(log n)\) 修改,\(\Theta(1)\)查询,然后这样如果模数很大我们肯定既存不下又会T,那我们观察大部分模数发现他们可以拆成很多小质数的积,那我们就可以把模数拆掉然后分别记答案,最后用CRT合并即可,这样模数的各种因子在复杂度里的影响就从原来的所有乘起来变成了加起来,就不会T和爆空间了,对于小于1000的数据我们可以直接暴力求解防止模数出大质数,为了优化快速幂的复杂度,我们还可以用扩展欧拉定理把乘方的幂次控制在\(\varphi(s)\)的数量级内

#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll n,m,p,nowx,a[1000000],op[1000000],ty,ca,cop,mod[100],cnt,cp,ex,ey,phi[10000],prime[10000],v[10000],tot;
char c;
ll ksm(ll x,ll ci,ll mo){
    if(ci==0)return 1;
    if(ci==1)return x;
    if(ci==0)return 0;
    if(ci%2==0)return ksm(x*x%mo,ci/2,mo);
    return ksm(x*x%mo,ci/2,mo)*x%mo;
}
struct tree{
    int xta[50];
}t[10][800000];
void build(ll x,ll l,ll r){
    if(l==r){
        for(int i=1;i<=cnt;i++){
            for(int j=0;j<mod[i];j++){
                if(op[l]==1){
                    t[i][x].xta[j]=(j+a[l])%mod[i];
                }
                else if(op[l]==2){
                    t[i][x].xta[j]=(j*a[l])%mod[i];
                }
                else{
                    t[i][x].xta[j]=ksm(j,a[l]%phi[mod[i]]+phi[mod[i]],mod[i])%mod[i];
                }
            }
        }
        return;
    }
    build(x<<1,l,(l+r)>>1);
    build((x<<1)+1,((l+r)>>1)+1,r);
    for(int i=1;i<=cnt;i++){
        for(int j=0;j<mod[i];j++){
            t[i][x].xta[j]=t[i][(x<<1)+1].xta[t[i][x<<1].xta[j]];
        }
    }
}
void change(ll x,ll l,ll r){
    if(l==r&&l==cp){
        for(int i=1;i<=cnt;i++){
            for(int j=0;j<mod[i];j++){
                if(op[l]==1){
                    t[i][x].xta[j]=(j+a[l])%mod[i];
                }
                else if(op[l]==2){
                    t[i][x].xta[j]=(j*a[l])%mod[i];
                }
                else{
                    t[i][x].xta[j]=ksm(j,a[l]%phi[mod[i]]+phi[mod[i]],mod[i])%mod[i];
                }
            }
        }
        return;
    }
    if(l==r)return ;
    if(cp<=(l+r)>>1){
        change(x<<1,l,(l+r)>>1);
    }
    else{
        change((x<<1)+1,((l+r)>>1)+1,r);
    }
    for(int i=1;i<=cnt;i++){
        for(int j=0;j<mod[i];j++){
            t[i][x].xta[j]=t[i][(x<<1)+1].xta[t[i][x<<1].xta[j]];
        }
    }
}
ll exgcd(ll xa,ll xb){
    if(xb==0){
        ex=1,ey=0;
        return xa;
    }
    ll res=exgcd(xb,xa%xb);
    ll x=ex,y=ey;
    ey=x-(ll)(xa/xb)*y;
    ex=y;
    return res;
}
ll excrt(){
    ll xa[cnt],xm[cnt],ans=0,M=1;
    for(int i=1;i<=cnt;i++){
        M*=mod[i];
    }
    for(int i=1;i<=cnt;i++){
        ans=(ans+(M/mod[i]*t[i][1].xta[nowx%mod[i]]/exgcd(M/mod[i],mod[i])*ex)%M)%M;
    }
    return (ans+M)%M;
}
int main(){
    for(int i=2;i<=1000;i++){
        if(v[i]==0){
            phi[i]=i-1;
            prime[++tot]=i;
            v[i]=i;
        }
        for(int j=1;j<=tot;j++){
            if(prime[j]*i>1000||v[i]<prime[j])break;
            v[i*prime[j]]=prime[j];
            if(v[i]==prime[j])phi[i*prime[j]]=phi[i]*prime[j];
            else phi[i*prime[j]]=phi[i]*(prime[j]-1);
        }
    }
    cin>>n>>n>>m>>p;
    for(int i=1;i<=n;i++){
        cin>>c;
        scanf("%lld",&a[i]);
        if(c=='+')op[i]=1;
        else if(c=='*')op[i]=2;
        else op[i]=3;
    }
    if(n<=1000&&m<=1000){
        for(int i=0;i<m;i++){
            scanf("%lld",&ty);
            if(ty==1){
                scanf("%lld",&nowx);
                for(int j=1;j<=n;j++){
                    if(op[j]==1){
                        nowx+=a[j];
                        nowx%=p;
                    }
                    else if(op[j]==2){
                        nowx*=a[j];
                        nowx%=p;
                    }
                    else{
                        nowx=ksm(nowx,a[j],p);
                    }
                }
                cout<<nowx<<endl;
            }
            else{
                scanf("%lld ",&cp);
                cin>>c;
                scanf("%lld",&a[cp]);
                if(c=='+')op[cp]=1;
                else if(c=='*')op[cp]=2;
                else op[cp]=3;
            }
        }
    }
    else{
        for(int i=2;i<=p;i++){
            if(p%i==0){
                cnt++;
                mod[cnt]=1;
                while(p%i==0)p/=i,mod[cnt]*=i;
            }
        }
        build(1,1,n);
        for(int i=0;i<m;i++){
            scanf("%lld",&ty);
            if(ty==1){
                scanf("%lld",&nowx);
                printf("%lld\n",excrt());
            }
            else{
                scanf("%lld ",&cp);
                cin>>c;
                scanf("%lld",&a[cp]);
                if(c=='+')op[cp]=1;
                else if(c=='*')op[cp]=2;
                else op[cp]=3;
                change(1,1,n);
            }
        }
    }
    return 0;
}
posted @ 2022-10-13 10:23  fluffy_stoat  阅读(25)  评论(0编辑  收藏  举报