mirror of
				https://github.com/devine-dl/devine.git
				synced 2025-11-04 03:44:49 +00:00 
			
		
		
		
	Multi-thread the new DASH download system, improve redundency
Just like the commit for HLS multi-threading, this mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
		
							parent
							
								
									9e6f5b25f3
								
							
						
					
					
						commit
						4e875f5ffc
					
				@ -6,9 +6,14 @@ import logging
 | 
			
		||||
import math
 | 
			
		||||
import re
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import traceback
 | 
			
		||||
from concurrent import futures
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
from copy import copy
 | 
			
		||||
from hashlib import md5
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from threading import Event
 | 
			
		||||
from typing import Any, Callable, Optional, Union
 | 
			
		||||
from urllib.parse import urljoin, urlparse
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
@ -303,7 +308,6 @@ class DASH:
 | 
			
		||||
        else:
 | 
			
		||||
            drm = None
 | 
			
		||||
 | 
			
		||||
        segment_urls: list[str] = []
 | 
			
		||||
        manifest = load_xml(session.get(manifest_url).text)
 | 
			
		||||
        manifest_url_query = urlparse(manifest_url).query
 | 
			
		||||
 | 
			
		||||
@ -312,107 +316,151 @@ class DASH:
 | 
			
		||||
            period_base_url = urljoin(manifest_url, period_base_url)
 | 
			
		||||
        period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
 | 
			
		||||
 | 
			
		||||
        init_data: Optional[bytes] = None
 | 
			
		||||
        base_url = representation.findtext("BaseURL") or period_base_url
 | 
			
		||||
 | 
			
		||||
        segment_template = representation.find("SegmentTemplate")
 | 
			
		||||
        if segment_template is None:
 | 
			
		||||
            segment_template = adaptation_set.find("SegmentTemplate")
 | 
			
		||||
 | 
			
		||||
        segment_base = representation.find("SegmentBase")
 | 
			
		||||
        if segment_base is None:
 | 
			
		||||
            segment_base = adaptation_set.find("SegmentBase")
 | 
			
		||||
 | 
			
		||||
        segment_list = representation.find("SegmentList")
 | 
			
		||||
        if segment_list is None:
 | 
			
		||||
            segment_list = adaptation_set.find("SegmentList")
 | 
			
		||||
 | 
			
		||||
        if segment_template is not None:
 | 
			
		||||
            segment_template = copy(segment_template)
 | 
			
		||||
            start_number = int(segment_template.get("startNumber") or 1)
 | 
			
		||||
            segment_timeline = segment_template.find("SegmentTimeline")
 | 
			
		||||
        if segment_template is None and segment_list is None and base_url:
 | 
			
		||||
            # If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL
 | 
			
		||||
            # Regardless which of the two is used, we can just directly grab the BaseURL
 | 
			
		||||
            # Players would normally calculate segments via Byte-Ranges, but we don't care
 | 
			
		||||
            track.url = urljoin(period_base_url, base_url)
 | 
			
		||||
            track.descriptor = track.Descriptor.URL
 | 
			
		||||
            track.drm = [drm] if drm else []
 | 
			
		||||
        else:
 | 
			
		||||
            segments: list[tuple[str, Optional[str]]] = []
 | 
			
		||||
 | 
			
		||||
            for item in ("initialization", "media"):
 | 
			
		||||
                value = segment_template.get(item)
 | 
			
		||||
                if not value:
 | 
			
		||||
                    continue
 | 
			
		||||
                if not re.match("^https?://", value, re.IGNORECASE):
 | 
			
		||||
                    if not base_url:
 | 
			
		||||
                        raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
 | 
			
		||||
                    value = urljoin(base_url, value)
 | 
			
		||||
                if not urlparse(value).query and manifest_url_query:
 | 
			
		||||
                    value += f"?{manifest_url_query}"
 | 
			
		||||
                segment_template.set(item, value)
 | 
			
		||||
            if segment_template is not None:
 | 
			
		||||
                segment_template = copy(segment_template)
 | 
			
		||||
                start_number = int(segment_template.get("startNumber") or 1)
 | 
			
		||||
                segment_timeline = segment_template.find("SegmentTimeline")
 | 
			
		||||
 | 
			
		||||
            if segment_timeline is not None:
 | 
			
		||||
                seg_time_list = []
 | 
			
		||||
                current_time = 0
 | 
			
		||||
                for s in segment_timeline.findall("S"):
 | 
			
		||||
                    if s.get("t"):
 | 
			
		||||
                        current_time = int(s.get("t"))
 | 
			
		||||
                    for _ in range(1 + (int(s.get("r") or 0))):
 | 
			
		||||
                        seg_time_list.append(current_time)
 | 
			
		||||
                        current_time += int(s.get("d"))
 | 
			
		||||
                seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
 | 
			
		||||
                segment_urls += [
 | 
			
		||||
                    DASH.replace_fields(
 | 
			
		||||
                        segment_template.get("media"),
 | 
			
		||||
                for item in ("initialization", "media"):
 | 
			
		||||
                    value = segment_template.get(item)
 | 
			
		||||
                    if not value:
 | 
			
		||||
                        continue
 | 
			
		||||
                    if not re.match("^https?://", value, re.IGNORECASE):
 | 
			
		||||
                        if not base_url:
 | 
			
		||||
                            raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
 | 
			
		||||
                        value = urljoin(base_url, value)
 | 
			
		||||
                    if not urlparse(value).query and manifest_url_query:
 | 
			
		||||
                        value += f"?{manifest_url_query}"
 | 
			
		||||
                    segment_template.set(item, value)
 | 
			
		||||
 | 
			
		||||
                init_url = segment_template.get("initialization")
 | 
			
		||||
                if init_url:
 | 
			
		||||
                    res = session.get(DASH.replace_fields(
 | 
			
		||||
                        init_url,
 | 
			
		||||
                        Bandwidth=representation.get("bandwidth"),
 | 
			
		||||
                        Number=n,
 | 
			
		||||
                        RepresentationID=representation.get("id"),
 | 
			
		||||
                        Time=t
 | 
			
		||||
                    )
 | 
			
		||||
                    for t, n in zip(seg_time_list, seg_num_list)
 | 
			
		||||
                ]
 | 
			
		||||
                        RepresentationID=representation.get("id")
 | 
			
		||||
                    ))
 | 
			
		||||
                    res.raise_for_status()
 | 
			
		||||
                    init_data = res.content
 | 
			
		||||
 | 
			
		||||
                if segment_timeline is not None:
 | 
			
		||||
                    seg_time_list = []
 | 
			
		||||
                    current_time = 0
 | 
			
		||||
                    for s in segment_timeline.findall("S"):
 | 
			
		||||
                        if s.get("t"):
 | 
			
		||||
                            current_time = int(s.get("t"))
 | 
			
		||||
                        for _ in range(1 + (int(s.get("r") or 0))):
 | 
			
		||||
                            seg_time_list.append(current_time)
 | 
			
		||||
                            current_time += int(s.get("d"))
 | 
			
		||||
                    seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
 | 
			
		||||
 | 
			
		||||
                    for t, n in zip(seg_time_list, seg_num_list):
 | 
			
		||||
                        segments.append((
 | 
			
		||||
                            DASH.replace_fields(
 | 
			
		||||
                                segment_template.get("media"),
 | 
			
		||||
                                Bandwidth=representation.get("bandwidth"),
 | 
			
		||||
                                Number=n,
 | 
			
		||||
                                RepresentationID=representation.get("id"),
 | 
			
		||||
                                Time=t
 | 
			
		||||
                            ), None
 | 
			
		||||
                        ))
 | 
			
		||||
                else:
 | 
			
		||||
                    if not period_duration:
 | 
			
		||||
                        raise ValueError("Duration of the Period was unable to be determined.")
 | 
			
		||||
                    period_duration = DASH.pt_to_sec(period_duration)
 | 
			
		||||
                    segment_duration = float(segment_template.get("duration"))
 | 
			
		||||
                    segment_timescale = float(segment_template.get("timescale") or 1)
 | 
			
		||||
                    total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
 | 
			
		||||
 | 
			
		||||
                    for s in range(start_number, start_number + total_segments):
 | 
			
		||||
                        segments.append((
 | 
			
		||||
                            DASH.replace_fields(
 | 
			
		||||
                                segment_template.get("media"),
 | 
			
		||||
                                Bandwidth=representation.get("bandwidth"),
 | 
			
		||||
                                Number=s,
 | 
			
		||||
                                RepresentationID=representation.get("id"),
 | 
			
		||||
                                Time=s
 | 
			
		||||
                            ), None
 | 
			
		||||
                        ))
 | 
			
		||||
            elif segment_list is not None:
 | 
			
		||||
                base_media_url = urljoin(period_base_url, base_url)
 | 
			
		||||
 | 
			
		||||
                init_data = None
 | 
			
		||||
                initialization = segment_list.find("Initialization")
 | 
			
		||||
                if initialization:
 | 
			
		||||
                    source_url = initialization.get("sourceURL")
 | 
			
		||||
                    if source_url is None:
 | 
			
		||||
                        source_url = base_media_url
 | 
			
		||||
 | 
			
		||||
                    res = session.get(source_url)
 | 
			
		||||
                    res.raise_for_status()
 | 
			
		||||
                    init_data = res.content
 | 
			
		||||
 | 
			
		||||
                segment_urls = segment_list.findall("SegmentURL")
 | 
			
		||||
                for segment_url in segment_urls:
 | 
			
		||||
                    media_url = segment_url.get("media")
 | 
			
		||||
                    if media_url is None:
 | 
			
		||||
                        media_url = base_media_url
 | 
			
		||||
 | 
			
		||||
                    segments.append((
 | 
			
		||||
                        media_url,
 | 
			
		||||
                        segment_url.get("mediaRange")
 | 
			
		||||
                    ))
 | 
			
		||||
            else:
 | 
			
		||||
                if not period_duration:
 | 
			
		||||
                    raise ValueError("Duration of the Period was unable to be determined.")
 | 
			
		||||
                period_duration = DASH.pt_to_sec(period_duration)
 | 
			
		||||
                segment_duration = float(segment_template.get("duration"))
 | 
			
		||||
                segment_timescale = float(segment_template.get("timescale") or 1)
 | 
			
		||||
                log.error("Could not find a way to get segments from this MPD manifest.")
 | 
			
		||||
                log.debug(manifest_url)
 | 
			
		||||
                sys.exit(1)
 | 
			
		||||
 | 
			
		||||
                total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
 | 
			
		||||
                segment_urls += [
 | 
			
		||||
                    DASH.replace_fields(
 | 
			
		||||
                        segment_template.get("media"),
 | 
			
		||||
                        Bandwidth=representation.get("bandwidth"),
 | 
			
		||||
                        Number=s,
 | 
			
		||||
                        RepresentationID=representation.get("id"),
 | 
			
		||||
                        Time=s
 | 
			
		||||
                    )
 | 
			
		||||
                    for s in range(start_number, start_number + total_segments)
 | 
			
		||||
                ]
 | 
			
		||||
            if not drm and isinstance(track, (Video, Audio)):
 | 
			
		||||
                try:
 | 
			
		||||
                    drm = Widevine.from_init_data(init_data)
 | 
			
		||||
                except Widevine.Exceptions.PSSHNotFound:
 | 
			
		||||
                    # it might not have Widevine DRM, or might not have found the PSSH
 | 
			
		||||
                    log.warning("No Widevine PSSH was found for this track, is it DRM free?")
 | 
			
		||||
                else:
 | 
			
		||||
                    # license and grab content keys
 | 
			
		||||
                    if not license_widevine:
 | 
			
		||||
                        raise ValueError("license_widevine func must be supplied to use Widevine DRM")
 | 
			
		||||
                    license_widevine(drm)
 | 
			
		||||
 | 
			
		||||
            init_data = None
 | 
			
		||||
            init_url = segment_template.get("initialization")
 | 
			
		||||
            if init_url:
 | 
			
		||||
                res = session.get(DASH.replace_fields(
 | 
			
		||||
                    init_url,
 | 
			
		||||
                    Bandwidth=representation.get("bandwidth"),
 | 
			
		||||
                    RepresentationID=representation.get("id")
 | 
			
		||||
                ))
 | 
			
		||||
                res.raise_for_status()
 | 
			
		||||
                init_data = res.content
 | 
			
		||||
                if not drm:
 | 
			
		||||
                    try:
 | 
			
		||||
                        drm = Widevine.from_init_data(init_data)
 | 
			
		||||
                    except Widevine.Exceptions.PSSHNotFound:
 | 
			
		||||
                        # it might not have Widevine DRM, or might not have found the PSSH
 | 
			
		||||
                        log.warning("No Widevine PSSH was found for this track, is it DRM free?")
 | 
			
		||||
                    else:
 | 
			
		||||
                        # license and grab content keys
 | 
			
		||||
                        if not license_widevine:
 | 
			
		||||
                            raise ValueError("license_widevine func must be supplied to use Widevine DRM")
 | 
			
		||||
                        license_widevine(drm)
 | 
			
		||||
            state_event = Event()
 | 
			
		||||
 | 
			
		||||
            for i, segment_url in enumerate(tqdm(segment_urls, unit="segments")):
 | 
			
		||||
                segment_filename = str(i).zfill(len(str(len(segment_urls))))
 | 
			
		||||
                segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
 | 
			
		||||
            def download_segment(filename: str, segment: tuple[str, Optional[str]]):
 | 
			
		||||
                time.sleep(0.1)
 | 
			
		||||
                if state_event.is_set():
 | 
			
		||||
                    return
 | 
			
		||||
 | 
			
		||||
                segment_save_path = (save_dir / filename).with_suffix(".mp4")
 | 
			
		||||
 | 
			
		||||
                segment_uri, segment_range = segment
 | 
			
		||||
 | 
			
		||||
                asyncio.run(aria2c(
 | 
			
		||||
                    segment_url,
 | 
			
		||||
                    segment_uri,
 | 
			
		||||
                    segment_save_path,
 | 
			
		||||
                    session.headers,
 | 
			
		||||
                    proxy
 | 
			
		||||
                    proxy,
 | 
			
		||||
                    byte_range=segment_range
 | 
			
		||||
                ))
 | 
			
		||||
 | 
			
		||||
                if isinstance(track, Audio) or init_data:
 | 
			
		||||
@ -438,84 +486,34 @@ class DASH:
 | 
			
		||||
                    track.drm = None
 | 
			
		||||
                    if callable(track.OnDecrypted):
 | 
			
		||||
                        track.OnDecrypted(track)
 | 
			
		||||
        elif segment_list is not None:
 | 
			
		||||
            base_media_url = urljoin(period_base_url, base_url)
 | 
			
		||||
            if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
 | 
			
		||||
                # at least one segment has no URL specified, it uses the base url and ranges
 | 
			
		||||
                track.url = base_media_url
 | 
			
		||||
                track.descriptor = track.Descriptor.URL
 | 
			
		||||
                track.drm = [drm] if drm else []
 | 
			
		||||
            else:
 | 
			
		||||
                init_data = None
 | 
			
		||||
                initialization = segment_list.find("Initialization")
 | 
			
		||||
                if initialization:
 | 
			
		||||
                    source_url = initialization.get("sourceURL")
 | 
			
		||||
                    if source_url is None:
 | 
			
		||||
                        source_url = base_media_url
 | 
			
		||||
 | 
			
		||||
                    res = session.get(source_url)
 | 
			
		||||
                    res.raise_for_status()
 | 
			
		||||
                    init_data = res.content
 | 
			
		||||
                    if not drm:
 | 
			
		||||
                        try:
 | 
			
		||||
                            drm = Widevine.from_init_data(init_data)
 | 
			
		||||
                        except Widevine.Exceptions.PSSHNotFound:
 | 
			
		||||
                            # it might not have Widevine DRM, or might not have found the PSSH
 | 
			
		||||
                            log.warning("No Widevine PSSH was found for this track, is it DRM free?")
 | 
			
		||||
                        else:
 | 
			
		||||
                            # license and grab content keys
 | 
			
		||||
                            if not license_widevine:
 | 
			
		||||
                                raise ValueError("license_widevine func must be supplied to use Widevine DRM")
 | 
			
		||||
                            license_widevine(drm)
 | 
			
		||||
 | 
			
		||||
                for i, segment_url in enumerate(tqdm(segment_list.findall("SegmentURL"), unit="segments")):
 | 
			
		||||
                    segment_filename = str(i).zfill(len(str(len(segment_urls))))
 | 
			
		||||
                    segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
 | 
			
		||||
 | 
			
		||||
                    media_url = segment_url.get("media")
 | 
			
		||||
                    if media_url is None:
 | 
			
		||||
                        media_url = base_media_url
 | 
			
		||||
 | 
			
		||||
                    asyncio.run(aria2c(
 | 
			
		||||
                        media_url,
 | 
			
		||||
                        segment_save_path,
 | 
			
		||||
                        session.headers,
 | 
			
		||||
                        proxy,
 | 
			
		||||
                        byte_range=segment_url.get("mediaRange")
 | 
			
		||||
                    ))
 | 
			
		||||
 | 
			
		||||
                    if isinstance(track, Audio) or init_data:
 | 
			
		||||
                        with open(segment_save_path, "rb+") as f:
 | 
			
		||||
                            segment_data = f.read()
 | 
			
		||||
                            if isinstance(track, Audio):
 | 
			
		||||
                                # fix audio decryption on ATVP by fixing the sample description index
 | 
			
		||||
                                # TODO: Is this in mpeg data, or init data?
 | 
			
		||||
                                segment_data = re.sub(
 | 
			
		||||
                                    b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
 | 
			
		||||
                                    b"\\g<1>\x01",
 | 
			
		||||
                                    segment_data
 | 
			
		||||
                                )
 | 
			
		||||
                            # prepend the init data to be able to decrypt
 | 
			
		||||
                            if init_data:
 | 
			
		||||
                                f.seek(0)
 | 
			
		||||
                                f.write(init_data)
 | 
			
		||||
                                f.write(segment_data)
 | 
			
		||||
 | 
			
		||||
                    if drm:
 | 
			
		||||
                        # TODO: What if the manifest does not mention DRM, but has DRM
 | 
			
		||||
                        drm.decrypt(segment_save_path)
 | 
			
		||||
                        track.drm = None
 | 
			
		||||
                        if callable(track.OnDecrypted):
 | 
			
		||||
                            track.OnDecrypted(track)
 | 
			
		||||
        elif segment_base is not None or base_url:
 | 
			
		||||
            # SegmentBase more or less boils down to defined ByteRanges
 | 
			
		||||
            # So, we don't care, just download the full file
 | 
			
		||||
            track.url = urljoin(period_base_url, base_url)
 | 
			
		||||
            track.descriptor = track.Descriptor.URL
 | 
			
		||||
            track.drm = [drm] if drm else []
 | 
			
		||||
        else:
 | 
			
		||||
            log.error("Could not find a way to get segments from this MPD manifest.")
 | 
			
		||||
            sys.exit(1)
 | 
			
		||||
            with tqdm(total=len(segments), unit="segments") as pbar:
 | 
			
		||||
                with ThreadPoolExecutor(max_workers=16) as pool:
 | 
			
		||||
                    try:
 | 
			
		||||
                        for download in futures.as_completed((
 | 
			
		||||
                            pool.submit(
 | 
			
		||||
                                download_segment,
 | 
			
		||||
                                filename=str(i).zfill(len(str(len(segments)))),
 | 
			
		||||
                                segment=segment
 | 
			
		||||
                            )
 | 
			
		||||
                            for i, segment in enumerate(segments)
 | 
			
		||||
                        )):
 | 
			
		||||
                            if download.cancelled():
 | 
			
		||||
                                continue
 | 
			
		||||
                            e = download.exception()
 | 
			
		||||
                            if e:
 | 
			
		||||
                                state_event.set()
 | 
			
		||||
                                pool.shutdown(wait=False, cancel_futures=True)
 | 
			
		||||
                                traceback.print_exception(e)
 | 
			
		||||
                                log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
 | 
			
		||||
                                sys.exit(1)
 | 
			
		||||
                            else:
 | 
			
		||||
                                pbar.update(1)
 | 
			
		||||
                    except KeyboardInterrupt:
 | 
			
		||||
                        state_event.set()
 | 
			
		||||
                        pool.shutdown(wait=False, cancel_futures=True)
 | 
			
		||||
                        log.info("Received Keyboard Interrupt, stopping...")
 | 
			
		||||
                        return
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_language(*options: Any) -> Optional[Language]:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user