【BZOJ4916】神犇和蒟蒻-欧拉函数+杜教筛

测试地址:神犇和蒟蒻
做法:本题需要用到杜教筛。
啊,差不多一年没碰过这东西了,想当初学这个东西学出心理阴影了都……然而不能因为菜就停下自己的脚步,所以先做一道杜教筛基础题复健一下。
对于这道题目,第一问就是玩的,显然当i>1μ(i2)=0,仅有μ(1)=1,所以答案就是1
对于第二问,根据欧拉函数的公式,我们知道φ(i2)=iφ(i),因此要求的就是这样一个积性函数前缀和:i=1niφ(i)。按照杜教筛的套路,要找到一个好求前缀和的积性函数g,使得它和要求的函数f(这道题中f(n)=nφ(n))的狄利克雷卷积也是一个好求前缀和的函数。这里我们找的函数是g(n)=n(在某些地方也写作id),因为显然这个函数是完全积性函数,而它和f的狄利克雷卷积:(fg)(n)=d|ndφ(d)nd=nd|nφ(d)=n2,也显然是一个完全积性函数,而且这两个函数都可以O(1)求出前缀和,那么令S(n)=i=1nf(i),套上杜教筛的公式:
g(1)S(n)=i=1n(fg)(i)i=2ng(i)S(ni)
直接杜教筛即可,注意杜教筛要预处理前n23项前缀和,要用哈希表处理记忆化,这样就可以做到O(n23)的复杂度了。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=1000000007;
const ll hashsiz=2000003;
ll limit,n,phi[1000010],sum[1000010],prime[1000010];
ll hashlist[2000010]={0},hashval[2000010];
bool vis[1000010]={0};

void calc()
{
    phi[1]=1;
    prime[0]=0;
    for(ll i=2;i<=limit;i++)
    {
        if (!vis[i])
        {
            prime[++prime[0]]=i;
            phi[i]=i-1;
        }
        for(ll j=1;j<=prime[0]&&i*prime[j]<=limit;j++)
        {
            vis[i*prime[j]]=1;
            if (i%prime[j]==0)
            {
                phi[i*prime[j]]=phi[i]*prime[j];
                break;
            }
            phi[i*prime[j]]=phi[i]*(prime[j]-1);
        }
    }
    sum[0]=0;
    for(ll i=1;i<=limit;i++)
        sum[i]=(sum[i-1]+i*phi[i])%mod;
}

ll sumg(ll n)
{
    ll inv=500000004;
    return n*(n+1)%mod*inv%mod;
}

ll sumfg(ll n)
{
    ll inv=166666668;
    return n*(n+1)%mod*(2*n+1)%mod*inv%mod;
}

void hashinsert(ll x,ll v)
{
    ll pos=x%hashsiz;
    while(hashlist[pos]&&hashlist[pos]!=x) pos++;
    hashlist[pos]=x;
    hashval[pos]=v;
}

ll hashfind(ll x)
{
    ll pos=x%hashsiz;
    while(hashlist[pos]&&hashlist[pos]!=x) pos++;
    if (hashlist[pos]==x) return pos;
    else return -1;
}

ll solve(ll n)
{
    ll pos=hashfind(n);
    if (n<=limit) return sum[n];
    if (pos!=-1) return hashval[pos]; 
    ll ans=sumfg(n);
    for(ll i=n;i>=2;i=n/(n/i+1))
    {
        ll l=max(2ll,n/(n/i+1)+1),r=i;
        ans-=(solve(n/i)*(sumg(r)-sumg(l-1))%mod+mod)%mod;
        ans=(ans+mod)%mod;
    }
    hashinsert(n,ans);
    return ans;
}

int main()
{
    scanf("%lld",&n);

    printf("1\n");

    for(ll i=1;i*i*i<=n;i++)
        if ((i+1)*(i+1)*(i+1)>n)
        {
            limit=i*i;
            break;
        }

    calc();
    printf("%lld",solve(n));

    return 0;
}
posted @ 2018-05-03 15:54  Maxwei_wzj  阅读(110)  评论(0编辑  收藏  举报