BZOJ1129 : [POI2008]Per

枚举LCP,假设前$i-1$个都相同。
那么后面$n-i$个数可以随意排列,第$i$个位置可以填的方案数为后面小于$a_i$的数字个数,树状数组维护。

同时为了保证本质不同,方案数需要除以每个数字的个数的阶乘。

将$m$分解质因数,然后CRT合并即可。

可以先用树状数组处理出所有贡献。

同时在分开计算答案的时候,除了某个超过$\sqrt{m}$的大因子之外,其它模数的逆元都可以线性预处理。

所以总时间复杂度为$O(n\log n)$。

 

#include<cstdio>
#include<algorithm>
#define N 300010
typedef long long ll;
int n,m,i,a[N],b[N],c[N],bit[N],f[N],g[N],ans,flag,K;ll B,P,x,y,inv[N];
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
inline int lower(int x){
  int l=1,r=n,mid,t;
  while(l<=r)if(b[mid=(l+r)>>1]<=x)l=(t=mid)+1;else r=mid-1;
  return t;
}
ll exgcd(ll a,ll b){
  if(!b)return x=1,y=0,a;
  ll d=exgcd(b,a%b),t=x;
  return x=y,y=t-a/b*y,d;
}
inline ll rev(ll a){
  if(flag&&a<P)return inv[a];
  exgcd(a,P);
  return (x+P)%P;
}
inline void add(int x,int y){for(;x<=n;x+=x&-x)bit[x]+=y;}
inline int ask(int x){int t=0;for(;x;x-=x&-x)t+=bit[x];return t;}
struct Num{
  ll a,b;
  Num(){a=1,b=0;}
  Num(ll _a,ll _b){a=_a,b=_b;}
  Num operator*(const Num&x){return Num(a*x.a%P,b+x.b);}
  Num operator/(const Num&x){return Num(a*rev(x.a)%P,b-x.b);}
  ll val(){
    if(!a||b>=K)return 0;
    ll t=a,x=B,k=b;
    for(;k;k>>=1,x=x*x%P)if(k&1)t=t*x%P;
    return t;
  }
  void set(ll n){
    a=n,b=0;
    while(a%B==0)a/=B,b++;
    a%=P;
  }
}w[N],t;
void solve(ll _B,ll _P,int _K){
  B=_B,P=_P,K=_K;
  if(P<=n){
    flag=1;
    for(inv[0]=inv[1]=1,i=2;i<P;i++)inv[i]=(P-inv[P%i])*(P/i)%P;
  }else flag=0;
  ll tmp=1LL*(m/P)*rev(m/P)%m;
  int i;
  for(w[0].set(i=1);i<=n;i++)w[i].set(i);
  for(i=1;i<=n;i++)w[i]=w[i]*w[i-1];
  for(i=1;i<=n;i++)c[i]=0;
  for(i=1;i<=n;i++)c[a[i]]++;
  Num all(1,0);
  for(i=1;i<=n;i++)all=all*w[c[i]];
  for(i=1;i<=n;i++){
    if(g[i])t.set(g[i]),f[i]=(tmp*(t*w[n-i]/all).val()+f[i])%m;
    c[a[i]]--;
    all=all/w[c[a[i]]+1]*w[c[a[i]]];
  }
}
void divide(int n){
  int i=1;
  for(int i=2;i*i<=n;i++)if(n%i==0){
    int x=1,k=0;
    while(n%i==0)n/=i,x*=i,k++;
    solve(i,x,k);
  }
  if(n>1)solve(n,n,1);
}
int main(){
  read(n),read(m);
  for(i=1;i<=n;i++)read(a[i]),b[i]=a[i];
  std::sort(b+1,b+n+1);
  for(i=1;i<=n;i++)a[i]=lower(a[i]);
  for(i=1;i<=n;i++)c[a[i]]++;
  for(i=1;i<=n;i++)add(i,c[i]);
  for(i=1;i<=n;i++)g[i]=ask(a[i]-1),add(a[i],-1);
  divide(m);
  for(ans=i=1;i<=n;i++)ans=(ans+f[i])%m;
  return printf("%d",ans),0;
}

  

posted @ 2016-12-14 01:56  Claris  阅读(567)  评论(0编辑  收藏  举报