BZOJ 1129 exgcd+CRT+线段树
思路:
先copy一下百度百科 作为预备知识吧
多重全排列定义:求r1个1,r2个2,…,rt个t的排列数,设r1+r2+…+rt=n,设此排列数称为多重全排列,表示为$P(n;r1,r2,…,rt)$
$P(n;r1,r2,…,rt)=\frac{n!}{(r1!r2!...rt!)}$
题目是让求s的排名mod m
我们就可以从前往后枚举
前$(i-1)$位跟给出的排列一样 第i位填小于s[i]的数
后面i到n位可以随便填的方案数
(有点像数位DP最后统计的那种感觉.)
设calc[x]是串s中i到n位 x出现的次数
这样枚举到第i位的答案就是$(\Sigma_{j=i}^n{s[j]<s[i]})*\frac{(n-i)!}{(cnt[1]!cnt[2]!...cnt[max—s[x]])}$
(离散化什么的就不用我说了吧)
m不是质数 怎么办
把m拆成$m={p_{1}}^{q_{1}}*{p_{2}}^{q_{2}}...{p_{cnt}}^{q_{cnt}}$
用中国剩余定理搞一搞
把不与$p^q$互质的数单独拎出来算
整体复杂度是$O(nlognlogm)$的
(最好别用线段树.. 常数大)
(CRT和exgcd的时候要时刻注意负数)
//By SiriusRen #include <cstdio> #include <algorithm> using namespace std; const int N=300050; int n,M,u,xx,yy,s[N],cpy[N],tree[N*8],Ans[N]; int p[N],ps[N],cnt,fac[N],numb[N],calc[N],pw[N]; void exgcd(int a,int b,int &x,int &y){ if(!b){x=1,y=0;return;} exgcd(b,a%b,x,y); int temp=x;x=y;y=temp-a/b*y; } int CRT(int *a,int *m){ int ans=0; for(int i=1;i<=cnt;i++){ exgcd(M/m[i],m[i],xx,yy); ans=(ans+1ll*M/m[i]*xx%M*a[i])%M; }return (ans+M)%M; } void insert(int l,int r,int pos,int num,int wei){ if(l==r){tree[pos]+=wei;return;} int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1; if(mid<num)insert(mid+1,r,rson,num,wei); else insert(l,mid,lson,num,wei); tree[pos]=tree[lson]+tree[rson]; } int query(int l,int r,int pos,int L,int R){ if(l>=L&&r<=R)return tree[pos]; int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1; if(mid<L)return query(mid+1,r,rson,L,R); else if(mid>=R)return query(l,mid,lson,L,R); else return query(l,mid,lson,L,R)+query(mid+1,r,rson,L,R); } void Dec(int m){ for(int i=2;i*i<=m;i++)if(m%i==0){ p[++cnt]=i,ps[cnt]=1; while(m%i==0)m/=i,ps[cnt]*=i; }if(m!=1)p[++cnt]=m,ps[cnt]=m; } int inv(int a,int b){exgcd(a,b,xx,yy);return (xx+b)%b;} void solve(){ for(int T=1;T<=cnt;T++){ for(int i=1;i<=n;i++){ int temp=i,jy=0;while(temp%p[T]==0)temp/=p[T],jy++; fac[i]=1ll*fac[i-1]*temp%ps[T]; pw[i]=pw[i-1]*p[T]%ps[T]; numb[i]=numb[i-1]+jy; insert(0,u,1,s[i],1),calc[s[i]]++; }int sum=0,sum_inv=1; for(int i=1;i<=u;i++)sum+=numb[calc[i]],sum_inv=1ll*sum_inv*inv(fac[calc[i]],ps[T])%ps[T]; for(int i=1;i<=n;i++){ Ans[T]=(Ans[T]+1ll*query(0,u,1,0,s[i]-1)*fac[n-i]%ps[T]*sum_inv%ps[T]*pw[numb[n-i]-sum])%ps[T]; sum_inv=1ll*sum_inv*fac[calc[s[i]]]%ps[T]*inv(fac[calc[s[i]]-1],ps[T])%ps[T]; sum-=numb[calc[s[i]]],calc[s[i]]--,sum+=numb[calc[s[i]]]; insert(0,u,1,s[i],-1); }(Ans[T]+=1)%=ps[T]; } } signed main(){ scanf("%d%d",&n,&M);Dec(M);pw[0]=fac[0]=1; for(int i=1;i<=n;i++)scanf("%d",&s[i]),cpy[i]=s[i]; sort(cpy+1,cpy+1+n);u=unique(cpy+1,cpy+1+n)-cpy-1; for(int i=1;i<=n;i++)s[i]=lower_bound(cpy+1,cpy+1+u,s[i])-cpy; solve();printf("%d\n",CRT(Ans,ps)); }