密码学实战 - HTB Shamir's Secret

930 阅读3分钟

概述

Shamir's Secret是来自于HTB(hackthebox.com)的一个中级密码学挑战,完成该挑战所需要掌握的知识点在于模数及其线性方程的计算。

题目分析

相关的任务文件包括server.py源代码和一个在线环境。

server.py内容节选如下

from Crypto.Util.number import *
from secret import flag
import random
import os
def getrandbits(n):
    return bytes_to_long(os.urandom(n // 8))

N = 2**1024

# Generate random key(64-bit number of which 32 of those bits are 1)


key = 0
rem = list(range(64))
for _ in range(32):
    bitpos = random.choice(rem)
    rem.remove(bitpos)
    key |= 1 << bitpos

def doeval(poly, x):
    # Given polynomial and x value, generates y modulo N
    ans = 0
    for i, coeff in enumerate(poly):
        ans += x**i * coeff
        ans %= N
    return ans

def encrypt(msg, key):
    out = ()
    msg = bytes_to_long(msg)
    poly = [msg] + [getrandbits(1024) for _ in range(31)]

    for bitpos in range(64):
        if key & 1 << bitpos != 0:
            # Real
            x = getrandbits(1024)
            out += ((x, doeval(poly, x)),)
        else:
            # Fake
            x = getrandbits(1024)
            y = getrandbits(1024)
            out += ((x,y),)
    return out

def printenc(data, key):
    for pair in encrypt(data, key):
        print(pair)

def menu():
    print("[1]: Get encrypted flag")
    print("[2]: Encrypt your own message")
    return int(input("> "))

doneflagenc = False
try:
    while True:
        try:
            option = menu()
        except:
            print("Invalid menu item")
            continue
        if option == 1:
            if doneflagenc:
                print("Nope")
                continue
            printenc(flag, key)
            doneflagenc = True
        elif option == 2:
            try:
                msg = bytes.fromhex(input("Input message as hex: "))
            except:
                print("Invalid message format")
                continue
            printenc(msg, key)
        else:
            print("Unknown option")
            continue
except:
    print("Unknown error ocurred")

以上代码实现的加密算法包括两个步骤,第一步在于随机生成一个64比特位的key,其中32个比特是0,另32个是1. 第二步则首先生成一个包含32个元素的poly数组,其中poly[0]是明文,而其后的31个元素是随机数, 然后遍历key中的各个比特位,如果该位是1,那么首先生成一个随机数x, 然后返回x(x**0 * poly[0] + x**1 * poly[1] ... +x**31 * poly[31]) mod N的结果。 而如果该位是0, 则返回的两个数都是随机数。

以上代码的运行环境提供两个输入选项,选项1对flag进行加密并返回结果,但只允许使用一次,选项2对用户提供的明文进行加密并返回结果,该选项没有使用次数限制。

解题过程

首先必须获取key, 根据代码提供的算法及数值,我们可以知道当明文和x都是偶数时,(x^0 * poly[0] + x^1 * poly[1] ... +x^31 * poly[31]) mod N的结果必然也是偶数, 因此我们可以使用选项2并输入一个偶数,然后观察返回的结果,如果某个数值对中x是偶数,而另一个是奇数, 那就可以判断该数值对所对应的比特位是0, 反复循环直到发现32个0比特位, 我们就得到了key

然后我们可以使用选项1获取flag的加密结果,因为key中有32个比特位是1,所以我们可以得到32个线性方程,而其未知量就是poly数值中的32个元素,使用sage我们可以对基于N的模数线性方程求解,而结果中的poly[0]就是flag的明文。

解题代码如下

from pwn import remote
from sage.all import *
from Crypto.Util.number import bytes_to_long, long_to_bytes 

conn = remote('165.232.104.184', 32248, level = 'error')

def tesBit():
  conn.recvuntil(">")
  conn.sendline("2")
  conn.recvuntil("Input message as hex: ")
  conn.sendline("00")
  pairs = []
  for i in range(64):
    line = conn.recvline()
    pairs.append(eval(line))
  return pairs
  
def getFlagPairs():
  conn.recvuntil(">")
  conn.sendline("1")
  pairs = []
  for i in range(64):
    line = conn.recvline()
    pairs.append(eval(line))
  return pairs

## recover key
bits = [1] * 64
count_of_zero_bits = 0

while (count_of_zero_bits != 32):
  print(".")
  key_pairs = tesBit()
  for i in range(64):
    x, y = key_pairs[i]
    if ((x % 2 == 0) and (y % 2 == 1)):
      bits[i] = 0       
     
  count_of_zero_bits = sum(bits)

bits.reverse()

key_bits = ""
for i in range(64):
  key_bits = key_bits + str(bits[i])

print("key_bits =", key_bits)

key = int(key_bits, 2)

## recover flag
flag_pairs = getFlagPairs()

R = IntegerModRing(2**1024)

x_values = []
y_values = []

for bitpos in range(64):
  if key & 1 << bitpos != 0:
    # Real
    x, y = flag_pairs[bitpos]
    
    x_list = []
    for j in range(32):
      x_list.append(x**j)    
    
    x_values.append(x_list)    
    y_values.append(y)
      
x_matrix = Matrix(R, x_values)
y_vector = vector(R, y_values)

poly = x_matrix.solve_right(y_vector)

print("flag = ", long_to_bytes(int(poly[0])))