BZOJ1319Sgu261Discrete Roots——BSGS+exgcd+原根与指标+欧拉定理

题目描述

给出三个整数p,k,a,其中p为质数,求出所有满足x^k=a (mod p),0<=x<=p-1的x。

输入

三个整数p,k,a。

输出

第一行一个整数,表示符合条件的x的个数。 第二行开始每行一个数,表示符合条件的x,按从小到大的顺序输出。

样例输入

11 3 8

样例输出

1
2

提示

2<=p<p<=10^9
 2<=k<=100000,0<=a

 

首先求出$p$的原根$g$,再求出$a$的指标$b$,即$g^b\equiv a(mod\ p)$。我们知道对于$[0,p-1]$中任意数都能用原根的幂次表示,所以将$x$表示成$g^y$即$g^y\equiv x(mod\ p)$,那么原式就变成了$(g^y)^k\equiv g^b(mod\ p)->g^{yk}\equiv g^b(mod\ p)$。根据欧拉定理可知$g^{p-1}\equiv 1(mod\ p)$,所以$yk\equiv b(mod\ (p-1))$,只需要用$exgcd$求出$[0,p-1]$内所有的$y$即可。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
ll p,k,a;
ll g,f;
ll prime[100010];
int tot;
ll q[100010];
int cnt;
map<ll,int>b;
ll quick(ll x,ll y)
{
	ll res=1ll;
	while(y)
	{
		if(y&1)
		{
			res=res*x%p;
		}
		y>>=1;
		x=x*x%p;
	}
	return res;
}
ll gcd(ll x,ll y)
{
	return y==0?x:gcd(y,x%y);
}
void exgcd(ll A,ll B,ll &x,ll &y)
{
    if(!B)
    {
        x=1;
        y=0;
        return ;
    }
    exgcd(B,A%B,y,x);
    y-=(A/B)*x;
}
int main()
{
	scanf("%lld%lld%lld",&p,&k,&a);
	if(!a)
	{
		printf("1\n0");
		return 0;
	}
	ll m=p-1;
	for(int i=2;1ll*i*i<=m;i++)
	{
		if(m%i==0)
		{
			prime[++tot]=i;
			while(m%i==0) 
			{
				m/=i;
			}
		}
	}
	if(m!=1) 
	{
		prime[++tot]=m;
	}
	for(int i=2;i<=p-1;i++)
	{
		bool flag=true;
		for(int j=1;j<=tot;j++) 
		{
			if(quick(i,(p-1)/prime[j])==1)
			{
				flag=false;
				break;
			}
		}
		if(flag)
		{
			g=i;
			break;
		}
	}
	int n=ceil(sqrt(p));
	ll sum=1ll;
	for(int i=1;i<=n;i++)
	{
		(sum*=g)%=p;
		b[sum]=i;
	}
	ll num=1ll;
	for(int i=0;i<=n;i++)
	{
		ll inv=quick(num,p-2);
		if(b[inv*a%p])
		{
			f=i*n+b[inv*a%p];
			break;
		}
		(num*=sum)%=p;
	}
	ll d=gcd(k,p-1);
	if(f%d)
	{
		printf("0");
		return 0;
	}
	f/=d;
	ll X=k/d;
	ll Y=(p-1)/d;
	ll x,y;
	exgcd(X,Y,x,y);
	(x*=f)%=(p-1);
	x%=Y;
	if(x<=0)
	{
		x+=Y;
	}
	while(x<=p-1)
	{
		q[++cnt]=quick(g,x);
		x+=Y;
	}
	sort(q+1,q+1+cnt);
	printf("%d\n",cnt);
	for(int i=1;i<=cnt;i++)
	{
		printf("%lld\n",q[i]);
	}
}
posted @ 2019-02-13 22:42  The_Virtuoso  阅读(336)  评论(0编辑  收藏  举报