Curso de Criptografía Post-Cuántica

Un enfoque introductorio a la seguridad en la era cuántica

ML-DSA: Firma Digital Post-Cuántica

ML-DSA (Module-Lattice Digital Signature Algorithm), anteriormente conocido como CRYSTALS-Dilithium, es el algoritmo de firma digital post-cuántica seleccionado por el NIST como estándar (FIPS 204). Este algoritmo se basa en problemas matemáticos de retículos modulares, específicamente en el problema Learning With Errors (LWE) y su variante Ring-LWE.

En este ejemplo, exploraremos cómo funciona ML-DSA y veremos una implementación simplificada que ilustra sus operaciones principales: generación de claves, firma y verificación.

Implementación Simplificada de ML-DSA

El siguiente código es una implementación didáctica simplificada de ML-DSA. No es criptográficamente segura y solo tiene fines educativos para ilustrar los conceptos básicos del algoritmo.

"""
Implementación didáctica simplificada de ML-DSA (FIPS 204)
NOTA: Esta implementación es solo para fines educativos y no es criptográficamente segura.
"""
import numpy as np
import hashlib
from numpy.polynomial import polynomial as poly

# Parámetros simplificados para ML-DSA
q = 8380417  # Módulo primo
n = 256      # Grado del polinomio
k = 4        # Número de vectores polinomiales en la matriz A
l = 4        # Número de vectores polinomiales en los vectores s1, s2
eta = 2      # Parámetro de distribución
gamma1 = 2**17  # Parámetro de rechazo
gamma2 = 95232  # Parámetro de rechazo
tau = 39     # Número de coeficientes distintos de cero en c
d = 13       # Bits truncados en t1

# Funciones auxiliares
def ntt(f):
    """Transformada Número-Teórica simplificada (simulada)"""
    # En una implementación real, esto sería una NTT eficiente
    return np.fft.rfft(f)

def intt(F):
    """Transformada Número-Teórica Inversa simplificada (simulada)"""
    # En una implementación real, esto sería una INTT eficiente
    return np.fft.irfft(F, n=n).real.astype(int)

def sample_in_ball(seed, tau):
    """Muestrea un polinomio con exactamente tau coeficientes ±1"""
    np.random.seed(int.from_bytes(seed, byteorder='big'))
    c = np.zeros(n, dtype=int)
    positions = np.random.choice(n, tau, replace=False)
    signs = np.random.choice([-1, 1], tau)
    c[positions] = signs
    return c

def high_bits(r, d):
    """Extrae los bits de mayor orden de r"""
    return (r + 2**(d-1)) // 2**d

def low_bits(r, d):
    """Extrae los bits de menor orden de r"""
    return r - high_bits(r, d) * 2**d

def hash_to_point(message, w1):
    """Hash del mensaje y w1 a un polinomio con tau coeficientes ±1"""
    h = hashlib.sha3_256()
    h.update(message.encode())
    h.update(w1.tobytes())
    return sample_in_ball(h.digest(), tau)

# Generación de claves
def keygen():
    """Genera un par de claves ML-DSA"""
    # Generar semilla aleatoria para A
    seed_a = np.random.bytes(32)
    
    # Generar matriz A usando la semilla (en una implementación real, esto sería expandido)
    np.random.seed(int.from_bytes(seed_a, byteorder='big'))
    A = np.random.randint(-q//2, q//2, (k, l, n))
    
    # Muestrear vectores secretos s1, s2 con coeficientes pequeños
    s1 = np.random.randint(-eta, eta+1, (l, n))
    s2 = np.random.randint(-eta, eta+1, (k, n))
    
    # Calcular t = A·s1 + s2
    t = np.zeros((k, n), dtype=int)
    for i in range(k):
        for j in range(l):
            t[i] = (t[i] + poly.polymul(A[i, j], s1[j]) % poly.polydiv([1] + [0]*n + [1], [1])[1]) % q
        t[i] = (t[i] + s2[i]) % q
    
    # Dividir t en t1 (bits altos) y t0 (bits bajos)
    t1 = np.array([high_bits(t[i], d) for i in range(k)])
    t0 = np.array([low_bits(t[i], d) for i in range(k)])
    
    # Clave pública: (seed_a, t1)
    # Clave privada: (seed_a, s1, s2, t0)
    return {
        'public': {'seed_a': seed_a, 't1': t1},
        'private': {'seed_a': seed_a, 's1': s1, 's2': s2, 't0': t0}
    }

# Firma
def sign(message, private_key):
    """Firma un mensaje usando ML-DSA"""
    seed_a = private_key['seed_a']
    s1 = private_key['s1']
    s2 = private_key['s2']
    t0 = private_key['t0']
    
    # Reconstruir A a partir de la semilla
    np.random.seed(int.from_bytes(seed_a, byteorder='big'))
    A = np.random.randint(-q//2, q//2, (k, l, n))
    
    # Reconstruir t
    t = np.zeros((k, n), dtype=int)
    for i in range(k):
        t[i] = (high_bits(t[i], d) * 2**d + t0[i]) % q
    
    while True:
        # Generar vector aleatorio y
        y = np.random.randint(-gamma1+1, gamma1, (l, n))
        
        # Calcular w = A·y
        w = np.zeros((k, n), dtype=int)
        for i in range(k):
            for j in range(l):
                w[i] = (w[i] + poly.polymul(A[i, j], y[j]) % poly.polydiv([1] + [0]*n + [1], [1])[1]) % q
        
        # Extraer bits altos de w
        w1 = np.array([high_bits(w[i], d) for i in range(k)])
        
        # Calcular c = H(message, w1)
        c = hash_to_point(message, w1)
        
        # Calcular z = y + c·s1
        z = np.zeros((l, n), dtype=int)
        for i in range(l):
            z[i] = (y[i] + poly.polymul(c, s1[i]) % poly.polydiv([1] + [0]*n + [1], [1])[1]) % q
        
        # Verificar si z está en el rango correcto
        if np.any(np.abs(z) >= gamma1 - 32):
            continue  # Rechazar y reintentar
        
        # Calcular r0 = w - c·s2
        r0 = np.zeros((k, n), dtype=int)
        for i in range(k):
            r0[i] = (w[i] - poly.polymul(c, s2[i]) % poly.polydiv([1] + [0]*n + [1], [1])[1]) % q
            # Normalizar al rango [-q/2, q/2)
            r0[i] = (r0[i] + q//2) % q - q//2
        
        # Verificar si los bits bajos de r0 están en el rango correcto
        r0_low = np.array([low_bits(r0[i], d) for i in range(k)])
        if np.any(np.abs(r0_low) >= gamma2 - 32):
            continue  # Rechazar y reintentar
        
        # Calcular hints h para la verificación
        h = np.zeros((k, n), dtype=bool)
        for i in range(k):
            for j in range(n):
                # Verificar si los bits altos de r0 y w difieren
                if high_bits(r0[i][j], d) != high_bits(w[i][j], d):
                    h[i][j] = True
        
        # La firma es (z, h, c)
        return {'z': z, 'h': h, 'c': c}

# Verificación
def verify(message, signature, public_key):
    """Verifica una firma ML-DSA"""
    seed_a = public_key['seed_a']
    t1 = public_key['t1']
    z = signature['z']
    h = signature['h']
    c = signature['c']
    
    # Verificar si z está en el rango correcto
    if np.any(np.abs(z) >= gamma1 - 32):
        return False
    
    # Reconstruir A a partir de la semilla
    np.random.seed(int.from_bytes(seed_a, byteorder='big'))
    A = np.random.randint(-q//2, q//2, (k, l, n))
    
    # Calcular Az
    Az = np.zeros((k, n), dtype=int)
    for i in range(k):
        for j in range(l):
            Az[i] = (Az[i] + poly.polymul(A[i, j], z[j]) % poly.polydiv([1] + [0]*n + [1], [1])[1]) % q
    
    # Calcular ct1*2^d
    ct1 = np.zeros((k, n), dtype=int)
    for i in range(k):
        ct1[i] = (poly.polymul(c, t1[i]) % poly.polydiv([1] + [0]*n + [1], [1])[1] * 2**d) % q
    
    # Calcular w' = Az - ct1*2^d (con ajustes por hints)
    w_prime = np.zeros((k, n), dtype=int)
    for i in range(k):
        w_prime[i] = (Az[i] - ct1[i]) % q
        # Normalizar al rango [-q/2, q/2)
        w_prime[i] = (w_prime[i] + q//2) % q - q//2
        
        # Aplicar hints para ajustar los bits altos
        for j in range(n):
            if h[i][j]:
                if w_prime[i][j] > 0:
                    w_prime[i][j] -= 2**d
                else:
                    w_prime[i][j] += 2**d
    
    # Extraer bits altos de w'
    w1_prime = np.array([high_bits(w_prime[i], d) for i in range(k)])
    
    # Calcular c' = H(message, w1')
    c_prime = hash_to_point(message, w1_prime)
    
    # Verificar si c = c'
    return np.array_equal(c, c_prime)

# Ejemplo de uso
if __name__ == "__main__":
    # Generar par de claves
    keys = keygen()
    print("Par de claves generado.")
    
    # Firmar un mensaje
    message = "Hola, mundo! Este es un mensaje de prueba para ML-DSA."
    signature = sign(message, keys['private'])
    print("Mensaje firmado.")
    
    # Verificar la firma
    result = verify(message, signature, keys['public'])
    print(f"Verificación de firma: {'Exitosa' if result else 'Fallida'}")
    
    # Intentar verificar con un mensaje modificado
    modified_message = message + " Texto añadido."
    result = verify(modified_message, signature, keys['public'])
    print(f"Verificación con mensaje modificado: {'Exitosa' if result else 'Fallida'}")

Explicación del Algoritmo

Fundamentos Matemáticos

ML-DSA se basa en el problema Learning With Errors (LWE) en anillos de polinomios. La seguridad del esquema depende de la dificultad de encontrar vectores cortos en retículos, un problema que se considera resistente a ataques cuánticos.

Componentes Principales

Características de ML-DSA

Comparación con Otros Esquemas de Firma

Esquema Base Matemática Tamaño de Clave Pública Tamaño de Firma Velocidad de Firma Velocidad de Verificación
ML-DSA-44 (FIPS 204) Retículos Modulares 1.3 KB 2.4 KB Rápida Rápida
SLH-DSA-SHAKE (FIPS 205) Hash 32 bytes 7.9 KB Lenta Lenta
RSA-2048 Factorización 256 bytes 256 bytes Lenta Rápida
ECDSA P-256 Curvas Elípticas 32 bytes 64 bytes Rápida Moderada

Aplicaciones Prácticas

ML-DSA está diseñado para reemplazar los esquemas de firma digital actuales (RSA, DSA, ECDSA) en aplicaciones como:

Demostración Interactiva: ML-DSA

Estado

Esperando acción...

Clave Pública

Clave Privada

Firma