Bzoj 3122 [Sdoi2013]随机数生成器(BSGS+exgcd)

这里写图片描述
Input
输入含有多组数据,第一行一个正整数T,表示这个测试点内的数据组数。
接下来T行,每行有五个整数p,a,b,X1,t,表示一组数据。保证X1和t都是合法的页码。
注意:P一定为质数
Output
共T行,每行一个整数表示他最早读到第t页是哪一天。如果他永远不会读到第t页,输出-1。
Sample Input
3
7 1 1 3 3
7 2 2 2 0
7 2 2 2 1
Sample Output
1
3
-1
HINT
0<=a<=P-1,0<=b<=P-1,2<=P<=10^9

/*
考试的时候没想出来然后打的暴力orz.
其实式子还是挺好推的.
然后用BSGS和exgcd搞.
还有各种可判.
*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<map>
#define LL long long
#define MAXN 101
using namespace std;
LL T,p[MAXN],a[MAXN],b[MAXN],x1[MAXN],t[MAXN];
map<LL,int>s;
bool vis[MAXN];
bool flag1=true,flag2=true,flag3=true;
LL read()
{
    LL x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') x=x*10+ch-48,ch=getchar();
    return x*f;
}
void slove1(int i)
{
    LL tot,sum;
    tot=x1[i],sum=1;
    while(true)
    {
        if(tot==t[i]) {printf("%d\n",sum);break;}
        if(!vis[tot]) vis[tot]=true;
        else {printf("-1\n");break;}
        tot=(a[i]*tot+b[i])%p[i];
        sum++;
    }
    for(int j=0;j<p[i];j++) vis[j]=0;
}
void exgcd(LL a1,LL b1,LL &x,LL &y)
{
    if(!b1){x=1;y=0;return ;}
    exgcd(b1,a1%b1,y,x),y-=(a1/b1)*x;
    return ;
}
void slove2(int i)
{
    if(!a[i])
    {
        if(b[i]==t[i]) printf("2\n");
        else printf("-1\n");
        return ;
    }
    LL c,g,tot,x,y;
    c=(t[i]-x1[i]+p[i])%p[i];
    if(!c){cout<<1<<endl;return;} 
    g=__gcd(b[i],p[i]);
    if(c%g) {cout<<-1<<endl;return ;}
    exgcd(b[i],p[i],x,y);
    x=(x*c/g)%p[i];x++;
    x=(x+p[i])%p[i];
    cout<<x<<endl;
}
LL mi(LL a1,LL b1,LL p1)
{
    LL tot1=1;//a1%=p1;
    while(b1)
    {
        if(b1&1) tot1=tot1*a1%p1;
        a1=a1*a1%p1;
        b1>>=1;
    }
    return tot1;
}
void slove3(int i)
{
    LL c,g,y;bool flag;
    s.clear();flag=false;
    c=mi(a[i]-1,p[i]-2,p[i]);
    exgcd((b[i]*c+x1[i])%p[i],p[i],g,y);
    if(g<p[i]) g=g%p[i]+p[i];
    if(a[i]==1) 
    {
        cout<<g+1<<endl;return ;
    }
    LL tmp1=(b[i]*c+t[i])%p[i],tmp2=__gcd(b[i]*c+x1[i],p[i]);
    if(tmp1%tmp2) {cout<<-1<<endl;return ;}
    g=((g*(tmp1/tmp2)+p[i])%p[i]+p[i])%p[i];
    LL m=ceil(sqrt(p[i])),tot=1,tt;
    for(int j=1;j<=m-1;j++)
    {
        tot=tot*a[i]%p[i];
        if(!s[tot]) s[tot]=j;
    }
    tot=1;tmp1=mi(a[i],p[i]-m-1,p[i]);s[1]=m+1;
    for(int k=0;k<=m-1;k++)
    {
        tt=s[tot*g%p[i]];
        if(tt)  
        {
            if(tt==m+1) tt=0;
            flag=true;
            cout<<k*m+tt+1<<endl;   
            break;  
        }
        tot=tot*tmp1%p[i];
    }
    if(!flag) cout<<-1<<endl;
}
int main()
{
    T=read();
    for(int i=1;i<=T;i++)
    {
        p[i]=read(),a[i]=read(),b[i]=read(),
        x1[i]=read(),t[i]=read();
        if(x1[i]==t[i]) {printf("1\n");continue;}
        if(p[i]<=200) slove1(i);
        else if(a[i]<2) slove2(i);
        else slove3(i);
    }
    return 0;
}
posted @ 2017-02-24 17:27  nancheng58  阅读(135)  评论(0编辑  收藏  举报