LOJ 6053 简单的函数——min_25筛

题目:https://loj.ac/problem/6053

min_25筛:https://www.cnblogs.com/cjyyb/p/9185093.html

这里把计算 s( n , j ) 需要的“质数部分的贡献”分成两部分算,令 \( g(n,j)=\sum\limits_{i=1}^{n}[i \in P or min_i > p_j]i \) , \( h(n,j)=\sum\limits_{i=1}^{n}[i \in P or min_i > p_j]1 \) ,其中 P 表示质数集合,\( min_i \) 表示 i 的最小质因子。

注意空间是两倍 sqrt(n) 。注意有些地方是 long long 。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;
const int N=1e5+5,mod=1e9+7;
//int upt(int x){if(x>=mod)x-=mod;if(x<0)x+=mod;return x;}
int upt(ll x){if(x>=mod)x-=mod;if(x<0)x+=mod;return x;}

int m,g[N<<1],h[N<<1],s[N<<1],p[N],cnt,sm[N],base;//[N<<1]!!!!!not[N]
ll n,w[N<<1],p2[N];bool vis[N];//w[N<<1]!!!!!not [N]
void get_pri(int n)
{
  for(int i=2;i<=n;i++)
    {
      if(!vis[i])
    {
      p[++cnt]=i;p2[cnt]=(ll)i*i;
      sm[cnt]=upt(sm[cnt-1]+i);
    }
      for(int j=1,d;j<=cnt&&(d=i*p[j])<=n;j++)
    {vis[d]=1; if(i%p[j]==0)break;}
    }
}
int Id(ll x){if(x<=base)return m-x+1; else return n/x;}//<= not <
void cz1()
{
  for(int j=1;j<=cnt;j++)
    for(int i=1;i<=m&&w[i]>=p2[j];i++)
      {
    int k=Id(w[i]/p[j]);
    g[i]=upt( g[i]-(ll)p[j]*upt(g[k]-sm[j-1])%mod );
    h[i]=upt( h[i]-upt(h[k]-(j-1)) );
      }
}
int S(ll x,int y)
{
  if(x<=1||p[y]>x)return 0;
  int k=Id(x),ret=upt(upt(g[k]-h[k])-upt(sm[y-1]-(y-1)));
  if(y==1)ret=upt(ret+2);
  for(int i=y;i<=cnt&&p2[i]<=x;i++)
    {
      ll m1=p[i],m2=p2[i];
      for(int t=1;m2<=x;t++,m1=m2,m2*=p[i])
    ret=( ret + (ll)(p[i]^t)*S(x/m1,i+1) + (p[i]^(t+1)) )%mod;//i+1!!
    }
  return ret;
}
int main()
{
  scanf("%lld",&n);base=sqrt(n);
  get_pri(base);
  for(ll i=1,j;i<=n;i=n/j+1)
    {
      j=n/i; w[++m]=j;
      g[m]=(j-1)%mod*((j+2)%mod)%mod;
      if(g[m]&1)g[m]+=mod; g[m]>>=1;
      h[m]=(j-1)%mod;/////////
    }
  cz1();printf("%d\n",upt(S(n,1)+1));
  return 0;
}
View Code

UPD(2019.4.3):

一些理解:之所以只用 \( <= \sqrt{n} \) 的质数可以得到所有质数的答案,是靠初值。初值里包含了 \( > \sqrt{n} \) 的质数的答案,之后再把合数的答案筛掉,剩下的就是所有质数的答案。

     所以一开始算所有质数答案的时候,给合数赋错误的值也可以,只要满足质数上的值是正确的,并且赋值函数满足积性。一般把合数当做质数看待来赋初值。

     筛合数也只用到 \( <= \sqrt{n} \) 的内容。因为一个合数的 mindiv 一定是 \( <= \sqrt{n} \) 的质数。

 并且尝试了另一种写法。

在计算 s( n , j ) 的时候,可以写成非递归的,式子就是 \( s(n,j)=s(n,j+1)+ f(p_j) + \sum\limits_{t=1}^{p_j^{t+1}<=n} f(p_j^{t+1}) * s( \frac{n}{p_j^t} , j+1 ) \) 。

当 \( p_j^2 > n \) 的时候 s( n , j ) 是没有赋值的。这时候如果要用,就判断一下;因为没有赋值说明此时它的值只有质数部分的,所以用 g 和 h 拼一下即可。

随着 j 变小,有赋值的 n 可以越来越小,用一个指针指一下即可。

好像比递归版慢。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;
const int N=2e5+5,mod=1e9+7;//2e5 not 1e5!!
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

ll n,w[N],p2[N];int m,bs,p[N],cnt;
int sm[N],g[N],h[N],s[N]; bool vis[N];
void init(int n)
{
  ll d;
  for(int i=2;i<=n;i++)
    {
      if(!vis[i])
    {
      p[++cnt]=i; p2[cnt]=(ll)i*i;
      sm[cnt]=upt(sm[cnt-1]+i);
    }
      for(int j=1;j<=cnt&&(d=(ll)i*p[j])<=n;j++)
    { vis[d]=1; if(i%p[j]==0)break;}
    }
}
int Id(ll x){return x>bs?n/x:m-x+1;}
int cz(int i,int j)
{
  if(i==m)return 0;//
  if(vis[i])return s[i];
  return (upt(g[i]-h[i]+(j==1?2:0))-sm[j-1]+(j-1));
}
void solve()
{
  for(int j=1;j<=cnt;j++)
    for(int i=1;i<=m&&w[i]>=p2[j];i++)
      {
    int k=Id(w[i]/p[j]);
    g[i]=upt(g[i]-(ll)p[j]*(g[k]-sm[j-1])%mod);
    h[i]=upt(upt(h[i]-h[k])+(j-1));
      }

  memset(vis,0,sizeof vis);
  int p0=0;
  for(int j=cnt;j;j--)
    {
      while(p0<m&&w[p0+1]>=p2[j])
    {
      p0++;vis[p0]=1;
      s[p0]=upt(upt(g[p0]-h[p0])-sm[j]+j);//S(..,j+1)
    }
      for(int i=1;i<=p0;i++)
    {
      ll m1=p[j], m2=p2[j]; s[i]=upt(s[i]+(p[j]^1));
      for(int t=1;m2<=w[i];t++,m1=m2,m2*=p[j])
        s[i]=(s[i]+(ll)(p[j]^t)*cz(Id(w[i]/m1),j+1)+(p[j]^(t+1)))%mod;
    }
    }
}
int S(ll n,int j)
{
  if(p[j]>n||n==1)return 0; int cr=Id(n);
  int ret=upt(upt(g[cr]-h[cr]+(j==1?2:0))-sm[j-1]+(j-1));
  for(int k=j;k<=cnt&&p2[k]<=n;k++)//k<=cnt
    {
      ll m1=p[k], m2=p2[k];
      for(int t=1;m2<=n;t++,m1=m2,m2*=p[k])
    {
      ret=(ret+(ll)(p[k]^t)*S(n/m1,k+1)+(p[k]^(t+1)))%mod;
    }
    }
  return ret;
}
int main()
{
  scanf("%lld",&n); bs=sqrt(n);
  init(bs); int iv2=pw(2,mod-2);
  for(ll i=1,j;i<=n;i=n/j+1)
    {
      j=n/i; w[++m]=j;
      g[m]=(2+j)%mod*((j-1)%mod)%mod*iv2%mod;
      ///// not (2+j)%mod*(j-1)%mod*iv2%mod!!!
      h[m]=(j-1)%mod;
    }
  solve(); printf("%d\n",upt(s[1]+1));
  return 0;
}
View Code

 

posted on 2019-01-17 11:05  Narh  阅读(273)  评论(0编辑  收藏  举报

导航