[代码]HDU 4335 What is N?

Abstract

HDU 4335 What is N?

数论

 

Body

Source

http://acm.hdu.edu.cn/showproblem.php?pid=4335

Description

给定p, b(0<=b<p<=10^5)和m(1<=m<2^64),问有多少个n满足n^(n!)=b (mod p)。

Solution

首先要知道这个结论:(反正比赛时我是不知道,只好orz福大核武景润后人了)

a^x = a^(x mod phi(c)+phi(c)) (mod c), x>=phi(c)

然后就很简单了。把[0,m]分成三部分:

1. 对于n!<phi(p)的就枚举;

2. phi(p)<=n!<t!, t=min{x|x! mod phi(p)=0}

套公式就是

n^(n! mod phi(p)+phi(p))=b (mod p)

同样也是枚举,弄个变量存n! mod phi(p)就行。

3. n>=t, t=min{x|x! mod phi(p)=0}

这时候公式就变为

n^phi(p)=b (mod p)

也就是

(n mod p)^phi(p)=b (mod p)

于是n mod p就循环了……

Code

#include <iostream>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

typedef unsigned long long LL;

const int MAX = 200020;
int p[MAX], pcnt=0;
int minf[MAX];
int phi[MAX];
void initprime() {
    for (int i = 2; i < MAX; ++i) {
        if (!minf[i]) {
            p[pcnt++] = i;
            phi[i] = i-1;
        }
        for (int j = 0; j < pcnt && p[j]*i < MAX; ++j) {
            minf[p[j]*i] = p[j];
            if (i%p[j])
                phi[p[j]*i] = phi[i]*(p[j]-1);
            else {
                phi[p[j]*i] = phi[i]*p[j];
                break;
            }
        }
    }
}

LL mod;

LL pow(LL x, LL p) {
    LL res = 1;
    for (; p; p>>=1) {
        if (p&1) res = (res*x)%mod;
        x = (x*x)%mod;
    }
    return res;
}

int T;
int b, c;
LL a;
LL ring[MAX], ringcnt;

int main() {
    initprime();
    cin>>T;
    LL i, j, k, n;
    LL php;
    for (int t = 1; t <= T; ++t) {
        cin>>b>>c>>a;
        mod = c;
        printf("Case #%d: ", t);
        if (c==1) {
            if (b==0) {
                if (a==18446744073709551615ULL)
                    cout<<"18446744073709551616"<<endl;
                else
                    cout<<a+1<<endl;
            }
            else cout<<0<<endl;
            continue;
        }
        LL ans = 0;
        LL fac = 1;
        php = phi[c];

        for (n = 0; n<=a && fac<php; ++n) {
            if (pow(n, fac)==b) ++ans;
            fac *= (n+1);
        }

        fac %= php;
        for (; n<=a && fac; ++n) {
            if (pow(n, fac+php)==b) ++ans;
            fac = fac*(n+1)%php;
        }

        if (n<=a) {
            ringcnt = 0;
            for (i = 0; i < c; ++i) {
                ring[i] = pow(n+i, php);
                if (ring[i]==b) ++ringcnt;
            }
            LL group = (a-n+1)/c;
            ans += group*ringcnt;
            LL left = (a-n+1)-group*c;
            for (i = 0; i < left; ++i)
                ans += (ring[i]==b);
        }

        cout<<ans<<endl;
    }
    return 0;
}
posted @ 2012-08-03 20:49  杂鱼  阅读(510)  评论(0编辑  收藏  举报