diff --git a/custom_functions/database/cache_to_db_mariadb.py b/custom_functions/database/cache_to_db_mariadb.py index e25eb7b..28c61fd 100644 --- a/custom_functions/database/cache_to_db_mariadb.py +++ b/custom_functions/database/cache_to_db_mariadb.py @@ -1,3 +1,5 @@ +"""Module to cache data to MariaDB.""" + import os import yaml import mysql.connector @@ -5,8 +7,10 @@ from mysql.connector import Error def get_db_config(): - # Configure your MariaDB connection - with open(f"{os.getcwd()}/configs/config.yaml", "r") as file: + """Get the database configuration for MariaDB.""" + with open( + os.path.join(os.getcwd(), "configs", "config.yaml"), "r", encoding="utf-8" + ) as file: config = yaml.safe_load(file) db_config = { "host": f'{config["mariadb"]["host"]}', @@ -18,6 +22,7 @@ def get_db_config(): def create_database(): + """Create the database for MariaDB.""" try: with mysql.connector.connect(**get_db_config()) as conn: cursor = conn.cursor() @@ -41,15 +46,16 @@ def create_database(): def cache_to_db( - service=None, - pssh=None, - kid=None, - key=None, - license_url=None, - headers=None, - cookies=None, - data=None, + service: str = "", + pssh: str = "", + kid: str = "", + key: str = "", + license_url: str = "", + headers: str = "", + cookies: str = "", + data: str = "", ): + """Cache data to the database for MariaDB.""" try: with mysql.connector.connect(**get_db_config()) as conn: cursor = conn.cursor() @@ -81,6 +87,7 @@ def cache_to_db( def search_by_pssh_or_kid(search_filter): + """Search the database by PSSH or KID for MariaDB.""" results = set() try: with mysql.connector.connect(**get_db_config()) as conn: @@ -109,6 +116,7 @@ def search_by_pssh_or_kid(search_filter): def get_key_by_kid_and_service(kid, service): + """Get the key by KID and service for MariaDB.""" try: with mysql.connector.connect(**get_db_config()) as conn: cursor = conn.cursor() @@ -124,6 +132,7 @@ def get_key_by_kid_and_service(kid, service): def get_kid_key_dict(service_name): + """Get the KID and key dictionary for MariaDB.""" try: with mysql.connector.connect(**get_db_config()) as conn: cursor = conn.cursor() @@ -137,6 +146,7 @@ def get_kid_key_dict(service_name): def get_unique_services(): + """Get the unique services for MariaDB.""" try: with mysql.connector.connect(**get_db_config()) as conn: cursor = conn.cursor() @@ -148,6 +158,7 @@ def get_unique_services(): def key_count(): + """Get the key count for MariaDB.""" try: with mysql.connector.connect(**get_db_config()) as conn: cursor = conn.cursor() diff --git a/custom_functions/database/cache_to_db_sqlite.py b/custom_functions/database/cache_to_db_sqlite.py index 101ba71..2558a9a 100644 --- a/custom_functions/database/cache_to_db_sqlite.py +++ b/custom_functions/database/cache_to_db_sqlite.py @@ -1,10 +1,14 @@ +"""Module to cache data to SQLite.""" + import sqlite3 import os def create_database(): - # Using with statement to manage the connection and cursor - with sqlite3.connect(f"{os.getcwd()}/databases/sql/key_cache.db") as conn: + """Create the database for SQLite.""" + with sqlite3.connect( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ) as conn: cursor = conn.cursor() cursor.execute( """ @@ -23,16 +27,19 @@ def create_database(): def cache_to_db( - service: str = None, - pssh: str = None, - kid: str = None, - key: str = None, - license_url: str = None, - headers: str = None, - cookies: str = None, - data: str = None, + service: str = "", + pssh: str = "", + kid: str = "", + key: str = "", + license_url: str = "", + headers: str = "", + cookies: str = "", + data: str = "", ): - with sqlite3.connect(f"{os.getcwd()}/databases/sql/key_cache.db") as conn: + """Cache data to the database for SQLite.""" + with sqlite3.connect( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ) as conn: cursor = conn.cursor() # Check if the record with the given KID already exists @@ -53,8 +60,10 @@ def cache_to_db( def search_by_pssh_or_kid(search_filter): - # Using with statement to automatically close the connection - with sqlite3.connect(f"{os.getcwd()}/databases/sql/key_cache.db") as conn: + """Search the database by PSSH or KID for SQLite.""" + with sqlite3.connect( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ) as conn: cursor = conn.cursor() # Initialize a set to store unique matching records @@ -92,8 +101,10 @@ def search_by_pssh_or_kid(search_filter): def get_key_by_kid_and_service(kid, service): - # Using 'with' to automatically close the connection when done - with sqlite3.connect(f"{os.getcwd()}/databases/sql/key_cache.db") as conn: + """Get the key by KID and service for SQLite.""" + with sqlite3.connect( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ) as conn: cursor = conn.cursor() # Query to search by KID and SERVICE @@ -114,8 +125,10 @@ def get_key_by_kid_and_service(kid, service): def get_kid_key_dict(service_name): - # Using with statement to automatically manage the connection and cursor - with sqlite3.connect(f"{os.getcwd()}/databases/sql/key_cache.db") as conn: + """Get the KID and key dictionary for SQLite.""" + with sqlite3.connect( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ) as conn: cursor = conn.cursor() # Query to fetch KID and Key for the selected service @@ -133,8 +146,10 @@ def get_kid_key_dict(service_name): def get_unique_services(): - # Using with statement to automatically manage the connection and cursor - with sqlite3.connect(f"{os.getcwd()}/databases/sql/key_cache.db") as conn: + """Get the unique services for SQLite.""" + with sqlite3.connect( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ) as conn: cursor = conn.cursor() # Query to get distinct services from the 'licenses' table @@ -150,8 +165,10 @@ def get_unique_services(): def key_count(): - # Using with statement to automatically manage the connection and cursor - with sqlite3.connect(f"{os.getcwd()}/databases/sql/key_cache.db") as conn: + """Get the key count for SQLite.""" + with sqlite3.connect( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ) as conn: cursor = conn.cursor() # Count the number of KID entries in the licenses table diff --git a/custom_functions/database/unified_db_ops.py b/custom_functions/database/unified_db_ops.py new file mode 100644 index 0000000..d95665b --- /dev/null +++ b/custom_functions/database/unified_db_ops.py @@ -0,0 +1,159 @@ +"""Unified database operations module that automatically uses the correct backend.""" + +import os +from typing import Optional, List, Dict, Any +import yaml + +# Import both backend modules +try: + import custom_functions.database.cache_to_db_sqlite as sqlite_db +except ImportError: + sqlite_db = None + +try: + import custom_functions.database.cache_to_db_mariadb as mariadb_db +except ImportError: + mariadb_db = None + + +class DatabaseOperations: + """Unified database operations class that automatically selects the correct backend.""" + + def __init__(self): + self.backend = self._get_database_backend() + self.db_module = self._get_db_module() + + def _get_database_backend(self) -> str: + """Get the database backend from config, default to sqlite.""" + try: + config_path = os.path.join(os.getcwd(), "configs", "config.yaml") + with open(config_path, "r", encoding="utf-8") as file: + config = yaml.safe_load(file) + return config.get("database_type", "sqlite").lower() + except (FileNotFoundError, KeyError, yaml.YAMLError): + return "sqlite" + + def _get_db_module(self): + """Get the appropriate database module based on backend.""" + if self.backend == "mariadb" and mariadb_db: + return mariadb_db + if sqlite_db: + return sqlite_db + raise ImportError(f"Database module for {self.backend} not available") + + def get_backend_info(self) -> Dict[str, str]: + """Get information about the current database backend being used.""" + return { + "backend": self.backend, + "module": self.db_module.__name__ if self.db_module else "None", + } + + def create_database(self) -> None: + """Create the database using the configured backend.""" + return self.db_module.create_database() + + def cache_to_db( + self, + service: str = "", + pssh: str = "", + kid: str = "", + key: str = "", + license_url: str = "", + headers: str = "", + cookies: str = "", + data: str = "", + ) -> bool: + """Cache data to the database using the configured backend.""" + return self.db_module.cache_to_db( + service=service, + pssh=pssh, + kid=kid, + key=key, + license_url=license_url, + headers=headers, + cookies=cookies, + data=data, + ) + + def search_by_pssh_or_kid(self, search_filter: str) -> List[Dict[str, str]]: + """Search the database by PSSH or KID using the configured backend.""" + return self.db_module.search_by_pssh_or_kid(search_filter) + + def get_key_by_kid_and_service(self, kid: str, service: str) -> Optional[str]: + """Get the key by KID and service using the configured backend.""" + return self.db_module.get_key_by_kid_and_service(kid, service) + + def get_kid_key_dict(self, service_name: str) -> Dict[str, str]: + """Get the KID and key dictionary using the configured backend.""" + return self.db_module.get_kid_key_dict(service_name) + + def get_unique_services(self) -> List[str]: + """Get the unique services using the configured backend.""" + return self.db_module.get_unique_services() + + def key_count(self) -> int: + """Get the key count using the configured backend.""" + return self.db_module.key_count() + + +# Create a singleton instance for easy import and use +db_ops = DatabaseOperations() + + +# Convenience functions that use the singleton instance +def get_backend_info() -> Dict[str, str]: + """Get information about the current database backend being used.""" + return db_ops.get_backend_info() + + +def create_database() -> None: + """Create the database using the configured backend.""" + return db_ops.create_database() + + +def cache_to_db( + service: str = "", + pssh: str = "", + kid: str = "", + key: str = "", + license_url: str = "", + headers: str = "", + cookies: str = "", + data: str = "", +) -> bool: + """Cache data to the database using the configured backend.""" + return db_ops.cache_to_db( + service=service, + pssh=pssh, + kid=kid, + key=key, + license_url=license_url, + headers=headers, + cookies=cookies, + data=data, + ) + + +def search_by_pssh_or_kid(search_filter: str) -> List[Dict[str, str]]: + """Search the database by PSSH or KID using the configured backend.""" + return db_ops.search_by_pssh_or_kid(search_filter) + + +def get_key_by_kid_and_service(kid: str, service: str) -> Optional[str]: + """Get the key by KID and service using the configured backend.""" + return db_ops.get_key_by_kid_and_service(kid, service) + + +def get_kid_key_dict(service_name: str) -> Dict[str, str]: + """Get the KID and key dictionary using the configured backend.""" + return db_ops.get_kid_key_dict(service_name) + + +def get_unique_services() -> List[str]: + """Get the unique services using the configured backend.""" + return db_ops.get_unique_services() + + +def key_count() -> int: + """Get the key count using the configured backend.""" + return db_ops.key_count() diff --git a/custom_functions/decrypt/api_decrypt.py b/custom_functions/decrypt/api_decrypt.py index c6eac25..91857d7 100644 --- a/custom_functions/decrypt/api_decrypt.py +++ b/custom_functions/decrypt/api_decrypt.py @@ -1,19 +1,29 @@ +"""Module to decrypt the license using the API.""" + +import base64 +import ast +import glob +import json +import os +from urllib.parse import urlparse +import binascii + +import requests +from requests.exceptions import Timeout, RequestException +import yaml + from pywidevine.cdm import Cdm as widevineCdm from pywidevine.device import Device as widevineDevice from pywidevine.pssh import PSSH as widevinePSSH from pyplayready.cdm import Cdm as playreadyCdm from pyplayready.device import Device as playreadyDevice from pyplayready.system.pssh import PSSH as playreadyPSSH -import requests -import base64 -import ast -import glob -import os -import yaml -from urllib.parse import urlparse + +from custom_functions.database.unified_db_ops import cache_to_db def find_license_key(data, keywords=None): + """Find the license key in the data.""" if keywords is None: keywords = [ "license", @@ -47,6 +57,7 @@ def find_license_key(data, keywords=None): def find_license_challenge(data, keywords=None, new_value=None): + """Find the license challenge in the data.""" if keywords is None: keywords = [ "license", @@ -79,17 +90,19 @@ def find_license_challenge(data, keywords=None, new_value=None): def is_base64(string): + """Check if the string is base64 encoded.""" try: # Try decoding the string decoded_data = base64.b64decode(string) # Check if the decoded data, when re-encoded, matches the original string return base64.b64encode(decoded_data).decode("utf-8") == string - except Exception: + except (binascii.Error, TypeError): # If decoding or encoding fails, it's not Base64 return False def is_url_and_split(input_str): + """Check if the string is a URL and split it into protocol and FQDN.""" parsed = urlparse(input_str) # Check if it's a valid URL with scheme and netloc @@ -97,390 +110,286 @@ def is_url_and_split(input_str): protocol = parsed.scheme fqdn = parsed.netloc return True, protocol, fqdn + return False, None, None + + +def load_device(device_type, device, username, config): + """Load the appropriate device file for PlayReady or Widevine.""" + if device_type == "PR": + ext, config_key, class_loader = ".prd", "default_pr_cdm", playreadyDevice.load + base_dir = "PR" else: - return False, None, None + ext, config_key, class_loader = ".wvd", "default_wv_cdm", widevineDevice.load + base_dir = "WV" + + if device == "public": + base_name = config[config_key] + if not base_name.endswith(ext): + base_name += ext + search_path = f"{os.getcwd()}/configs/CDMs/{base_dir}/{base_name}" + else: + base_name = device + if not base_name.endswith(ext): + base_name += ext + search_path = f"{os.getcwd()}/configs/CDMs/{username}/{base_dir}/{base_name}" + + files = glob.glob(search_path) + if not files: + return None, f"No {ext} file found for device '{device}'" + try: + return class_loader(files[0]), None + except (IOError, OSError) as e: + return None, f"Failed to read device file: {e}" + except (ValueError, TypeError, AttributeError) as e: + return None, f"Failed to parse device file: {e}" + + +def prepare_request_data(headers, cookies, json_data, challenge, is_widevine): + """Prepare headers, cookies, and json_data for the license request.""" + try: + format_headers = ast.literal_eval(headers) if headers else None + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid headers format: {e}") from e + + try: + format_cookies = ast.literal_eval(cookies) if cookies else None + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid cookies format: {e}") from e + + format_json_data = None + if json_data and not is_base64(json_data): + try: + format_json_data = ast.literal_eval(json_data) + if is_widevine: + format_json_data = find_license_challenge( + data=format_json_data, + new_value=base64.b64encode(challenge).decode(), + ) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid json_data format: {e}") from e + except (TypeError, AttributeError) as e: + raise ValueError(f"Error processing json_data: {e}") from e + + return format_headers, format_cookies, format_json_data + + +def send_license_request(license_url, headers, cookies, json_data, challenge, proxies): + """Send the license request and return the response.""" + try: + response = requests.post( + url=license_url, + headers=headers, + proxies=proxies, + cookies=cookies, + json=json_data if json_data is not None else None, + data=challenge if json_data is None else None, + timeout=10, + ) + return response, None + except ConnectionError as error: + return None, f"Connection error: {error}" + except Timeout as error: + return None, f"Request timeout: {error}" + except RequestException as error: + return None, f"Request error: {error}" + + +def extract_and_cache_keys( + cdm, + session_id, + cache_to_db, + pssh, + license_url, + headers, + cookies, + challenge, + json_data, + is_widevine, +): + """Extract keys from the session and cache them.""" + returned_keys = "" + try: + keys = list(cdm.get_keys(session_id)) + for index, key in enumerate(keys): + # Widevine: key.type, PlayReady: key.key_type + key_type = getattr(key, "type", getattr(key, "key_type", None)) + kid = getattr(key, "kid", getattr(key, "key_id", None)) + if key_type != "SIGNING" and kid is not None: + cache_to_db( + pssh=pssh, + license_url=license_url, + headers=headers, + cookies=cookies, + data=challenge if json_data is None else json_data, + kid=kid.hex, + key=key.key.hex(), + ) + if index != len(keys) - 1: + returned_keys += f"{kid.hex}:{key.key.hex()}\n" + else: + returned_keys += f"{kid.hex}:{key.key.hex()}" + return returned_keys, None + except AttributeError as error: + return None, f"Error accessing CDM keys: {error}" + except (TypeError, ValueError) as error: + return None, f"Error processing keys: {error}" def api_decrypt( - pssh: str = None, - license_url: str = None, - proxy: str = None, - headers: str = None, - cookies: str = None, - json_data: str = None, + pssh: str = "", + license_url: str = "", + proxy: str = "", + headers: str = "", + cookies: str = "", + json_data: str = "", device: str = "public", - username: str = None, + username: str = "", ): + """Decrypt the license using the API.""" print(f"Using device {device} for user {username}") - with open(f"{os.getcwd()}/configs/config.yaml", "r") as file: + with open(f"{os.getcwd()}/configs/config.yaml", "r", encoding="utf-8") as file: config = yaml.safe_load(file) - if config["database_type"].lower() == "sqlite": - from custom_functions.database.cache_to_db_sqlite import cache_to_db - elif config["database_type"].lower() == "mariadb": - from custom_functions.database.cache_to_db_mariadb import cache_to_db - if pssh is None: + + if pssh == "": return {"status": "error", "message": "No PSSH provided"} + + # Detect PlayReady or Widevine try: - if "".encode("utf-16-le") in base64.b64decode(pssh): # PR - try: - pr_pssh = playreadyPSSH(pssh) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred processing PSSH\n\n{error}", - } - try: - if device == "public": - base_name = config["default_pr_cdm"] - if not base_name.endswith(".prd"): - base_name += ".prd" - prd_files = glob.glob( - f"{os.getcwd()}/configs/CDMs/PR/{base_name}" - ) - else: - prd_files = glob.glob( - f"{os.getcwd()}/configs/CDMs/PR/{base_name}" - ) - if prd_files: - pr_device = playreadyDevice.load(prd_files[0]) - else: - return { - "status": "error", - "message": "No default .prd file found", - } - else: - base_name = device - if not base_name.endswith(".prd"): - base_name += ".prd" - prd_files = glob.glob( - f"{os.getcwd()}/configs/CDMs/{username}/PR/{base_name}" - ) - else: - prd_files = glob.glob( - f"{os.getcwd()}/configs/CDMs/{username}/PR/{base_name}" - ) - if prd_files: - pr_device = playreadyDevice.load(prd_files[0]) - else: - return { - "status": "error", - "message": f"{base_name} does not exist", - } - except Exception as error: - return { - "status": "error", - "message": f"An error occurred location PlayReady CDM file\n\n{error}", - } - try: - pr_cdm = playreadyCdm.from_device(pr_device) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred loading PlayReady CDM\n\n{error}", - } - try: - pr_session_id = pr_cdm.open() - except Exception as error: - return { - "status": "error", - "message": f"An error occurred opening a CDM session\n\n{error}", - } - try: - pr_challenge = pr_cdm.get_license_challenge( - pr_session_id, pr_pssh.wrm_headers[0] - ) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting license challenge\n\n{error}", - } - try: - if headers: - format_headers = ast.literal_eval(headers) - else: - format_headers = None - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting headers\n\n{error}", - } - try: - if cookies: - format_cookies = ast.literal_eval(cookies) - else: - format_cookies = None - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting cookies\n\n{error}", - } - try: - if json_data and not is_base64(json_data): - format_json_data = ast.literal_eval(json_data) - else: - format_json_data = None - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting json_data\n\n{error}", - } - licence = None - proxies = None - if proxy is not None: - is_url, protocol, fqdn = is_url_and_split(proxy) - if is_url: - proxies = {"http": proxy, "https": proxy} - else: - return { - "status": "error", - "message": f"Your proxy is invalid, please put it in the format of http(s)://fqdn.tld:port", - } - try: - licence = requests.post( - url=license_url, - headers=format_headers, - proxies=proxies, - cookies=format_cookies, - json=format_json_data if format_json_data is not None else None, - data=pr_challenge if format_json_data is None else None, - ) - except requests.exceptions.ConnectionError as error: - return { - "status": "error", - "message": f"An error occurred sending license challenge through your proxy\n\n{error}", - } - except Exception as error: - return { - "status": "error", - "message": f"An error occurred sending license reqeust\n\n{error}\n\n{licence.content}", - } - try: - pr_cdm.parse_license(pr_session_id, licence.text) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred parsing license content\n\n{error}\n\n{licence.content}", - } - returned_keys = "" - try: - keys = list(pr_cdm.get_keys(pr_session_id)) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting keys\n\n{error}", - } - try: - for index, key in enumerate(keys): - if key.key_type != "SIGNING": - cache_to_db( - pssh=pssh, - license_url=license_url, - headers=headers, - cookies=cookies, - data=pr_challenge if json_data is None else json_data, - kid=key.key_id.hex, - key=key.key.hex(), - ) - if index != len(keys) - 1: - returned_keys += f"{key.key_id.hex}:{key.key.hex()}\n" - else: - returned_keys += f"{key.key_id.hex}:{key.key.hex()}" - except Exception as error: - return { - "status": "error", - "message": f"An error occurred formatting keys\n\n{error}", - } - try: - pr_cdm.close(pr_session_id) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred closing session\n\n{error}", - } - try: - return {"status": "success", "message": returned_keys} - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting returned_keys\n\n{error}", - } - except Exception as error: + is_pr = "".encode("utf-16-le") in base64.b64decode(pssh) + except (binascii.Error, TypeError) as error: return { "status": "error", "message": f"An error occurred processing PSSH\n\n{error}", } - else: - try: - wv_pssh = widevinePSSH(pssh) - except Exception as error: + + device_type = "PR" if is_pr else "WV" + cdm_class = playreadyCdm if is_pr else widevineCdm + pssh_class = playreadyPSSH if is_pr else widevinePSSH + + # Load device + device_obj, device_err = load_device(device_type, device, username, config) + if device_obj is None: + return {"status": "error", "message": device_err} + + # Create CDM + try: + cdm = cdm_class.from_device(device_obj) + except (IOError, ValueError, AttributeError) as error: + return { + "status": "error", + "message": f"An error occurred loading {device_type} CDM\n\n{error}", + } + + # Open session + try: + session_id = cdm.open() + except (IOError, ValueError, AttributeError) as error: + return { + "status": "error", + "message": f"An error occurred opening a CDM session\n\n{error}", + } + + # Parse PSSH and get challenge + try: + pssh_obj = pssh_class(pssh) + if is_pr: + challenge = cdm.get_license_challenge(session_id, pssh_obj.wrm_headers[0]) + else: + challenge = cdm.get_license_challenge(session_id, pssh_obj) + except (ValueError, AttributeError, IndexError) as error: + return { + "status": "error", + "message": f"An error occurred getting license challenge\n\n{error}", + } + + # Prepare request data + try: + format_headers, format_cookies, format_json_data = prepare_request_data( + headers, cookies, json_data, challenge, is_widevine=(not is_pr) + ) + except (ValueError, SyntaxError) as error: + return { + "status": "error", + "message": f"An error occurred preparing request data\n\n{error}", + } + + # Prepare proxies + proxies = None + if proxy is not None: + is_url, protocol, fqdn = is_url_and_split(proxy) + if is_url: + proxies = {"http": proxy, "https": proxy} + else: return { "status": "error", - "message": f"An error occurred processing PSSH\n\n{error}", + "message": "Your proxy is invalid, please put it in the format of http(s)://fqdn.tld:port", } - try: - if device == "public": - base_name = config["default_wv_cdm"] - if not base_name.endswith(".wvd"): - base_name += ".wvd" - wvd_files = glob.glob(f"{os.getcwd()}/configs/CDMs/WV/{base_name}") - else: - wvd_files = glob.glob(f"{os.getcwd()}/configs/CDMs/WV/{base_name}") - if wvd_files: - wv_device = widevineDevice.load(wvd_files[0]) - else: - return {"status": "error", "message": "No default .wvd file found"} - else: - base_name = device - if not base_name.endswith(".wvd"): - base_name += ".wvd" - wvd_files = glob.glob( - f"{os.getcwd()}/configs/CDMs/{username}/WV/{base_name}" - ) - else: - wvd_files = glob.glob( - f"{os.getcwd()}/configs/CDMs/{username}/WV/{base_name}" - ) - if wvd_files: - wv_device = widevineDevice.load(wvd_files[0]) - else: - return {"status": "error", "message": f"{base_name} does not exist"} - except Exception as error: - return { - "status": "error", - "message": f"An error occurred location Widevine CDM file\n\n{error}", - } - try: - wv_cdm = widevineCdm.from_device(wv_device) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred loading Widevine CDM\n\n{error}", - } - try: - wv_session_id = wv_cdm.open() - except Exception as error: - return { - "status": "error", - "message": f"An error occurred opening a CDM session\n\n{error}", - } - try: - wv_challenge = wv_cdm.get_license_challenge(wv_session_id, wv_pssh) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting license challenge\n\n{error}", - } - try: - if headers: - format_headers = ast.literal_eval(headers) - else: - format_headers = None - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting headers\n\n{error}", - } - try: - if cookies: - format_cookies = ast.literal_eval(cookies) - else: - format_cookies = None - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting cookies\n\n{error}", - } - try: - if json_data and not is_base64(json_data): - format_json_data = ast.literal_eval(json_data) - format_json_data = find_license_challenge( - data=format_json_data, - new_value=base64.b64encode(wv_challenge).decode(), - ) - else: - format_json_data = None - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting json_data\n\n{error}", - } - licence = None - proxies = None - if proxy is not None: - is_url, protocol, fqdn = is_url_and_split(proxy) - if is_url: - proxies = {"http": proxy, "https": proxy} - try: - licence = requests.post( - url=license_url, - headers=format_headers, - proxies=proxies, - cookies=format_cookies, - json=format_json_data if format_json_data is not None else None, - data=wv_challenge if format_json_data is None else None, - ) - except requests.exceptions.ConnectionError as error: - return { - "status": "error", - "message": f"An error occurred sending license challenge through your proxy\n\n{error}", - } - except Exception as error: - return { - "status": "error", - "message": f"An error occurred sending license reqeust\n\n{error}\n\n{licence.content}", - } - try: - wv_cdm.parse_license(wv_session_id, licence.content) - except: + + # Send license request + licence, req_err = send_license_request( + license_url, + format_headers, + format_cookies, + format_json_data, + challenge, + proxies, + ) + if licence is None: + return {"status": "error", "message": req_err} + + # Parse license + try: + if is_pr: + cdm.parse_license(session_id, licence.text) + else: try: - license_json = licence.json() - license_value = find_license_key(license_json) - wv_cdm.parse_license(wv_session_id, license_value) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred parsing license content\n\n{error}\n\n{licence.content}", - } - returned_keys = "" - try: - keys = list(wv_cdm.get_keys(wv_session_id)) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting keys\n\n{error}", - } - try: - for index, key in enumerate(keys): - if key.type != "SIGNING": - cache_to_db( - pssh=pssh, - license_url=license_url, - headers=headers, - cookies=cookies, - data=wv_challenge if json_data is None else json_data, - kid=key.kid.hex, - key=key.key.hex(), - ) - if index != len(keys) - 1: - returned_keys += f"{key.kid.hex}:{key.key.hex()}\n" + cdm.parse_license(session_id, licence.content) # type: ignore[arg-type] + except (ValueError, TypeError): + # Try to extract license from JSON + try: + license_json = licence.json() + license_value = find_license_key(license_json) + if license_value is not None: + cdm.parse_license(session_id, license_value) else: - returned_keys += f"{key.kid.hex}:{key.key.hex()}" - except Exception as error: - return { - "status": "error", - "message": f"An error occurred formatting keys\n\n{error}", - } - try: - wv_cdm.close(wv_session_id) - except Exception as error: - return { - "status": "error", - "message": f"An error occurred closing session\n\n{error}", - } - try: - return {"status": "success", "message": returned_keys} - except Exception as error: - return { - "status": "error", - "message": f"An error occurred getting returned_keys\n\n{error}", - } + return { + "status": "error", + "message": f"Could not extract license from JSON: {license_json}", + } + except (ValueError, json.JSONDecodeError, AttributeError) as error: + return { + "status": "error", + "message": f"An error occurred parsing license content\n\n{error}\n\n{licence.content}", + } + except (ValueError, TypeError, AttributeError) as error: + return { + "status": "error", + "message": f"An error occurred parsing license content\n\n{error}\n\n{licence.content}", + } + + # Extract and cache keys + returned_keys, key_err = extract_and_cache_keys( + cdm, + session_id, + cache_to_db, + pssh, + license_url, + headers, + cookies, + challenge, + json_data, + is_widevine=(not is_pr), + ) + if returned_keys is None: + return {"status": "error", "message": key_err} + + # Close session + try: + cdm.close(session_id) + except (IOError, ValueError, AttributeError) as error: + return { + "status": "error", + "message": f"An error occurred closing session\n\n{error}", + } + + return {"status": "success", "message": returned_keys} diff --git a/custom_functions/prechecks/database_checks.py b/custom_functions/prechecks/database_checks.py index 76b17a4..55d6918 100644 --- a/custom_functions/prechecks/database_checks.py +++ b/custom_functions/prechecks/database_checks.py @@ -1,52 +1,159 @@ -"""Module to check for the database.""" +"""Module to check for the database with unified backend support.""" import os +from typing import Dict, Any import yaml -from custom_functions.database.cache_to_db_mariadb import ( - create_database as create_mariadb_database, -) -from custom_functions.database.cache_to_db_sqlite import ( - create_database as create_sqlite_database, +from custom_functions.database.unified_db_ops import ( + db_ops, + get_backend_info, + key_count, ) from custom_functions.database.user_db import create_user_database -def check_for_sqlite_database(): - """Check for the SQLite database.""" - with open( - os.path.join(os.getcwd(), "configs", "config.yaml"), "r", encoding="utf-8" - ) as file: - config = yaml.safe_load(file) - if os.path.exists(os.path.join(os.getcwd(), "databases", "key_cache.db")): - return - if config["database_type"].lower() == "sqlite": - create_sqlite_database() - return - return +def get_database_config() -> Dict[str, Any]: + """Get the database configuration from config.yaml.""" + try: + config_path = os.path.join(os.getcwd(), "configs", "config.yaml") + with open(config_path, "r", encoding="utf-8") as file: + config = yaml.safe_load(file) + return config + except (FileNotFoundError, KeyError, yaml.YAMLError) as e: + print(f"Warning: Could not load config.yaml: {e}") + return {"database_type": "sqlite"} # Default fallback -def check_for_user_database(): - """Check for the user database.""" - if os.path.exists(os.path.join(os.getcwd(), "databases", "users.db")): - return - create_user_database() +def check_for_sqlite_database() -> None: + """Check for the SQLite database file and create if needed.""" + config = get_database_config() + database_type = config.get("database_type", "sqlite").lower() + + # Only check for SQLite file if we're using SQLite + if database_type == "sqlite": + sqlite_path = os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + if not os.path.exists(sqlite_path): + print("SQLite database not found, creating...") + # Ensure directory exists + os.makedirs(os.path.dirname(sqlite_path), exist_ok=True) + db_ops.create_database() + print(f"SQLite database created at: {sqlite_path}") + else: + print(f"SQLite database found at: {sqlite_path}") -def check_for_mariadb_database(): - """Check for the MariaDB database.""" - with open( - os.path.join(os.getcwd(), "configs", "config.yaml"), "r", encoding="utf-8" - ) as file: - config = yaml.safe_load(file) - if config["database_type"].lower() == "mariadb": - create_mariadb_database() - return - return +def check_for_mariadb_database() -> None: + """Check for the MariaDB database and create if needed.""" + config = get_database_config() + database_type = config.get("database_type", "sqlite").lower() + + # Only check MariaDB if we're using MariaDB + if database_type == "mariadb": + try: + print("Checking MariaDB connection and creating database if needed...") + db_ops.create_database() + print("MariaDB database check completed successfully") + except Exception as e: + print(f"Error checking/creating MariaDB database: {e}") + print("Falling back to SQLite...") + # Fallback to SQLite if MariaDB fails + fallback_config_path = os.path.join(os.getcwd(), "configs", "config.yaml") + try: + with open(fallback_config_path, "r", encoding="utf-8") as file: + config = yaml.safe_load(file) + config["database_type"] = "sqlite" + with open(fallback_config_path, "w", encoding="utf-8") as file: + yaml.safe_dump(config, file) + check_for_sqlite_database() + except Exception as fallback_error: + print(f"Error during fallback to SQLite: {fallback_error}") -def check_for_sql_database(): - """Check for the SQL database.""" - check_for_sqlite_database() - check_for_mariadb_database() +def check_for_user_database() -> None: + """Check for the user database and create if needed.""" + user_db_path = os.path.join(os.getcwd(), "databases", "users.db") + if not os.path.exists(user_db_path): + print("User database not found, creating...") + # Ensure directory exists + os.makedirs(os.path.dirname(user_db_path), exist_ok=True) + create_user_database() + print(f"User database created at: {user_db_path}") + else: + print(f"User database found at: {user_db_path}") + + +def check_for_sql_database() -> None: + """Check for the SQL database based on configuration.""" + print("=== Database Check Starting ===") + + # Get backend information + backend_info = get_backend_info() + print(f"Database backend: {backend_info['backend']}") + print(f"Using module: {backend_info['module']}") + + config = get_database_config() + database_type = config.get("database_type", "sqlite").lower() + + # Ensure databases directory exists + os.makedirs(os.path.join(os.getcwd(), "databases"), exist_ok=True) + os.makedirs(os.path.join(os.getcwd(), "databases", "sql"), exist_ok=True) + + # Check main database based on type + if database_type == "mariadb": + check_for_mariadb_database() + else: # Default to SQLite + check_for_sqlite_database() + + # Always check user database (always SQLite) check_for_user_database() + + print("=== Database Check Completed ===") + + +def get_database_status() -> Dict[str, Any]: + """Get the current database status and configuration.""" + config = get_database_config() + backend_info = get_backend_info() + + status = { + "configured_backend": config.get("database_type", "sqlite").lower(), + "active_backend": backend_info["backend"], + "module_in_use": backend_info["module"], + "sqlite_file_exists": os.path.exists( + os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") + ), + "user_db_exists": os.path.exists( + os.path.join(os.getcwd(), "databases", "users.db") + ), + } + + # Try to get key count to verify database is working + try: + + status["key_count"] = key_count() + status["database_operational"] = True + except Exception as e: + status["key_count"] = "Error" + status["database_operational"] = False + status["error"] = str(e) + + return status + + +def print_database_status() -> None: + """Print a formatted database status report.""" + status = get_database_status() + + print("\n=== Database Status Report ===") + print(f"Configured Backend: {status['configured_backend']}") + print(f"Active Backend: {status['active_backend']}") + print(f"Module in Use: {status['module_in_use']}") + print(f"SQLite File Exists: {status['sqlite_file_exists']}") + print(f"User DB Exists: {status['user_db_exists']}") + print(f"Database Operational: {status['database_operational']}") + print(f"Key Count: {status['key_count']}") + + if not status["database_operational"]: + print(f"Error: {status.get('error', 'Unknown error')}") + + print("==============================\n") diff --git a/routes/api.py b/routes/api.py index 13a302f..0df884f 100644 --- a/routes/api.py +++ b/routes/api.py @@ -1,56 +1,53 @@ +"""Module to handle the API routes.""" + import os import sqlite3 -from flask import Blueprint, jsonify, request, send_file, session import json -from custom_functions.decrypt.api_decrypt import api_decrypt -from custom_functions.user_checks.device_allowed import user_allowed_to_use_device import shutil import math -import yaml -import mysql.connector from io import StringIO import tempfile import time + +from flask import Blueprint, jsonify, request, send_file, session, after_this_request +import yaml +import mysql.connector + +from custom_functions.decrypt.api_decrypt import api_decrypt +from custom_functions.user_checks.device_allowed import user_allowed_to_use_device +from custom_functions.database.unified_db_ops import ( + search_by_pssh_or_kid, + cache_to_db, + get_key_by_kid_and_service, + get_unique_services, + get_kid_key_dict, + key_count, +) from configs.icon_links import data as icon_data api_bp = Blueprint("api", __name__) -with open(f"{os.getcwd()}/configs/config.yaml", "r") as file: +with open(os.path.join(os.getcwd(), "configs", "config.yaml"), "r", encoding="utf-8") as file: config = yaml.safe_load(file) -if config["database_type"].lower() != "mariadb": - from custom_functions.database.cache_to_db_sqlite import ( - search_by_pssh_or_kid, - cache_to_db, - get_key_by_kid_and_service, - get_unique_services, - get_kid_key_dict, - key_count, - ) -elif config["database_type"].lower() == "mariadb": - from custom_functions.database.cache_to_db_mariadb import ( - search_by_pssh_or_kid, - cache_to_db, - get_key_by_kid_and_service, - get_unique_services, - get_kid_key_dict, - key_count, - ) def get_db_config(): - # Configure your MariaDB connection - with open(f"{os.getcwd()}/configs/config.yaml", "r") as file: - config = yaml.safe_load(file) + """Get the MariaDB database configuration.""" + with open( + os.path.join(os.getcwd(), "configs", "config.yaml"), "r", encoding="utf-8" + ) as file_mariadb: + config_mariadb = yaml.safe_load(file_mariadb) db_config = { - "host": f'{config["mariadb"]["host"]}', - "user": f'{config["mariadb"]["user"]}', - "password": f'{config["mariadb"]["password"]}', - "database": f'{config["mariadb"]["database"]}', + "host": f'{config_mariadb["mariadb"]["host"]}', + "user": f'{config_mariadb["mariadb"]["user"]}', + "password": f'{config_mariadb["mariadb"]["password"]}', + "database": f'{config_mariadb["mariadb"]["database"]}', } return db_config @api_bp.route("/api/cache/search", methods=["POST"]) def get_data(): + """Get the data from the database.""" search_argument = json.loads(request.data)["input"] results = search_by_pssh_or_kid(search_filter=search_argument) return jsonify(results) @@ -58,6 +55,7 @@ def get_data(): @api_bp.route("/api/cache//", methods=["GET"]) def get_single_key_service(service, kid): + """Get the single key from the database.""" result = get_key_by_kid_and_service(kid=kid, service=service) return jsonify( { @@ -69,6 +67,7 @@ def get_single_key_service(service, kid): @api_bp.route("/api/cache/", methods=["GET"]) def get_multiple_key_service(service): + """Get the multiple keys from the database.""" result = get_kid_key_dict(service_name=service) pages = math.ceil(len(result) / 10) return jsonify({"code": 0, "content_keys": result, "pages": pages}) @@ -76,6 +75,7 @@ def get_multiple_key_service(service): @api_bp.route("/api/cache//", methods=["POST"]) def add_single_key_service(service, kid): + """Add the single key to the database.""" body = request.get_json() content_key = body["content_key"] result = cache_to_db(service=service, kid=kid, key=content_key) @@ -86,17 +86,17 @@ def add_single_key_service(service, kid): "updated": True, } ) - elif result is False: - return jsonify( - { - "code": 0, - "updated": True, - } - ) + return jsonify( + { + "code": 0, + "updated": True, + } + ) @api_bp.route("/api/cache/", methods=["POST"]) def add_multiple_key_service(service): + """Add the multiple keys to the database.""" body = request.get_json() keys_added = 0 keys_updated = 0 @@ -104,7 +104,7 @@ def add_multiple_key_service(service): result = cache_to_db(service=service, kid=kid, key=key) if result is True: keys_updated += 1 - elif result is False: + else: keys_added += 1 return jsonify( { @@ -117,6 +117,7 @@ def add_multiple_key_service(service): @api_bp.route("/api/cache", methods=["POST"]) def unique_service(): + """Get the unique services from the database.""" services = get_unique_services() return jsonify( { @@ -128,11 +129,12 @@ def unique_service(): @api_bp.route("/api/cache/download", methods=["GET"]) def download_database(): + """Download the database.""" if config["database_type"].lower() != "mariadb": - original_database_path = f"{os.getcwd()}/databases/sql/key_cache.db" + original_database_path = os.path.join(os.getcwd(), "databases", "sql", "key_cache.db") # Make a copy of the original database (without locking the original) - modified_database_path = f"{os.getcwd()}/databases/sql/key_cache_modified.db" + modified_database_path = os.path.join(os.getcwd(), "databases", "sql", "key_cache_modified.db") # Using shutil.copy2 to preserve metadata (timestamps, etc.) shutil.copy2(original_database_path, modified_database_path) @@ -157,51 +159,56 @@ def download_database(): return send_file( modified_database_path, as_attachment=True, download_name="key_cache.db" ) - if config["database_type"].lower() == "mariadb": - try: - # Connect to MariaDB - conn = mysql.connector.connect(**get_db_config()) - cursor = conn.cursor() + try: + conn = mysql.connector.connect(**get_db_config()) + cursor = conn.cursor() - # Update sensitive data (this updates the live DB, you may want to duplicate rows instead) - cursor.execute( - """ - UPDATE licenses - SET Headers = NULL, - Cookies = NULL - """ + # Get column names + cursor.execute("SHOW COLUMNS FROM licenses") + columns = [row[0] for row in cursor.fetchall()] + + # Build SELECT with Headers and Cookies as NULL + select_columns = [] + for col in columns: + if col.lower() in ("headers", "cookies"): + select_columns.append("NULL AS " + col) + else: + select_columns.append(col) + select_query = f"SELECT {', '.join(select_columns)} FROM licenses" + cursor.execute(select_query) + rows = cursor.fetchall() + + # Dump to SQL-like format + output = StringIO() + output.write("-- Dump of `licenses` table (Headers and Cookies are NULL)\n") + for row in rows: + values = ", ".join( + f"'{str(v).replace('\'', '\\\'')}'" if v is not None else "NULL" + for v in row + ) + output.write( + f"INSERT INTO licenses ({', '.join(columns)}) VALUES ({values});\n" ) - conn.commit() + # Write to a temp file for download + temp_dir = tempfile.gettempdir() + temp_path = os.path.join(temp_dir, "key_cache.sql") + with open(temp_path, "w", encoding="utf-8") as f: + f.write(output.getvalue()) - # Now export the table - cursor.execute("SELECT * FROM licenses") - rows = cursor.fetchall() - column_names = [desc[0] for desc in cursor.description] + @after_this_request + def remove_file(response): + try: + os.remove(temp_path) + except Exception: + pass + return response - # Dump to SQL-like format - output = StringIO() - output.write(f"-- Dump of `licenses` table\n") - for row in rows: - values = ", ".join( - f"'{str(v).replace('\'', '\\\'')}'" if v is not None else "NULL" - for v in row - ) - output.write( - f"INSERT INTO licenses ({', '.join(column_names)}) VALUES ({values});\n" - ) - - # Write to a temp file for download - temp_dir = tempfile.gettempdir() - temp_path = os.path.join(temp_dir, "key_cache.sql") - with open(temp_path, "w", encoding="utf-8") as f: - f.write(output.getvalue()) - - return send_file( - temp_path, as_attachment=True, download_name="licenses_dump.sql" - ) - except mysql.connector.Error as err: - return {"error": str(err)}, 500 + return send_file( + temp_path, as_attachment=True, download_name="licenses_dump.sql" + ) + except mysql.connector.Error as err: + return {"error": str(err)}, 500 _keycount_cache = {"count": None, "timestamp": 0} @@ -209,6 +216,7 @@ _keycount_cache = {"count": None, "timestamp": 0} @api_bp.route("/api/cache/keycount", methods=["GET"]) def get_count(): + """Get the count of the keys in the database.""" now = time.time() if now - _keycount_cache["timestamp"] > 10 or _keycount_cache["count"] is None: _keycount_cache["count"] = key_count() @@ -218,69 +226,42 @@ def get_count(): @api_bp.route("/api/decrypt", methods=["POST"]) def decrypt_data(): - api_request_data = json.loads(request.data) - if "pssh" in api_request_data: - if api_request_data["pssh"] == "": - api_request_pssh = None - else: - api_request_pssh = api_request_data["pssh"] - else: - api_request_pssh = None - if "licurl" in api_request_data: - if api_request_data["licurl"] == "": - api_request_licurl = None - else: - api_request_licurl = api_request_data["licurl"] - else: - api_request_licurl = None - if "proxy" in api_request_data: - if api_request_data["proxy"] == "": - api_request_proxy = None - else: - api_request_proxy = api_request_data["proxy"] - else: - api_request_proxy = None - if "headers" in api_request_data: - if api_request_data["headers"] == "": - api_request_headers = None - else: - api_request_headers = api_request_data["headers"] - else: - api_request_headers = None - if "cookies" in api_request_data: - if api_request_data["cookies"] == "": - api_request_cookies = None - else: - api_request_cookies = api_request_data["cookies"] - else: - api_request_cookies = None - if "data" in api_request_data: - if api_request_data["data"] == "": - api_request_data_func = None - else: - api_request_data_func = api_request_data["data"] - else: - api_request_data_func = None - if "device" in api_request_data: - if ( - api_request_data["device"] == "default" - or api_request_data["device"] == "CDRM-Project Public Widevine CDM" - or api_request_data["device"] == "CDRM-Project Public PlayReady CDM" - ): - api_request_device = "public" - else: - api_request_device = api_request_data["device"] - else: + """Decrypt the data.""" + api_request_data = request.get_json(force=True) + + # Helper to get fields or None if missing/empty + def get_field(key, default=""): + value = api_request_data.get(key, default) + return value if value != "" else default + + api_request_pssh = get_field("pssh") + api_request_licurl = get_field("licurl") + api_request_proxy = get_field("proxy") + api_request_headers = get_field("headers") + api_request_cookies = get_field("cookies") + api_request_data_func = get_field("data") + + # Device logic + device = get_field("device", "public") + if device in [ + "default", + "CDRM-Project Public Widevine CDM", + "CDRM-Project Public PlayReady CDM", + "", + None, + ]: api_request_device = "public" - username = None + else: + api_request_device = device + + username = "" if api_request_device != "public": username = session.get("username") if not username: return jsonify({"message": "Not logged in, not allowed"}), 400 - if user_allowed_to_use_device(device=api_request_device, username=username): - api_request_device = api_request_device - else: - return jsonify({"message": f"Not authorized / Not found"}), 403 + if not user_allowed_to_use_device(device=api_request_device, username=username): + return jsonify({"message": "Not authorized / Not found"}), 403 + result = api_decrypt( pssh=api_request_pssh, proxy=api_request_proxy, @@ -293,12 +274,12 @@ def decrypt_data(): ) if result["status"] == "success": return jsonify({"status": "success", "message": result["message"]}) - else: - return jsonify({"status": "fail", "message": result["message"]}) + return jsonify({"status": "fail", "message": result["message"]}) @api_bp.route("/api/links", methods=["GET"]) def get_links(): + """Get the links.""" return jsonify( { "discord": icon_data["discord"], @@ -310,6 +291,7 @@ def get_links(): @api_bp.route("/api/extension", methods=["POST"]) def verify_extension(): + """Verify the extension.""" return jsonify( { "status": True,