Add RSA encryption
parent
9a0f720dc0
commit
3805f80b75
|
@ -21,3 +21,190 @@ def decrypt_bytes(token: bytes, password: str) -> bytes:
|
|||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
pt = unpad(cipher.decrypt(ct), AES.block_size)
|
||||
return pt
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import struct
|
||||
|
||||
from Cryptodome.PublicKey import RSA
|
||||
from Cryptodome.Cipher import PKCS1_OAEP
|
||||
from Cryptodome.Signature import pss
|
||||
from Cryptodome.Hash import SHA256
|
||||
|
||||
# Configuration
|
||||
RSA_BITS = 4096 # change to 4096 if you want stronger RSA keys (slower)
|
||||
AES_KEY_LEN = 32 # AES-256
|
||||
AES_NONCE_LEN = 12 # recommended nonce length for GCM
|
||||
TAG_LEN = 16 # GCM tag length
|
||||
# Wire format:
|
||||
# [2 bytes rsa_ct_len][rsa_ct]
|
||||
# [1 byte nonce_len][nonce]
|
||||
# [4 bytes ct_len][ciphertext]
|
||||
# [16 bytes tag]
|
||||
# [2 bytes sig_len][signature]
|
||||
|
||||
def generate_keypair(bits: int = RSA_BITS) -> Tuple[bytes, bytes]:
|
||||
"""
|
||||
Generate an RSA keypair. Returns (private_pem_bytes, public_pem_bytes).
|
||||
"""
|
||||
key = RSA.generate(bits)
|
||||
priv_pem = key.export_key(format='PEM')
|
||||
pub_pem = key.publickey().export_key(format='PEM')
|
||||
return priv_pem, pub_pem
|
||||
|
||||
|
||||
def encrypt(plaintext: bytes, encryption_public_key_pem: bytes, signing_private_key_pem: bytes) -> Optional[Tuple[bytes, bytes]]:
|
||||
"""
|
||||
Encrypt + sign.
|
||||
|
||||
Inputs:
|
||||
plaintext: bytes to encrypt
|
||||
encryption_public_key_pem: recipient RSA public key (PEM bytes)
|
||||
signing_private_key_pem: sender RSA private key (PEM bytes) for signing
|
||||
|
||||
Returns:
|
||||
(encrypted_blob_bytes, signature_bytes) on success, or None on error.
|
||||
"""
|
||||
try:
|
||||
# Import keys
|
||||
recipient_pub = RSA.import_key(encryption_public_key_pem)
|
||||
signer_priv = RSA.import_key(signing_private_key_pem)
|
||||
|
||||
# 1) generate ephemeral AES key and encrypt it with RSA-OAEP (SHA-256)
|
||||
aes_key = get_random_bytes(AES_KEY_LEN)
|
||||
rsa_cipher = PKCS1_OAEP.new(recipient_pub, hashAlgo=SHA256)
|
||||
rsa_ct = rsa_cipher.encrypt(aes_key)
|
||||
|
||||
# 2) encrypt plaintext with AES-GCM
|
||||
aes_cipher = AES.new(aes_key, AES.MODE_GCM, nonce=None) # let library choose nonce
|
||||
nonce = aes_cipher.nonce
|
||||
ciphertext, tag = aes_cipher.encrypt_and_digest(plaintext)
|
||||
|
||||
# 3) assemble wire format (without signature)
|
||||
parts = []
|
||||
parts.append(struct.pack(">H", len(rsa_ct)))
|
||||
parts.append(rsa_ct)
|
||||
parts.append(struct.pack("B", len(nonce)))
|
||||
parts.append(nonce)
|
||||
parts.append(struct.pack(">I", len(ciphertext)))
|
||||
parts.append(ciphertext)
|
||||
parts.append(tag) # 16 bytes
|
||||
encoded = b"".join(parts)
|
||||
|
||||
# 4) sign the encoded blob with RSA-PSS (SHA-256)
|
||||
h = SHA256.new(encoded)
|
||||
signer = pss.new(signer_priv)
|
||||
signature = signer.sign(h)
|
||||
|
||||
# 5) append signature + length
|
||||
final = encoded + struct.pack(">H", len(signature)) + signature
|
||||
|
||||
return final, signature
|
||||
|
||||
except Exception:
|
||||
# Per your API: on invalid input or failure, return None
|
||||
return None
|
||||
|
||||
|
||||
def decrypt(encrypted_bytes: bytes, decryption_private_key_pem: bytes, verification_public_key_pem: bytes) -> Optional[Tuple[bytes, bool]]:
|
||||
"""
|
||||
Decrypt + verify.
|
||||
|
||||
Inputs:
|
||||
encrypted_bytes: blob produced by encrypt()
|
||||
decryption_private_key_pem: recipient RSA private key (PEM bytes)
|
||||
verification_public_key_pem: sender RSA public key (PEM bytes)
|
||||
|
||||
Returns:
|
||||
(plaintext_bytes, verification_result_bool) on success, or None on malformed input / failure.
|
||||
verification_result_bool is True if signature verified, False otherwise.
|
||||
"""
|
||||
try:
|
||||
buf = encrypted_bytes
|
||||
idx = 0
|
||||
|
||||
# Parse RSA ciphertext length (2 bytes)
|
||||
if idx + 2 > len(buf): return None
|
||||
rsa_ct_len = struct.unpack_from(">H", buf, idx)[0]; idx += 2
|
||||
|
||||
if idx + rsa_ct_len > len(buf): return None
|
||||
rsa_ct = buf[idx: idx + rsa_ct_len]; idx += rsa_ct_len
|
||||
|
||||
# nonce length (1 byte)
|
||||
if idx + 1 > len(buf): return None
|
||||
nonce_len = struct.unpack_from("B", buf, idx)[0]; idx += 1
|
||||
|
||||
if idx + nonce_len > len(buf): return None
|
||||
nonce = buf[idx: idx + nonce_len]; idx += nonce_len
|
||||
|
||||
# ciphertext length (4 bytes)
|
||||
if idx + 4 > len(buf): return None
|
||||
ct_len = struct.unpack_from(">I", buf, idx)[0]; idx += 4
|
||||
|
||||
if idx + ct_len > len(buf): return None
|
||||
ciphertext = buf[idx: idx + ct_len]; idx += ct_len
|
||||
|
||||
# tag (16 bytes)
|
||||
if idx + TAG_LEN > len(buf): return None
|
||||
tag = buf[idx: idx + TAG_LEN]; idx += TAG_LEN
|
||||
|
||||
# signature length (2 bytes) + signature
|
||||
if idx + 2 > len(buf): return None
|
||||
sig_len = struct.unpack_from(">H", buf, idx)[0]; idx += 2
|
||||
|
||||
if idx + sig_len > len(buf): return None
|
||||
signature = buf[idx: idx + sig_len]; idx += sig_len
|
||||
|
||||
# the part that was signed is everything up to the signature length field
|
||||
signed_part = buf[: (len(buf) - (2 + sig_len))]
|
||||
|
||||
# RSA decapsulate AES key
|
||||
recipient_priv = RSA.import_key(decryption_private_key_pem)
|
||||
rsa_cipher = PKCS1_OAEP.new(recipient_priv, hashAlgo=SHA256)
|
||||
try:
|
||||
aes_key = rsa_cipher.decrypt(rsa_ct)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# AES-GCM decrypt and verify
|
||||
aes_cipher = AES.new(aes_key, AES.MODE_GCM, nonce=nonce)
|
||||
try:
|
||||
plaintext = aes_cipher.decrypt_and_verify(ciphertext, tag)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# verify signature with RSA-PSS
|
||||
verifier_pub = RSA.import_key(verification_public_key_pem)
|
||||
h = SHA256.new(signed_part)
|
||||
verifier = pss.new(verifier_pub)
|
||||
try:
|
||||
verifier.verify(h, signature)
|
||||
verified = True
|
||||
except (ValueError, TypeError):
|
||||
verified = False
|
||||
|
||||
return plaintext, verified
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
#
|
||||
# # Demonstration (only runs when invoked as script)
|
||||
# if __name__ == "__main__":
|
||||
# # generate recipient encryption keypair (recipient)
|
||||
# priv_enc, pub_enc = generate_keypair()
|
||||
# # generate sender signing keypair (sender)
|
||||
# priv_sig, pub_sig = generate_keypair()
|
||||
#
|
||||
# message = b"I love Ruby!"
|
||||
# out = encrypt(message, pub_enc, priv_sig)
|
||||
# if out is None:
|
||||
# raise SystemExit("Encryption failed")
|
||||
# blob, sig = out
|
||||
# print("Encrypted blob length:", len(blob), "signature length:", len(sig))
|
||||
#
|
||||
# got = decrypt(blob, priv_enc, pub_sig)
|
||||
# if got is None:
|
||||
# raise SystemExit("Decryption/verification failed")
|
||||
# plaintext, verified = got
|
||||
# print("Decrypted:", plaintext)
|
||||
# print("Signature verified:", verified)
|
||||
|
|
Loading…
Reference in New Issue