BZOJ 1129 exgcd+CRT+线段树

思路:

先copy一下百度百科 作为预备知识吧
多重全排列定义:求r1个1,r2个2,…,rt个t的排列数,设r1+r2+…+rt=n,设此排列数称为多重全排列,表示为$P(n;r1,r2,…,rt)$
$P(n;r1,r2,…,rt)=\frac{n!}{(r1!r2!...rt!)}$

题目是让求s的排名mod m
我们就可以从前往后枚举
前$(i-1)$位跟给出的排列一样 第i位填小于s[i]的数
后面i到n位可以随便填的方案数
(有点像数位DP最后统计的那种感觉.)
设calc[x]是串s中i到n位 x出现的次数
这样枚举到第i位的答案就是$(\Sigma_{j=i}^n{s[j]<s[i]})*\frac{(n-i)!}{(cnt[1]!cnt[2]!...cnt[max—s[x]])}$
(离散化什么的就不用我说了吧)
m不是质数 怎么办
把m拆成$m={p_{1}}^{q_{1}}*{p_{2}}^{q_{2}}...{p_{cnt}}^{q_{cnt}}$
用中国剩余定理搞一搞
把不与$p^q$互质的数单独拎出来算
整体复杂度是$O(nlognlogm)$的
(最好别用线段树.. 常数大)
(CRT和exgcd的时候要时刻注意负数)

//By SiriusRen
#include <cstdio>
#include <algorithm>
using namespace std;
const int N=300050;
int n,M,u,xx,yy,s[N],cpy[N],tree[N*8],Ans[N];
int p[N],ps[N],cnt,fac[N],numb[N],calc[N],pw[N];
void exgcd(int a,int b,int &x,int &y){
    if(!b){x=1,y=0;return;}
    exgcd(b,a%b,x,y);
    int temp=x;x=y;y=temp-a/b*y;
}
int CRT(int *a,int *m){
    int ans=0;
    for(int i=1;i<=cnt;i++){
        exgcd(M/m[i],m[i],xx,yy);
        ans=(ans+1ll*M/m[i]*xx%M*a[i])%M;
    }return (ans+M)%M;
}
void insert(int l,int r,int pos,int num,int wei){
    if(l==r){tree[pos]+=wei;return;}
    int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
    if(mid<num)insert(mid+1,r,rson,num,wei);
    else insert(l,mid,lson,num,wei);
    tree[pos]=tree[lson]+tree[rson];
}
int query(int l,int r,int pos,int L,int R){
    if(l>=L&&r<=R)return tree[pos];
    int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
    if(mid<L)return query(mid+1,r,rson,L,R);
    else if(mid>=R)return query(l,mid,lson,L,R);
    else return query(l,mid,lson,L,R)+query(mid+1,r,rson,L,R);
}
void Dec(int m){
    for(int i=2;i*i<=m;i++)if(m%i==0){
        p[++cnt]=i,ps[cnt]=1;
        while(m%i==0)m/=i,ps[cnt]*=i;
    }if(m!=1)p[++cnt]=m,ps[cnt]=m;
}
int inv(int a,int b){exgcd(a,b,xx,yy);return (xx+b)%b;}
void solve(){
    for(int T=1;T<=cnt;T++){
        for(int i=1;i<=n;i++){
            int temp=i,jy=0;while(temp%p[T]==0)temp/=p[T],jy++;
            fac[i]=1ll*fac[i-1]*temp%ps[T];
            pw[i]=pw[i-1]*p[T]%ps[T];
            numb[i]=numb[i-1]+jy;
            insert(0,u,1,s[i],1),calc[s[i]]++;
        }int sum=0,sum_inv=1;
        for(int i=1;i<=u;i++)sum+=numb[calc[i]],sum_inv=1ll*sum_inv*inv(fac[calc[i]],ps[T])%ps[T];
        for(int i=1;i<=n;i++){
            Ans[T]=(Ans[T]+1ll*query(0,u,1,0,s[i]-1)*fac[n-i]%ps[T]*sum_inv%ps[T]*pw[numb[n-i]-sum])%ps[T];
            sum_inv=1ll*sum_inv*fac[calc[s[i]]]%ps[T]*inv(fac[calc[s[i]]-1],ps[T])%ps[T];
            sum-=numb[calc[s[i]]],calc[s[i]]--,sum+=numb[calc[s[i]]];
            insert(0,u,1,s[i],-1);
        }(Ans[T]+=1)%=ps[T];
    }
}
signed main(){
    scanf("%d%d",&n,&M);Dec(M);pw[0]=fac[0]=1;
    for(int i=1;i<=n;i++)scanf("%d",&s[i]),cpy[i]=s[i];
    sort(cpy+1,cpy+1+n);u=unique(cpy+1,cpy+1+n)-cpy-1;
    for(int i=1;i<=n;i++)s[i]=lower_bound(cpy+1,cpy+1+u,s[i])-cpy;
    solve();printf("%d\n",CRT(Ans,ps));
}

 

posted @ 2017-04-09 22:05  SiriusRen  阅读(370)  评论(0编辑  收藏  举报