[CISCN 2022 东北赛区]math

chall.py

import gmpy2
from Crypto.Util.number import *
from flag import flag
assert flag.startswith(b"flag{")
assert flag.endswith(b"}")
message=bytes_to_long(flag)
def keygen(nbit, dbit):
    if 2*dbit < nbit:
        while True:
            a1 = getRandomNBitInteger(dbit) # 256bit
            b1 = getRandomNBitInteger(nbit//2-dbit) # 256bit

            n1 = a1*b1+1 # 512bit

            if isPrime(n1):
                break
        while True:
            a2 = getRandomNBitInteger(dbit)
            b2 = getRandomNBitInteger(nbit//2-dbit)

            n2=a2*b2+1 # 512bit

            n3=a1*b2+1 # 512bit
            if isPrime(n2) and isPrime(n3):
                break
        while True:
            a3=getRandomNBitInteger(dbit) # 256bit
            if gmpy2.gcd(a3,a1*b1*a2*b2)==1: # gcd(e,phi1) == 1
                v1=(n1-1)*(n2-1) # v1=phi1
                k=(a3*inverse(a3,v1)-1)//v1 # k*v1+1 = k*phi1+1 = a3*d1
                v2=k*b1+1
                if isPrime(v2):
                    return a3,n1*n2,n3*v2
def encrypt(msg, pubkey):
    return pow(msg, pubkey[0], pubkey[1])

nbit = 1024
dbit = 256
e, n1, n2=keygen(nbit, dbit)
print('e =', e)
print('n1 =', n1)
print('n2 =', n2)
c1 = encrypt(message, [e, n1])
c2 = encrypt(message, [e, n2])
print('enc1 =', c1)
print('enc2 =', c2)
# e = 86905291018330218127760596324522274547253465551209634052618098249596388694529
# n1 = 112187114035595515717020336420063560192608507634951355884730277020103272516595827630685773552014888608894587055283796519554267693654102295681730016199369580577243573496236556117934113361938190726830349853086562389955289707685145472794173966128519654167325961312446648312096211985486925702789773780669802574893
# n2 = 95727255683184071257205119413595957528984743590073248708202176413951084648626277198841459757379712896901385049813671642628441940941434989886894512089336243796745883128585743868974053010151180059532129088434348142499209024860189145032192068409977856355513219728891104598071910465809354419035148873624856313067
# enc1 = 71281698683006229705169274763783817580572445422844810406739630520060179171191882439102256990860101502686218994669784245358102850927955191225903171777969259480990566718683951421349181856119965365618782630111357309280954558872160237158905739584091706635219142133906953305905313538806862536551652537126291478865
# enc2 = 7333744583943012697651917897083326988621572932105018877567461023651527927346658805965099102481100945100738540533077677296823678241143375320240933128613487693799458418017975152399878829426141218077564669468040331339428477336144493624090728897185260894290517440392720900787100373142671471448913212103518035775

很好的一道题
拿到题 分析了一下 还是得分解n1或者得到d1 or phi1的值才能做
观察 n1,n2,n3 注意到n3的生成方式:a1 * b2+1那么跟v3相乘后与n1 * n2其实有很多相似的项
做一个分式 (n1xn2)/(n3xv2) = (a1xb1+1)x(a2xb2+1)/(a1xb2+1)x(kxb1+1)约等于a2/k
所以连分数展开渐进分数逼近可以得到a2,k的值
关键是得到a2,k后怎么处理
注意到 kxφ(n1)+1=0(mod e) φ(n1)=0(mod a2)
由CRT就可以解出φ(n1) mod (lcm(e,a2))的值
大致估算一下bit位: φ在1024bit左右(与n1接近) lcm(e,a2)大约在511~512bit
由于φ = N1-(n1+n2)+1 (n1+n2)约等于512bit 所以可以大致确定一下范围

φ : φ'+(N1//lcm) * lcm - 100 * lcm ~ N1
这里的100取得很宽了 反正爆破快 这样只要枚举100次 check一下即可得到flag
solution.py

e = 86905291018330218127760596324522274547253465551209634052618098249596388694529
n1 = 112187114035595515717020336420063560192608507634951355884730277020103272516595827630685773552014888608894587055283796519554267693654102295681730016199369580577243573496236556117934113361938190726830349853086562389955289707685145472794173966128519654167325961312446648312096211985486925702789773780669802574893
n2 = 95727255683184071257205119413595957528984743590073248708202176413951084648626277198841459757379712896901385049813671642628441940941434989886894512089336243796745883128585743868974053010151180059532129088434348142499209024860189145032192068409977856355513219728891104598071910465809354419035148873624856313067
c1 = 71281698683006229705169274763783817580572445422844810406739630520060179171191882439102256990860101502686218994669784245358102850927955191225903171777969259480990566718683951421349181856119965365618782630111357309280954558872160237158905739584091706635219142133906953305905313538806862536551652537126291478865
c2 = 7333744583943012697651917897083326988621572932105018877567461023651527927346658805965099102481100945100738540533077677296823678241143375320240933128613487693799458418017975152399878829426141218077564669468040331339428477336144493624090728897185260894290517440392720900787100373142671471448913212103518035775

# num = n1/n2
# cf = continued_fraction(num)
# alist = cf.convergents()
# for i in alist:
#     a = str(i).split('/')
#     if(len(a)>1):
#         x, y = int(a[0]), int(a[1])
#         if(x.bit_length()==256):
#             print(f'a2= {x}\nk= {y}')
a2= 77847068777976205641001327374129461800550106446456531629024890581516199408639
k= 66425509927381621759828972423608734959235404098353997682249694809002263660141
a2= 98480275536866812059232182842816874069457047339437560195078186997171442641348
k= 84031455814764285608244298716519389101911648259508762219157077342974897619853

from sympy.ntheory.modular import crt
from libnum import *
from primefac import *
import gmpy2

invk = modinv(k,e)
_phi = (crt([e,a2],[-invk,a2])[0])
mod = (a2*e)//gmpy2.gcd(a2,e)
_phi %= mod
print(_phi)
print(_phi.bit_length())

left = _phi + (n1 // mod) * mod - 100*mod
for i in range(105):
    phi = left + i*mod
    try:
        d = modinv(e,phi)
        m = pow(c1,d,n1)
        m = n2s(int(m))
        if b'flag' in m:
            print(m)
    except:
        pass

# b'flag{b5073f3d774c460ae2b714010cc69435}'

回顾整道题发现最关键的一步在于 想到利用CRT算mode 和 moda2(!!!)

posted @ 2023-11-11 23:10  N0zoM1z0  阅读(52)  评论(0编辑  收藏  举报