Ressenger/ressenger_cryptography.py

201 lines
7.0 KiB
Python
Executable File

#!/usr/bin/python3
import struct
from typing import Optional, Tuple
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import AES, PKCS1_OAEP
from Cryptodome.Protocol.KDF import PBKDF2
from Cryptodome.Random import get_random_bytes
from Cryptodome.Util.Padding import pad, unpad
from Cryptodome.Signature import pss
from Cryptodome.Hash import SHA256
RSA_BITS = 4096 # change to 4096 if you want stronger RSA keys (slower)
AES_KEY_LEN = 32 # AES-256
AES_NONCE_LEN = 12 # recommended nonce length for GCM
TAG_LEN = 16 # GCM tag length
def encrypt_bytes(data: bytes, password: str, *, salt: bytes = None) -> bytes:
if salt is None:
salt = get_random_bytes(16)
key = PBKDF2(password, salt, dkLen=32, count=100_000)
iv = get_random_bytes(16)
cipher = AES.new(key, AES.MODE_CBC, iv)
ct = cipher.encrypt(pad(data, AES.block_size))
return salt + iv + ct
def decrypt_bytes(token: bytes, password: str) -> bytes:
salt = token[:16]
iv = token[16:32]
ct = token[32:]
key = PBKDF2(password, salt, dkLen=32, count=100_000)
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(ct), AES.block_size)
return pt
def generate_keypair(bits: int = RSA_BITS) -> Tuple[bytes, bytes]:
"""
Generate an RSA keypair. Returns (private_pem_bytes, public_pem_bytes).
"""
key = RSA.generate(bits)
priv_pem = key.export_key(format='PEM')
pub_pem = key.publickey().export_key(format='PEM')
return priv_pem, pub_pem
def encrypt(plaintext: bytes, encryption_public_key_pem: bytes, signing_private_key_pem: bytes) -> Optional[Tuple[bytes, bytes]]:
"""
Encrypt + sign.
Inputs:
plaintext: bytes to encrypt
encryption_public_key_pem: recipient RSA public key (PEM bytes)
signing_private_key_pem: sender RSA private key (PEM bytes) for signing
Returns:
(encrypted_blob_bytes, signature_bytes) on success, or None on error.
"""
try:
# Import keys
recipient_pub = RSA.import_key(encryption_public_key_pem)
signer_priv = RSA.import_key(signing_private_key_pem)
# 1) generate ephemeral AES key and encrypt it with RSA-OAEP (SHA-256)
aes_key = get_random_bytes(AES_KEY_LEN)
rsa_cipher = PKCS1_OAEP.new(recipient_pub, hashAlgo=SHA256)
rsa_ct = rsa_cipher.encrypt(aes_key)
# 2) encrypt plaintext with AES-GCM
aes_cipher = AES.new(aes_key, AES.MODE_GCM, nonce=None) # let library choose nonce
nonce = aes_cipher.nonce
ciphertext, tag = aes_cipher.encrypt_and_digest(plaintext)
# 3) assemble wire format (without signature)
parts = []
parts.append(struct.pack(">H", len(rsa_ct)))
parts.append(rsa_ct)
parts.append(struct.pack("B", len(nonce)))
parts.append(nonce)
parts.append(struct.pack(">I", len(ciphertext)))
parts.append(ciphertext)
parts.append(tag) # 16 bytes
encoded = b"".join(parts)
# 4) sign the encoded blob with RSA-PSS (SHA-256)
h = SHA256.new(encoded)
signer = pss.new(signer_priv)
signature = signer.sign(h)
# 5) append signature + length
final = encoded + struct.pack(">H", len(signature)) + signature
return final, signature
except Exception:
# Per your API: on invalid input or failure, return None
return None
def decrypt(encrypted_bytes: bytes, decryption_private_key_pem: bytes, verification_public_key_pem: bytes) -> Optional[Tuple[bytes, bool]]:
"""
Decrypt + verify.
Inputs:
encrypted_bytes: blob produced by encrypt()
decryption_private_key_pem: recipient RSA private key (PEM bytes)
verification_public_key_pem: sender RSA public key (PEM bytes)
Returns:
(plaintext_bytes, verification_result_bool) on success, or None on malformed input / failure.
verification_result_bool is True if signature verified, False otherwise.
"""
try:
buf = encrypted_bytes
idx = 0
# Parse RSA ciphertext length (2 bytes)
if idx + 2 > len(buf): return None
rsa_ct_len = struct.unpack_from(">H", buf, idx)[0]; idx += 2
if idx + rsa_ct_len > len(buf): return None
rsa_ct = buf[idx: idx + rsa_ct_len]; idx += rsa_ct_len
# nonce length (1 byte)
if idx + 1 > len(buf): return None
nonce_len = struct.unpack_from("B", buf, idx)[0]; idx += 1
if idx + nonce_len > len(buf): return None
nonce = buf[idx: idx + nonce_len]; idx += nonce_len
# ciphertext length (4 bytes)
if idx + 4 > len(buf): return None
ct_len = struct.unpack_from(">I", buf, idx)[0]; idx += 4
if idx + ct_len > len(buf): return None
ciphertext = buf[idx: idx + ct_len]; idx += ct_len
# tag (16 bytes)
if idx + TAG_LEN > len(buf): return None
tag = buf[idx: idx + TAG_LEN]; idx += TAG_LEN
# signature length (2 bytes) + signature
if idx + 2 > len(buf): return None
sig_len = struct.unpack_from(">H", buf, idx)[0]; idx += 2
if idx + sig_len > len(buf): return None
signature = buf[idx: idx + sig_len]; idx += sig_len
# the part that was signed is everything up to the signature length field
signed_part = buf[: (len(buf) - (2 + sig_len))]
# RSA decapsulate AES key
recipient_priv = RSA.import_key(decryption_private_key_pem)
rsa_cipher = PKCS1_OAEP.new(recipient_priv, hashAlgo=SHA256)
try:
aes_key = rsa_cipher.decrypt(rsa_ct)
except Exception:
return None
# AES-GCM decrypt and verify
aes_cipher = AES.new(aes_key, AES.MODE_GCM, nonce=nonce)
try:
plaintext = aes_cipher.decrypt_and_verify(ciphertext, tag)
except Exception:
return None
# verify signature with RSA-PSS
verifier_pub = RSA.import_key(verification_public_key_pem)
h = SHA256.new(signed_part)
verifier = pss.new(verifier_pub)
try:
verifier.verify(h, signature)
verified = True
except (ValueError, TypeError):
verified = False
return plaintext, verified
except Exception:
return None
#
# # Demonstration (only runs when invoked as script)
# if __name__ == "__main__":
# # generate recipient encryption keypair (recipient)
# priv_enc, pub_enc = generate_keypair()
# # generate sender signing keypair (sender)
# priv_sig, pub_sig = generate_keypair()
#
# message = b"I love C!"
# out = encrypt(message, pub_enc, priv_sig)
# if out is None:
# raise SystemExit("Encryption failed")
# blob, sig = out
# print("Encrypted blob length:", len(blob), "signature length:", len(sig))
#
# got = decrypt(blob, priv_enc, pub_sig)
# if got is None:
# raise SystemExit("Decryption/verification failed")
# plaintext, verified = got
# print("Decrypted:", plaintext)
# print("Signature verified:", verified)