2025-03-18 00:23:51 +05:30

87 lines
2.8 KiB
Python

from __future__ import annotations
import base64
from pathlib import Path
from typing import Union
from Crypto.Hash import SHA256
from Crypto.PublicKey import ECC
from Crypto.PublicKey.ECC import EccKey
from ecpy.curves import Curve, Point
class ECCKey:
"""Represents a PlayReady ECC key"""
def __init__(self, key: EccKey):
self.key = key
@classmethod
def generate(cls):
"""Generate a new ECC key pair"""
return cls(key=ECC.generate(curve='P-256'))
@classmethod
def construct(cls, private_key: Union[bytes, int]):
"""Construct an ECC key pair from private/public bytes/ints"""
if isinstance(private_key, bytes):
private_key = int.from_bytes(private_key, 'big')
if not isinstance(private_key, int):
raise ValueError(f"Expecting Bytes or Int input, got {private_key!r}")
# The public is always derived from the private key; loading the other stuff won't work
key = ECC.construct(
curve='P-256',
d=private_key,
)
return cls(key=key)
@classmethod
def loads(cls, data: Union[str, bytes]) -> ECCKey:
if isinstance(data, str):
data = base64.b64decode(data)
if not isinstance(data, bytes):
raise ValueError(f"Expecting Bytes or Base64 input, got {data!r}")
if len(data) not in [96, 32]:
raise ValueError(f"Invalid data length. Expecting 96 or 32 bytes, got {len(data)}")
return cls.construct(private_key=data[:32])
@classmethod
def load(cls, path: Union[Path, str]) -> ECCKey:
if not isinstance(path, (Path, str)):
raise ValueError(f"Expecting Path object or path string, got {path!r}")
with Path(path).open(mode="rb") as f:
return cls.loads(f.read())
def dumps(self):
return self.private_bytes() + self.public_bytes()
def dump(self, path: Union[Path, str]) -> None:
if not isinstance(path, (Path, str)):
raise ValueError(f"Expecting Path object or path string, got {path!r}")
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(self.dumps())
def get_point(self, curve: Curve) -> Point:
return Point(self.key.pointQ.x, self.key.pointQ.y, curve)
def private_bytes(self) -> bytes:
return self.key.d.to_bytes()
def private_sha256_digest(self) -> bytes:
hash_object = SHA256.new()
hash_object.update(self.private_bytes())
return hash_object.digest()
def public_bytes(self) -> bytes:
return self.key.pointQ.x.to_bytes() + self.key.pointQ.y.to_bytes()
def public_sha256_digest(self) -> bytes:
hash_object = SHA256.new()
hash_object.update(self.public_bytes())
return hash_object.digest()