由Polya定理可得到最后结果等于1/N*∑N^gcd(i,n);

可是,N≤10^9,枚举i明显会超时,但gcd(i,n)最后得到的结果很少,最多1000多个,于是反过来枚举gcd(i,n)的值L,L即n的某个约数,那么我们需要找到0~n-1中有多少个数与n的约数是L,由扩展欧几里得可以知道,必然存在x,y使得

i*x+n*y=L,由于L是i,n最大公约数,所以可以变成(i/L)*x+(n/L)*y=1,同时mod(n/L),(i/L)*x≡1(mod n/L),即,要找与n/L互质的i/L有多少个,变成欧拉函数了!

于是,最后答案就变成了∑φ(n/L)*N^(L-1),dfs+快速幂取模搞定

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int mr=100001;
bool notp[mr];
int pr[mr];
int pn,mod,n;
void getp()
{
	memset(notp,0,sizeof(notp));
	pn=0;
	for(int i=2;i<mr;i++)
	{
  		if(!notp[i])
			pr[pn++]=i;
		for(int j=0;j<pn&&i*pr[j]<mr;j++)
		{
			int k=i*pr[j];
			notp[k]=1;
			if(i%pr[j]==0)
				break;
		}
	}
}
int fac[1025],tot[1025],num;
void div(int n)
{
    num=0;
    for(int i=0;i<pn&&pr[i]*pr[i]<=n;i++)
    {
        if(n%pr[i]==0)
        {
            tot[num]=0;
            fac[num]=pr[i];
            while(n%pr[i]==0)
            {
                n/=pr[i];
                tot[num]++;
            }
            num++;
        }
    }
    if(n>1)
    {
        fac[num]=n;
        tot[num]=1;
        num++;
    }
}
int getphi(int n)
{
    int tp=n;
    for(int i=0;i<pn&&pr[i]*pr[i]<=n;i++)
    {
        if(n%pr[i]==0)
        {
            tp=tp/pr[i]*(pr[i]-1);
            while(n%pr[i]==0)
                n/=pr[i];
        }
    }
    if(n>1)
        tp=tp/n*(n-1);
    return tp;
}
int mulmod(int a, int b)//a*b%m 避免高精度计算
{
    a=a%mod,b=b%mod;
    int re=0;
    while(b)
    {
        if(b&1) re=(re+a)%mod;
        a=(a<<1)%mod;
        b>>= 1;
    }
    return re;
}
int fastmod(int a,int b)//a^b %m 注意可能要用long long时用long long
{
    int re=1,y=a%mod;
    for(;b;b>>=1,y=mulmod(y, y))
        if(b&1)re=mulmod(y,re);
    return re;
}
int ans;
void solve(int L)
{
    int ph=getphi(n/L)%mod,po=fastmod(n%mod,L-1);
    ans+=mulmod(ph,po);
    if(ans>=mod)
        ans-=mod;
}
void dfs(int k,int L)
{
    if(k==num)
        solve(L);
    else
    {
        int tp=1;
        dfs(k+1,L);
        for(int i=1;i<=tot[k];i++)
        {
            tp*=fac[k];
            dfs(k+1,L*tp);
        }
    }
}
int main()
{
    getp();
    int T;
    for(scanf("%d",&T);T;T--)
    {
        scanf("%d%d",&n,&mod);
        if(mod==1)
        {
            printf("0\n");
            continue;
        }
        else if(n==1)
        {
            printf("1\n");
            continue;
        }
        ans=0;
        div(n);
        dfs(0,1);
        printf("%d\n",ans);
    }
    return 0;
}