from __future__ import annotations import collections.abc from Crypto.PublicKey import ECC from .exceptions import InvalidCertificateChain # monkey patch for construct 2.8.8 compatibility if not hasattr(collections, 'Sequence'): collections.Sequence = collections.abc.Sequence import base64 from pathlib import Path from typing import Union from Crypto.Hash import SHA256 from Crypto.Signature import DSS from construct import Bytes, Const, Int32ub, GreedyRange, Switch, Container, ListContainer from construct import Int16ub, Array from construct import Struct, this from .ecc_key import ECCKey class _BCertStructs: DrmBCertBasicInfo = Struct( "cert_id" / Bytes(16), "security_level" / Int32ub, "flags" / Int32ub, "cert_type" / Int32ub, "public_key_digest" / Bytes(32), "expiration_date" / Int32ub, "client_id" / Bytes(16) ) # TODO: untested DrmBCertDomainInfo = Struct( "service_id" / Bytes(16), "account_id" / Bytes(16), "revision_timestamp" / Int32ub, "domain_url_length" / Int32ub, "domain_url" / Bytes((this.domain_url_length + 3) & 0xfffffffc) ) # TODO: untested DrmBCertPCInfo = Struct( "security_version" / Int32ub ) # TODO: untested DrmBCertDeviceInfo = Struct( "max_license" / Int32ub, "max_header" / Int32ub, "max_chain_depth" / Int32ub ) DrmBCertFeatureInfo = Struct( "feature_count" / Int32ub, # max. 32 "features" / Array(this.feature_count, Int32ub) ) DrmBCertKeyInfo = Struct( "key_count" / Int32ub, "cert_keys" / Array(this.key_count, Struct( "type" / Int16ub, "length" / Int16ub, "flags" / Int32ub, "key" / Bytes(this.length // 8), "usages_count" / Int32ub, "usages" / Array(this.usages_count, Int32ub) )) ) DrmBCertManufacturerInfo = Struct( "flags" / Int32ub, "manufacturer_name_length" / Int32ub, "manufacturer_name" / Bytes((this.manufacturer_name_length + 3) & 0xfffffffc), "model_name_length" / Int32ub, "model_name" / Bytes((this.model_name_length + 3) & 0xfffffffc), "model_number_length" / Int32ub, "model_number" / Bytes((this.model_number_length + 3) & 0xfffffffc), ) DrmBCertSignatureInfo = Struct( "signature_type" / Int16ub, "signature_size" / Int16ub, "signature" / Bytes(this.signature_size), "signature_key_size" / Int32ub, "signature_key" / Bytes(this.signature_key_size // 8) ) # TODO: untested DrmBCertSilverlightInfo = Struct( "security_version" / Int32ub, "platform_identifier" / Int32ub ) # TODO: untested DrmBCertMeteringInfo = Struct( "metering_id" / Bytes(16), "metering_url_length" / Int32ub, "metering_url" / Bytes((this.metering_url_length + 3) & 0xfffffffc) ) # TODO: untested DrmBCertExtDataSignKeyInfo = Struct( "key_type" / Int16ub, "key_length" / Int16ub, "flags" / Int32ub, "key" / Bytes(this.length // 8) ) # TODO: untested BCertExtDataRecord = Struct( "data_size" / Int32ub, "data" / Bytes(this.data_size) ) # TODO: untested DrmBCertExtDataSignature = Struct( "signature_type" / Int16ub, "signature_size" / Int16ub, "signature" / Bytes(this.signature_size) ) # TODO: untested BCertExtDataContainer = Struct( "record_count" / Int32ub, # always 1 "records" / Array(this.record_count, BCertExtDataRecord), "signature" / DrmBCertExtDataSignature ) # TODO: untested DrmBCertServerInfo = Struct( "warning_days" / Int32ub ) # TODO: untested DrmBcertSecurityVersion = Struct( "security_version" / Int32ub, "platform_identifier" / Int32ub ) Attribute = Struct( "flags" / Int16ub, "tag" / Int16ub, "length" / Int32ub, "attribute" / Switch( lambda this_: this_.tag, { 1: DrmBCertBasicInfo, 2: DrmBCertDomainInfo, 3: DrmBCertPCInfo, 4: DrmBCertDeviceInfo, 5: DrmBCertFeatureInfo, 6: DrmBCertKeyInfo, 7: DrmBCertManufacturerInfo, 8: DrmBCertSignatureInfo, 9: DrmBCertSilverlightInfo, 10: DrmBCertMeteringInfo, 11: DrmBCertExtDataSignKeyInfo, 12: BCertExtDataContainer, 13: DrmBCertExtDataSignature, 14: Bytes(this.length - 8), 15: DrmBCertServerInfo, 16: DrmBcertSecurityVersion, 17: DrmBcertSecurityVersion }, default=Bytes(this.length - 8) ) ) BCert = Struct( "signature" / Const(b"CERT"), "version" / Int32ub, "total_length" / Int32ub, "certificate_length" / Int32ub, "attributes" / GreedyRange(Attribute) ) BCertChain = Struct( "signature" / Const(b"CHAI"), "version" / Int32ub, "total_length" / Int32ub, "flags" / Int32ub, "certificate_count" / Int32ub, "certificates" / GreedyRange(BCert) ) class Certificate(_BCertStructs): """Represents a BCert""" def __init__( self, parsed_bcert: Container, bcert_obj: _BCertStructs.BCert = _BCertStructs.BCert ): self.parsed = parsed_bcert self._BCERT = bcert_obj @classmethod def new_leaf_cert( cls, cert_id: bytes, security_level: int, client_id: bytes, signing_key: ECCKey, encryption_key: ECCKey, group_key: ECCKey, parent: CertificateChain, expiry: int = 0xFFFFFFFF, max_license: int = 10240, max_header: int = 15360, max_chain_depth: int = 2 ) -> Certificate: if not cert_id: raise ValueError("Certificate ID is required") if not client_id: raise ValueError("Client ID is required") basic_info = Container( cert_id=cert_id, security_level=security_level, flags=0, cert_type=2, public_key_digest=signing_key.public_sha256_digest(), expiration_date=expiry, client_id=client_id ) basic_info_attribute = Container( flags=1, tag=1, length=len(_BCertStructs.DrmBCertBasicInfo.build(basic_info)) + 8, attribute=basic_info ) device_info = Container( max_license=max_license, max_header=max_header, max_chain_depth=max_chain_depth ) device_info_attribute = Container( flags=1, tag=4, length=len(_BCertStructs.DrmBCertDeviceInfo.build(device_info)) + 8, attribute=device_info ) feature = Container( feature_count=3, features=ListContainer([ 4, # SECURE_CLOCK 9, # REVOCATION_LIST_FEATURE 13 # SUPPORTS_PR3_FEATURES ]) ) feature_attribute = Container( flags=1, tag=5, length=len(_BCertStructs.DrmBCertFeatureInfo.build(feature)) + 8, attribute=feature ) cert_key_sign = Container( type=1, length=512, # bits flags=0, key=signing_key.public_bytes(), usages_count=1, usages=ListContainer([ 1 # KEYUSAGE_SIGN ]) ) cert_key_encrypt = Container( type=1, length=512, # bits flags=0, key=encryption_key.public_bytes(), usages_count=1, usages=ListContainer([ 2 # KEYUSAGE_ENCRYPT_KEY ]) ) key_info = Container( key_count=2, cert_keys=ListContainer([ cert_key_sign, cert_key_encrypt ]) ) key_info_attribute = Container( flags=1, tag=6, length=len(_BCertStructs.DrmBCertKeyInfo.build(key_info)) + 8, attribute=key_info ) manufacturer_info = parent.get_certificate(0).get_attribute(7) new_bcert_container = Container( signature=b"CERT", version=1, total_length=0, # filled at a later time certificate_length=0, # filled at a later time attributes=ListContainer([ basic_info_attribute, device_info_attribute, feature_attribute, key_info_attribute, manufacturer_info, ]) ) payload = _BCertStructs.BCert.build(new_bcert_container) new_bcert_container.certificate_length = len(payload) new_bcert_container.total_length = len(payload) + 144 # signature length sign_payload = _BCertStructs.BCert.build(new_bcert_container) hash_obj = SHA256.new(sign_payload) signer = DSS.new(group_key.key, 'fips-186-3') signature = signer.sign(hash_obj) signature_info = Container( signature_type=1, signature_size=64, signature=signature, signature_key_size=512, # bits signature_key=group_key.public_bytes() ) signature_info_attribute = Container( flags=1, tag=8, length=len(_BCertStructs.DrmBCertSignatureInfo.build(signature_info)) + 8, attribute=signature_info ) new_bcert_container.attributes.append(signature_info_attribute) return cls(new_bcert_container) @classmethod def loads(cls, data: Union[str, bytes]) -> Certificate: if isinstance(data, str): data = base64.b64decode(data) if not isinstance(data, bytes): raise ValueError(f"Expecting Bytes or Base64 input, got {data!r}") cert = _BCertStructs.BCert return cls( parsed_bcert=cert.parse(data), bcert_obj=cert ) @classmethod def load(cls, path: Union[Path, str]) -> Certificate: if not isinstance(path, (Path, str)): raise ValueError(f"Expecting Path object or path string, got {path!r}") with Path(path).open(mode="rb") as f: return cls.loads(f.read()) def get_attribute(self, type_: int): for attribute in self.parsed.attributes: if attribute.tag == type_: return attribute def get_security_level(self) -> int: basic_info_attribute = self.get_attribute(1).attribute if basic_info_attribute: return basic_info_attribute.security_level @staticmethod def _unpad(name: bytes): return name.rstrip(b'\x00').decode("utf-8", errors="ignore") def get_name(self): manufacturer_info = self.get_attribute(7).attribute if manufacturer_info: return f"{self._unpad(manufacturer_info.manufacturer_name)} {self._unpad(manufacturer_info.model_name)} {self._unpad(manufacturer_info.model_number)}" def dumps(self) -> bytes: return self._BCERT.build(self.parsed) def struct(self) -> _BCertStructs.BCert: return self._BCERT def verify_signature(self): sign_payload = self.dumps()[:-144] signature_attribute = self.get_attribute(8).attribute raw_signature_key = signature_attribute.signature_key signature_key = ECC.construct( curve='P-256', point_x=int.from_bytes(raw_signature_key[:32], 'big'), point_y=int.from_bytes(raw_signature_key[32:], 'big') ) hash_obj = SHA256.new(sign_payload) verifier = DSS.new(signature_key, 'fips-186-3') try: verifier.verify(hash_obj, signature_attribute.signature) return True except ValueError: return False class CertificateChain(_BCertStructs): """Represents a BCertChain""" def __init__( self, parsed_bcert_chain: Container, bcert_chain_obj: _BCertStructs.BCertChain = _BCertStructs.BCertChain ): self.parsed = parsed_bcert_chain self._BCERT_CHAIN = bcert_chain_obj @classmethod def loads(cls, data: Union[str, bytes]) -> CertificateChain: if isinstance(data, str): data = base64.b64decode(data) if not isinstance(data, bytes): raise ValueError(f"Expecting Bytes or Base64 input, got {data!r}") cert_chain = _BCertStructs.BCertChain return cls( parsed_bcert_chain=cert_chain.parse(data), bcert_chain_obj=cert_chain ) @classmethod def load(cls, path: Union[Path, str]) -> CertificateChain: if not isinstance(path, (Path, str)): raise ValueError(f"Expecting Path object or path string, got {path!r}") with Path(path).open(mode="rb") as f: return cls.loads(f.read()) def dumps(self) -> bytes: return self._BCERT_CHAIN.build(self.parsed) def struct(self) -> _BCertStructs.BCertChain: return self._BCERT_CHAIN def get_certificate(self, index: int) -> Certificate: return Certificate(self.parsed.certificates[index]) def get_security_level(self) -> int: # not sure if there's a better way than this return self.get_certificate(0).get_security_level() def get_name(self) -> str: return self.get_certificate(0).get_name() def append(self, bcert: Certificate) -> None: self.parsed.certificate_count += 1 self.parsed.certificates.append(bcert.parsed) self.parsed.total_length += len(bcert.dumps()) def prepend(self, bcert: Certificate) -> None: self.parsed.certificate_count += 1 self.parsed.certificates.insert(0, bcert.parsed) self.parsed.total_length += len(bcert.dumps()) def remove(self, index: int) -> None: if self.parsed.certificate_count <= 0: raise InvalidCertificateChain("CertificateChain does not contain any Certificates") if index >= self.parsed.certificate_count: raise IndexError(f"No Certificate at index {index}, {self.parsed.certificate_count} total") self.parsed.certificate_count -= 1 bcert = Certificate(self.parsed.certificates[index]) self.parsed.total_length -= len(bcert.dumps()) self.parsed.certificates.pop(index) def get(self, index: int) -> Certificate: if self.parsed.certificate_count <= 0: raise InvalidCertificateChain("CertificateChain does not contain any Certificates") if index >= self.parsed.certificate_count: raise IndexError(f"No Certificate at index {index}, {self.parsed.certificate_count} total") return Certificate(self.parsed.certificates[index]) def count(self) -> int: return self.parsed.certificate_count