Bzoj 3122 [Sdoi2013]随机数生成器(BSGS+exgcd)
Input
输入含有多组数据,第一行一个正整数T,表示这个测试点内的数据组数。
接下来T行,每行有五个整数p,a,b,X1,t,表示一组数据。保证X1和t都是合法的页码。
注意:P一定为质数
Output
共T行,每行一个整数表示他最早读到第t页是哪一天。如果他永远不会读到第t页,输出-1。
Sample Input
3
7 1 1 3 3
7 2 2 2 0
7 2 2 2 1
Sample Output
1
3
-1
HINT
0<=a<=P-1,0<=b<=P-1,2<=P<=10^9
/*
考试的时候没想出来然后打的暴力orz.
其实式子还是挺好推的.
然后用BSGS和exgcd搞.
还有各种可判.
*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<map>
#define LL long long
#define MAXN 101
using namespace std;
LL T,p[MAXN],a[MAXN],b[MAXN],x1[MAXN],t[MAXN];
map<LL,int>s;
bool vis[MAXN];
bool flag1=true,flag2=true,flag3=true;
LL read()
{
LL x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-48,ch=getchar();
return x*f;
}
void slove1(int i)
{
LL tot,sum;
tot=x1[i],sum=1;
while(true)
{
if(tot==t[i]) {printf("%d\n",sum);break;}
if(!vis[tot]) vis[tot]=true;
else {printf("-1\n");break;}
tot=(a[i]*tot+b[i])%p[i];
sum++;
}
for(int j=0;j<p[i];j++) vis[j]=0;
}
void exgcd(LL a1,LL b1,LL &x,LL &y)
{
if(!b1){x=1;y=0;return ;}
exgcd(b1,a1%b1,y,x),y-=(a1/b1)*x;
return ;
}
void slove2(int i)
{
if(!a[i])
{
if(b[i]==t[i]) printf("2\n");
else printf("-1\n");
return ;
}
LL c,g,tot,x,y;
c=(t[i]-x1[i]+p[i])%p[i];
if(!c){cout<<1<<endl;return;}
g=__gcd(b[i],p[i]);
if(c%g) {cout<<-1<<endl;return ;}
exgcd(b[i],p[i],x,y);
x=(x*c/g)%p[i];x++;
x=(x+p[i])%p[i];
cout<<x<<endl;
}
LL mi(LL a1,LL b1,LL p1)
{
LL tot1=1;//a1%=p1;
while(b1)
{
if(b1&1) tot1=tot1*a1%p1;
a1=a1*a1%p1;
b1>>=1;
}
return tot1;
}
void slove3(int i)
{
LL c,g,y;bool flag;
s.clear();flag=false;
c=mi(a[i]-1,p[i]-2,p[i]);
exgcd((b[i]*c+x1[i])%p[i],p[i],g,y);
if(g<p[i]) g=g%p[i]+p[i];
if(a[i]==1)
{
cout<<g+1<<endl;return ;
}
LL tmp1=(b[i]*c+t[i])%p[i],tmp2=__gcd(b[i]*c+x1[i],p[i]);
if(tmp1%tmp2) {cout<<-1<<endl;return ;}
g=((g*(tmp1/tmp2)+p[i])%p[i]+p[i])%p[i];
LL m=ceil(sqrt(p[i])),tot=1,tt;
for(int j=1;j<=m-1;j++)
{
tot=tot*a[i]%p[i];
if(!s[tot]) s[tot]=j;
}
tot=1;tmp1=mi(a[i],p[i]-m-1,p[i]);s[1]=m+1;
for(int k=0;k<=m-1;k++)
{
tt=s[tot*g%p[i]];
if(tt)
{
if(tt==m+1) tt=0;
flag=true;
cout<<k*m+tt+1<<endl;
break;
}
tot=tot*tmp1%p[i];
}
if(!flag) cout<<-1<<endl;
}
int main()
{
T=read();
for(int i=1;i<=T;i++)
{
p[i]=read(),a[i]=read(),b[i]=read(),
x1[i]=read(),t[i]=read();
if(x1[i]==t[i]) {printf("1\n");continue;}
if(p[i]<=200) slove1(i);
else if(a[i]<2) slove2(i);
else slove3(i);
}
return 0;
}