mirror of
				https://github.com/devine-dl/devine.git
				synced 2025-11-04 03:44:49 +00:00 
			
		
		
		
	Move download_segment() from DASH/HLS download_track() to Class
Various overall small readability improvements have also been made.
This commit is contained in:
		
							parent
							
								
									03c012f88e
								
							
						
					
					
						commit
						dd64212ad2
					
				@ -13,7 +13,7 @@ from functools import partial
 | 
			
		||||
from hashlib import md5
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from threading import Event
 | 
			
		||||
from typing import Any, Callable, Optional, Union
 | 
			
		||||
from typing import Any, Callable, Optional, Union, MutableMapping
 | 
			
		||||
from urllib.parse import urljoin, urlparse
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
 | 
			
		||||
@ -392,7 +392,8 @@ class DASH:
 | 
			
		||||
                # last chance to find the KID, assumes first segment will hold the init data
 | 
			
		||||
                track_kid = track_kid or track.get_key_id(url=segments[0][0], session=session)
 | 
			
		||||
                # license and grab content keys
 | 
			
		||||
                drm = track.drm[0]  # just use the first supported DRM system for now
 | 
			
		||||
                # TODO: What if we don't want to use the first DRM system?
 | 
			
		||||
                drm = track.drm[0]
 | 
			
		||||
                if isinstance(drm, Widevine):
 | 
			
		||||
                    if not license_widevine:
 | 
			
		||||
                        raise ValueError("license_widevine func must be supplied to use Widevine DRM")
 | 
			
		||||
@ -404,74 +405,26 @@ class DASH:
 | 
			
		||||
                progress(downloaded="[yellow]SKIPPED")
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
 | 
			
		||||
                if stop_event.is_set():
 | 
			
		||||
                    # the track already started downloading, but another failed or was stopped
 | 
			
		||||
                    raise KeyboardInterrupt()
 | 
			
		||||
 | 
			
		||||
                segment_save_path = (save_dir / filename).with_suffix(".mp4")
 | 
			
		||||
 | 
			
		||||
                segment_uri, segment_range = segment
 | 
			
		||||
 | 
			
		||||
                attempts = 1
 | 
			
		||||
                while True:
 | 
			
		||||
                    try:
 | 
			
		||||
                        downloader_ = downloader
 | 
			
		||||
                        headers_ = session.headers
 | 
			
		||||
                        if segment_range:
 | 
			
		||||
                            # aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
 | 
			
		||||
                            downloader_ = requests_downloader
 | 
			
		||||
                            headers_["Range"] = f"bytes={segment_range}"
 | 
			
		||||
                        downloader_(
 | 
			
		||||
                            uri=segment_uri,
 | 
			
		||||
                            out=segment_save_path,
 | 
			
		||||
                            headers=headers_,
 | 
			
		||||
                            proxy=proxy,
 | 
			
		||||
                            silent=attempts != 5,
 | 
			
		||||
                            segmented=True
 | 
			
		||||
                        )
 | 
			
		||||
                        break
 | 
			
		||||
                    except Exception as ee:
 | 
			
		||||
                        if stop_event.is_set() or attempts == 5:
 | 
			
		||||
                            raise ee
 | 
			
		||||
                        time.sleep(2)
 | 
			
		||||
                        attempts += 1
 | 
			
		||||
 | 
			
		||||
                data_size = segment_save_path.stat().st_size
 | 
			
		||||
 | 
			
		||||
                # fix audio decryption on ATVP by fixing the sample description index
 | 
			
		||||
                # TODO: Should this be done in the video data or the init data?
 | 
			
		||||
                if isinstance(track, Audio):
 | 
			
		||||
                    with open(segment_save_path, "rb+") as f:
 | 
			
		||||
                        segment_data = f.read()
 | 
			
		||||
                        fixed_segment_data = re.sub(
 | 
			
		||||
                            b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
 | 
			
		||||
                            b"\\g<1>\x01",
 | 
			
		||||
                            segment_data
 | 
			
		||||
                        )
 | 
			
		||||
                        if fixed_segment_data != segment_data:
 | 
			
		||||
                            f.seek(0)
 | 
			
		||||
                            f.write(fixed_segment_data)
 | 
			
		||||
 | 
			
		||||
                return data_size
 | 
			
		||||
 | 
			
		||||
            progress(total=len(segments))
 | 
			
		||||
 | 
			
		||||
            finished_threads = 0
 | 
			
		||||
            download_sizes = []
 | 
			
		||||
            download_speed_window = 5
 | 
			
		||||
            last_speed_refresh = time.time()
 | 
			
		||||
 | 
			
		||||
            with ThreadPoolExecutor(max_workers=16) as pool:
 | 
			
		||||
                for download in futures.as_completed((
 | 
			
		||||
                for i, download in enumerate(futures.as_completed((
 | 
			
		||||
                    pool.submit(
 | 
			
		||||
                        download_segment,
 | 
			
		||||
                        filename=str(i).zfill(len(str(len(segments)))),
 | 
			
		||||
                        segment=segment
 | 
			
		||||
                        DASH.download_segment,
 | 
			
		||||
                        url=url,
 | 
			
		||||
                        out_path=(save_dir / str(n).zfill(len(str(len(segments))))).with_suffix(".mp4"),
 | 
			
		||||
                        track=track,
 | 
			
		||||
                        proxy=proxy,
 | 
			
		||||
                        headers=session.headers,
 | 
			
		||||
                        bytes_range=bytes_range,
 | 
			
		||||
                        stop_event=stop_event
 | 
			
		||||
                    )
 | 
			
		||||
                    for i, segment in enumerate(segments)
 | 
			
		||||
                )):
 | 
			
		||||
                    finished_threads += 1
 | 
			
		||||
 | 
			
		||||
                    for n, (url, bytes_range) in enumerate(segments)
 | 
			
		||||
                ))):
 | 
			
		||||
                    try:
 | 
			
		||||
                        download_size = download.result()
 | 
			
		||||
                    except KeyboardInterrupt:
 | 
			
		||||
@ -482,16 +435,15 @@ class DASH:
 | 
			
		||||
                        # tell dl that it was cancelled
 | 
			
		||||
                        # the pool is already shut down, so exiting loop is fine
 | 
			
		||||
                        raise
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        stop_event.set()  # skip pending track downloads
 | 
			
		||||
                        progress(downloaded="[red]FAILING")
 | 
			
		||||
                        pool.shutdown(wait=True, cancel_futures=True)
 | 
			
		||||
                        progress(downloaded="[red]FAILED")
 | 
			
		||||
                        # tell dl that it failed
 | 
			
		||||
                        # the pool is already shut down, so exiting loop is fine
 | 
			
		||||
                        raise e
 | 
			
		||||
                        raise
 | 
			
		||||
                    else:
 | 
			
		||||
                        # it successfully downloaded, and it was not cancelled
 | 
			
		||||
                        progress(advance=1)
 | 
			
		||||
 | 
			
		||||
                        now = time.time()
 | 
			
		||||
@ -500,7 +452,7 @@ class DASH:
 | 
			
		||||
                        if download_size:  # no size == skipped dl
 | 
			
		||||
                            download_sizes.append(download_size)
 | 
			
		||||
 | 
			
		||||
                        if download_sizes and (time_since > 5 or finished_threads == len(segments)):
 | 
			
		||||
                        if download_sizes and (time_since > download_speed_window or i == len(segments)):
 | 
			
		||||
                            data_size = sum(download_sizes)
 | 
			
		||||
                            download_speed = data_size / (time_since or 1)
 | 
			
		||||
                            progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
 | 
			
		||||
@ -527,6 +479,76 @@ class DASH:
 | 
			
		||||
 | 
			
		||||
            progress(downloaded="Downloaded")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def download_segment(
 | 
			
		||||
        url: str,
 | 
			
		||||
        out_path: Path,
 | 
			
		||||
        track: AnyTrack,
 | 
			
		||||
        proxy: Optional[str] = None,
 | 
			
		||||
        headers: Optional[MutableMapping[str, str | bytes]] = None,
 | 
			
		||||
        bytes_range: Optional[str] = None,
 | 
			
		||||
        stop_event: Optional[Event] = None
 | 
			
		||||
    ) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        Download a DASH Media Segment.
 | 
			
		||||
 | 
			
		||||
        Parameters:
 | 
			
		||||
            url: Full HTTP(S) URL to the Segment you want to download.
 | 
			
		||||
            out_path: Path to save the downloaded Segment file to.
 | 
			
		||||
            track: The Track object of which this Segment is for. Currently only used to
 | 
			
		||||
                fix an invalid value in the TFHD box of Audio Tracks.
 | 
			
		||||
            proxy: Proxy URI to use when downloading the Segment file.
 | 
			
		||||
            headers: HTTP Headers to send when requesting the Segment file.
 | 
			
		||||
            bytes_range: Download only specific bytes of the Segment file using the Range header.
 | 
			
		||||
            stop_event: Prematurely stop the Download from beginning. Useful if ran from
 | 
			
		||||
                a Thread Pool. It will raise a KeyboardInterrupt if set.
 | 
			
		||||
 | 
			
		||||
        Returns the file size of the downloaded Segment in bytes.
 | 
			
		||||
        """
 | 
			
		||||
        if stop_event and stop_event.is_set():
 | 
			
		||||
            raise KeyboardInterrupt()
 | 
			
		||||
 | 
			
		||||
        attempts = 1
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                headers_ = headers or {}
 | 
			
		||||
                if bytes_range:
 | 
			
		||||
                    # aria2(c) doesn't support byte ranges, use python-requests
 | 
			
		||||
                    downloader_ = requests_downloader
 | 
			
		||||
                    headers_["Range"] = f"bytes={bytes_range}"
 | 
			
		||||
                else:
 | 
			
		||||
                    downloader_ = downloader
 | 
			
		||||
                downloader_(
 | 
			
		||||
                    uri=url,
 | 
			
		||||
                    out=out_path,
 | 
			
		||||
                    headers=headers_,
 | 
			
		||||
                    proxy=proxy,
 | 
			
		||||
                    silent=attempts != 5,
 | 
			
		||||
                    segmented=True
 | 
			
		||||
                )
 | 
			
		||||
                break
 | 
			
		||||
            except Exception as ee:
 | 
			
		||||
                if (stop_event and stop_event.is_set()) or attempts == 5:
 | 
			
		||||
                    raise ee
 | 
			
		||||
                time.sleep(2)
 | 
			
		||||
                attempts += 1
 | 
			
		||||
 | 
			
		||||
        # fix audio decryption on ATVP by fixing the sample description index
 | 
			
		||||
        # TODO: Should this be done in the video data or the init data?
 | 
			
		||||
        if isinstance(track, Audio):
 | 
			
		||||
            with open(out_path, "rb+") as f:
 | 
			
		||||
                segment_data = f.read()
 | 
			
		||||
                fixed_segment_data = re.sub(
 | 
			
		||||
                    b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
 | 
			
		||||
                    b"\\g<1>\x01",
 | 
			
		||||
                    segment_data
 | 
			
		||||
                )
 | 
			
		||||
                if fixed_segment_data != segment_data:
 | 
			
		||||
                    f.seek(0)
 | 
			
		||||
                    f.write(fixed_segment_data)
 | 
			
		||||
 | 
			
		||||
        return out_path.stat().st_size
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _get(
 | 
			
		||||
        item: str,
 | 
			
		||||
 | 
			
		||||
@ -214,137 +214,6 @@ class HLS:
 | 
			
		||||
            log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
 | 
			
		||||
            sys.exit(1)
 | 
			
		||||
 | 
			
		||||
        drm_lock = Lock()
 | 
			
		||||
 | 
			
		||||
        def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
 | 
			
		||||
            if stop_event.is_set():
 | 
			
		||||
                # the track already started downloading, but another failed or was stopped
 | 
			
		||||
                raise KeyboardInterrupt()
 | 
			
		||||
 | 
			
		||||
            if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
 | 
			
		||||
                return 0
 | 
			
		||||
 | 
			
		||||
            segment_save_path = (save_dir / filename).with_suffix(".mp4")
 | 
			
		||||
 | 
			
		||||
            newest_init_data = init_data.get()
 | 
			
		||||
            try:
 | 
			
		||||
                if segment.init_section and (not newest_init_data or segment.discontinuity):
 | 
			
		||||
                    # Only use the init data if there's no init data yet (e.g., start of file)
 | 
			
		||||
                    # or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
 | 
			
		||||
                    # Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
 | 
			
		||||
                    # be unnecessary and slow to re-download the init data each time.
 | 
			
		||||
                    if not segment.init_section.uri.startswith(segment.init_section.base_uri):
 | 
			
		||||
                        segment.init_section.uri = segment.init_section.base_uri + segment.init_section.uri
 | 
			
		||||
 | 
			
		||||
                    if segment.init_section.byterange:
 | 
			
		||||
                        byte_range = HLS.calculate_byte_range(segment.init_section.byterange)
 | 
			
		||||
                        _ = range_offset.get()
 | 
			
		||||
                        range_offset.put(byte_range.split("-")[0])
 | 
			
		||||
                        headers = {
 | 
			
		||||
                            "Range": f"bytes={byte_range}"
 | 
			
		||||
                        }
 | 
			
		||||
                    else:
 | 
			
		||||
                        headers = {}
 | 
			
		||||
 | 
			
		||||
                    log.debug("Got new init segment, %s", segment.init_section.uri)
 | 
			
		||||
                    res = session.get(segment.init_section.uri, headers=headers)
 | 
			
		||||
                    res.raise_for_status()
 | 
			
		||||
                    newest_init_data = res.content
 | 
			
		||||
            finally:
 | 
			
		||||
                init_data.put(newest_init_data)
 | 
			
		||||
 | 
			
		||||
            with drm_lock:
 | 
			
		||||
                newest_segment_key = segment_key.get()
 | 
			
		||||
                try:
 | 
			
		||||
                    if segment.keys and newest_segment_key[1] != segment.keys:
 | 
			
		||||
                        try:
 | 
			
		||||
                            drm = HLS.get_drm(
 | 
			
		||||
                                keys=segment.keys,
 | 
			
		||||
                                proxy=proxy
 | 
			
		||||
                            )
 | 
			
		||||
                        except NotImplementedError as e:
 | 
			
		||||
                            log.error(str(e))
 | 
			
		||||
                            sys.exit(1)
 | 
			
		||||
                        else:
 | 
			
		||||
                            if drm:
 | 
			
		||||
                                track.drm = drm
 | 
			
		||||
                                drm = drm[0]  # just use the first supported DRM system for now
 | 
			
		||||
                                log.debug("Got segment key, %s", drm)
 | 
			
		||||
                                if isinstance(drm, Widevine):
 | 
			
		||||
                                    # license and grab content keys
 | 
			
		||||
                                    track_kid = track.get_key_id(newest_init_data)
 | 
			
		||||
                                    if not license_widevine:
 | 
			
		||||
                                        raise ValueError("license_widevine func must be supplied to use Widevine DRM")
 | 
			
		||||
                                    license_widevine(drm, track_kid=track_kid)
 | 
			
		||||
                                newest_segment_key = (drm, segment.keys)
 | 
			
		||||
                finally:
 | 
			
		||||
                    segment_key.put(newest_segment_key)
 | 
			
		||||
 | 
			
		||||
                if skip_event.is_set():
 | 
			
		||||
                    progress(downloaded="[yellow]SKIPPING")
 | 
			
		||||
                    return 0
 | 
			
		||||
 | 
			
		||||
            if not segment.uri.startswith(segment.base_uri):
 | 
			
		||||
                segment.uri = segment.base_uri + segment.uri
 | 
			
		||||
 | 
			
		||||
            attempts = 1
 | 
			
		||||
            while True:
 | 
			
		||||
                try:
 | 
			
		||||
                    downloader_ = downloader
 | 
			
		||||
                    headers_ = session.headers
 | 
			
		||||
                    if segment.byterange:
 | 
			
		||||
                        # aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
 | 
			
		||||
                        previous_range_offset = range_offset.get()
 | 
			
		||||
                        byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset)
 | 
			
		||||
                        range_offset.put(byte_range.split("-")[0])
 | 
			
		||||
                        downloader_ = requests_downloader
 | 
			
		||||
                        headers_["Range"] = f"bytes={byte_range}"
 | 
			
		||||
                    downloader_(
 | 
			
		||||
                        uri=segment.uri,
 | 
			
		||||
                        out=segment_save_path,
 | 
			
		||||
                        headers=headers_,
 | 
			
		||||
                        proxy=proxy,
 | 
			
		||||
                        silent=attempts != 5,
 | 
			
		||||
                        segmented=True
 | 
			
		||||
                    )
 | 
			
		||||
                    break
 | 
			
		||||
                except Exception as ee:
 | 
			
		||||
                    if stop_event.is_set() or attempts == 5:
 | 
			
		||||
                        raise ee
 | 
			
		||||
                    time.sleep(2)
 | 
			
		||||
                    attempts += 1
 | 
			
		||||
 | 
			
		||||
            data_size = segment_save_path.stat().st_size
 | 
			
		||||
 | 
			
		||||
            if isinstance(track, Audio) or newest_init_data:
 | 
			
		||||
                with open(segment_save_path, "rb+") as f:
 | 
			
		||||
                    segment_data = f.read()
 | 
			
		||||
                    if isinstance(track, Audio):
 | 
			
		||||
                        # fix audio decryption on ATVP by fixing the sample description index
 | 
			
		||||
                        # TODO: Is this in mpeg data, or init data?
 | 
			
		||||
                        segment_data = re.sub(
 | 
			
		||||
                            b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
 | 
			
		||||
                            b"\\g<1>\x01",
 | 
			
		||||
                            segment_data
 | 
			
		||||
                        )
 | 
			
		||||
                    # prepend the init data to be able to decrypt
 | 
			
		||||
                    if newest_init_data:
 | 
			
		||||
                        f.seek(0)
 | 
			
		||||
                        f.write(newest_init_data)
 | 
			
		||||
                        f.write(segment_data)
 | 
			
		||||
 | 
			
		||||
            if newest_segment_key[0]:
 | 
			
		||||
                newest_segment_key[0].decrypt(segment_save_path)
 | 
			
		||||
                track.drm = None
 | 
			
		||||
                if callable(track.OnDecrypted):
 | 
			
		||||
                    track.OnDecrypted(track)
 | 
			
		||||
 | 
			
		||||
            return data_size
 | 
			
		||||
 | 
			
		||||
        segment_key = Queue(maxsize=1)
 | 
			
		||||
        init_data = Queue(maxsize=1)
 | 
			
		||||
        range_offset = Queue(maxsize=1)
 | 
			
		||||
 | 
			
		||||
        if track.drm:
 | 
			
		||||
            session_drm = track.drm[0]  # just use the first supported DRM system for now
 | 
			
		||||
            if isinstance(session_drm, Widevine):
 | 
			
		||||
@ -355,30 +224,39 @@ class HLS:
 | 
			
		||||
        else:
 | 
			
		||||
            session_drm = None
 | 
			
		||||
 | 
			
		||||
        # have data to begin with, or it will be stuck waiting on the first pool forever
 | 
			
		||||
        segment_key.put((session_drm, None))
 | 
			
		||||
        init_data.put(None)
 | 
			
		||||
        range_offset.put(0)
 | 
			
		||||
 | 
			
		||||
        progress(total=len(master.segments))
 | 
			
		||||
 | 
			
		||||
        finished_threads = 0
 | 
			
		||||
        download_sizes = []
 | 
			
		||||
        download_speed_window = 5
 | 
			
		||||
        last_speed_refresh = time.time()
 | 
			
		||||
 | 
			
		||||
        with ThreadPoolExecutor(max_workers=16) as pool:
 | 
			
		||||
            for download in futures.as_completed((
 | 
			
		||||
                pool.submit(
 | 
			
		||||
                    download_segment,
 | 
			
		||||
                    filename=str(i).zfill(len(str(len(master.segments)))),
 | 
			
		||||
                    segment=segment,
 | 
			
		||||
                    init_data=init_data,
 | 
			
		||||
                    segment_key=segment_key
 | 
			
		||||
                )
 | 
			
		||||
                for i, segment in enumerate(master.segments)
 | 
			
		||||
            )):
 | 
			
		||||
                finished_threads += 1
 | 
			
		||||
        segment_key = Queue(maxsize=1)
 | 
			
		||||
        segment_key.put((session_drm, None))
 | 
			
		||||
        init_data = Queue(maxsize=1)
 | 
			
		||||
        init_data.put(None)
 | 
			
		||||
        range_offset = Queue(maxsize=1)
 | 
			
		||||
        range_offset.put(0)
 | 
			
		||||
        drm_lock = Lock()
 | 
			
		||||
 | 
			
		||||
        with ThreadPoolExecutor(max_workers=16) as pool:
 | 
			
		||||
            for i, download in enumerate(futures.as_completed((
 | 
			
		||||
                pool.submit(
 | 
			
		||||
                    HLS.download_segment,
 | 
			
		||||
                    segment=segment,
 | 
			
		||||
                    out_path=(save_dir / str(n).zfill(len(str(len(master.segments))))).with_suffix(".mp4"),
 | 
			
		||||
                    track=track,
 | 
			
		||||
                    init_data=init_data,
 | 
			
		||||
                    segment_key=segment_key,
 | 
			
		||||
                    range_offset=range_offset,
 | 
			
		||||
                    drm_lock=drm_lock,
 | 
			
		||||
                    license_widevine=license_widevine,
 | 
			
		||||
                    session=session,
 | 
			
		||||
                    proxy=proxy,
 | 
			
		||||
                    stop_event=stop_event,
 | 
			
		||||
                    skip_event=skip_event
 | 
			
		||||
                )
 | 
			
		||||
                for n, segment in enumerate(master.segments)
 | 
			
		||||
            ))):
 | 
			
		||||
                try:
 | 
			
		||||
                    download_size = download.result()
 | 
			
		||||
                except KeyboardInterrupt:
 | 
			
		||||
@ -401,13 +279,17 @@ class HLS:
 | 
			
		||||
                    # it successfully downloaded, and it was not cancelled
 | 
			
		||||
                    progress(advance=1)
 | 
			
		||||
 | 
			
		||||
                    if download_size == -1:  # skipped for --skip-dl
 | 
			
		||||
                        progress(downloaded="[yellow]SKIPPING")
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    now = time.time()
 | 
			
		||||
                    time_since = now - last_speed_refresh
 | 
			
		||||
 | 
			
		||||
                    if download_size:  # no size == skipped dl
 | 
			
		||||
                        download_sizes.append(download_size)
 | 
			
		||||
 | 
			
		||||
                    if download_sizes and (time_since > 5 or finished_threads == len(master.segments)):
 | 
			
		||||
                    if download_sizes and (time_since > download_speed_window or i == len(master.segments)):
 | 
			
		||||
                        data_size = sum(download_sizes)
 | 
			
		||||
                        download_speed = data_size / (time_since or 1)
 | 
			
		||||
                        progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
 | 
			
		||||
@ -424,6 +306,174 @@ class HLS:
 | 
			
		||||
        track.path = save_path
 | 
			
		||||
        save_dir.rmdir()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def download_segment(
 | 
			
		||||
        segment: m3u8.Segment,
 | 
			
		||||
        out_path: Path,
 | 
			
		||||
        track: AnyTrack,
 | 
			
		||||
        init_data: Queue,
 | 
			
		||||
        segment_key: Queue,
 | 
			
		||||
        range_offset: Queue,
 | 
			
		||||
        drm_lock: Lock,
 | 
			
		||||
        license_widevine: Optional[Callable] = None,
 | 
			
		||||
        session: Optional[Session] = None,
 | 
			
		||||
        proxy: Optional[str] = None,
 | 
			
		||||
        stop_event: Optional[Event] = None,
 | 
			
		||||
        skip_event: Optional[Event] = None
 | 
			
		||||
    ) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        Download (and Decrypt) an HLS Media Segment.
 | 
			
		||||
 | 
			
		||||
        Note: Make sure all Queue objects passed are appropriately initialized with
 | 
			
		||||
              a starting value or this function may get permanently stuck.
 | 
			
		||||
 | 
			
		||||
        Parameters:
 | 
			
		||||
            segment: The m3u8.Segment Object to Download.
 | 
			
		||||
            out_path: Path to save the downloaded Segment file to.
 | 
			
		||||
            track: The Track object of which this Segment is for. Currently used to fix an
 | 
			
		||||
                invalid value in the TFHD box of Audio Tracks, for the OnSegmentFilter, and
 | 
			
		||||
                for DRM-related operations like getting the Track ID and Decryption.
 | 
			
		||||
            init_data: Queue for saving and loading the most recent init section data.
 | 
			
		||||
            segment_key: Queue for saving and loading the most recent DRM object, and it's
 | 
			
		||||
                adjacent Segment.Key object.
 | 
			
		||||
            range_offset: Queue for saving and loading the most recent Segment Bytes Range.
 | 
			
		||||
            drm_lock: Prevent more than one Download from doing anything DRM-related at the
 | 
			
		||||
                same time. Make sure all calls to download_segment() use the same Lock object.
 | 
			
		||||
            license_widevine: Function used to license Widevine DRM objects. It must be passed
 | 
			
		||||
                if the Segment's DRM uses Widevine.
 | 
			
		||||
            proxy: Proxy URI to use when downloading the Segment file.
 | 
			
		||||
            session: Python-Requests Session used when requesting init data.
 | 
			
		||||
            stop_event: Prematurely stop the Download from beginning. Useful if ran from
 | 
			
		||||
                a Thread Pool. It will raise a KeyboardInterrupt if set.
 | 
			
		||||
            skip_event: Prematurely stop the Download from beginning. It returns with a
 | 
			
		||||
                file size of -1 directly after DRM licensing occurs, even if it's DRM-free.
 | 
			
		||||
                This is mainly for `--skip-dl` to allow licensing without downloading.
 | 
			
		||||
 | 
			
		||||
        Returns the file size of the downloaded Segment in bytes.
 | 
			
		||||
        """
 | 
			
		||||
        if stop_event.is_set():
 | 
			
		||||
            raise KeyboardInterrupt()
 | 
			
		||||
 | 
			
		||||
        if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
 | 
			
		||||
            return 0
 | 
			
		||||
 | 
			
		||||
        # handle init section changes
 | 
			
		||||
        newest_init_data = init_data.get()
 | 
			
		||||
        try:
 | 
			
		||||
            if segment.init_section and (not newest_init_data or segment.discontinuity):
 | 
			
		||||
                # Only use the init data if there's no init data yet (e.g., start of file)
 | 
			
		||||
                # or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
 | 
			
		||||
                # Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
 | 
			
		||||
                # be unnecessary and slow to re-download the init data each time.
 | 
			
		||||
                if not segment.init_section.uri.startswith(segment.init_section.base_uri):
 | 
			
		||||
                    segment.init_section.uri = segment.init_section.base_uri + segment.init_section.uri
 | 
			
		||||
 | 
			
		||||
                if segment.init_section.byterange:
 | 
			
		||||
                    byte_range = HLS.calculate_byte_range(segment.init_section.byterange)
 | 
			
		||||
                    _ = range_offset.get()
 | 
			
		||||
                    range_offset.put(byte_range.split("-")[0])
 | 
			
		||||
                    range_header = {
 | 
			
		||||
                        "Range": f"bytes={byte_range}"
 | 
			
		||||
                    }
 | 
			
		||||
                else:
 | 
			
		||||
                    range_header = {}
 | 
			
		||||
 | 
			
		||||
                res = session.get(segment.init_section.uri, headers=range_header)
 | 
			
		||||
                res.raise_for_status()
 | 
			
		||||
                newest_init_data = res.content
 | 
			
		||||
        finally:
 | 
			
		||||
            init_data.put(newest_init_data)
 | 
			
		||||
 | 
			
		||||
        # handle segment key changes
 | 
			
		||||
        with drm_lock:
 | 
			
		||||
            newest_segment_key = segment_key.get()
 | 
			
		||||
            try:
 | 
			
		||||
                if segment.keys and newest_segment_key[1] != segment.keys:
 | 
			
		||||
                    drm = HLS.get_drm(
 | 
			
		||||
                        keys=segment.keys,
 | 
			
		||||
                        proxy=proxy
 | 
			
		||||
                    )
 | 
			
		||||
                    if drm:
 | 
			
		||||
                        track.drm = drm
 | 
			
		||||
                        # license and grab content keys
 | 
			
		||||
                        # TODO: What if we don't want to use the first DRM system?
 | 
			
		||||
                        drm = drm[0]
 | 
			
		||||
                        if isinstance(drm, Widevine):
 | 
			
		||||
                            track_kid = track.get_key_id(newest_init_data)
 | 
			
		||||
                            if not license_widevine:
 | 
			
		||||
                                raise ValueError("license_widevine func must be supplied to use Widevine DRM")
 | 
			
		||||
                            license_widevine(drm, track_kid=track_kid)
 | 
			
		||||
                        newest_segment_key = (drm, segment.keys)
 | 
			
		||||
            finally:
 | 
			
		||||
                segment_key.put(newest_segment_key)
 | 
			
		||||
 | 
			
		||||
            if skip_event.is_set():
 | 
			
		||||
                return -1
 | 
			
		||||
 | 
			
		||||
        if not segment.uri.startswith(segment.base_uri):
 | 
			
		||||
            segment.uri = segment.base_uri + segment.uri
 | 
			
		||||
 | 
			
		||||
        attempts = 1
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                headers_ = session.headers
 | 
			
		||||
                if segment.byterange:
 | 
			
		||||
                    # aria2(c) doesn't support byte ranges, use python-requests
 | 
			
		||||
                    downloader_ = requests_downloader
 | 
			
		||||
                    previous_range_offset = range_offset.get()
 | 
			
		||||
                    byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset)
 | 
			
		||||
                    range_offset.put(byte_range.split("-")[0])
 | 
			
		||||
                    headers_["Range"] = f"bytes={byte_range}"
 | 
			
		||||
                else:
 | 
			
		||||
                    downloader_ = downloader
 | 
			
		||||
                downloader_(
 | 
			
		||||
                    uri=segment.uri,
 | 
			
		||||
                    out=out_path,
 | 
			
		||||
                    headers=headers_,
 | 
			
		||||
                    proxy=proxy,
 | 
			
		||||
                    silent=attempts != 5,
 | 
			
		||||
                    segmented=True
 | 
			
		||||
                )
 | 
			
		||||
                break
 | 
			
		||||
            except Exception as ee:
 | 
			
		||||
                if stop_event.is_set() or attempts == 5:
 | 
			
		||||
                    raise ee
 | 
			
		||||
                time.sleep(2)
 | 
			
		||||
                attempts += 1
 | 
			
		||||
 | 
			
		||||
        download_size = out_path.stat().st_size
 | 
			
		||||
 | 
			
		||||
        # fix audio decryption on ATVP by fixing the sample description index
 | 
			
		||||
        # TODO: Should this be done in the video data or the init data?
 | 
			
		||||
        if isinstance(track, Audio):
 | 
			
		||||
            with open(out_path, "rb+") as f:
 | 
			
		||||
                segment_data = f.read()
 | 
			
		||||
                fixed_segment_data = re.sub(
 | 
			
		||||
                    b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
 | 
			
		||||
                    b"\\g<1>\x01",
 | 
			
		||||
                    segment_data
 | 
			
		||||
                )
 | 
			
		||||
                if fixed_segment_data != segment_data:
 | 
			
		||||
                    f.seek(0)
 | 
			
		||||
                    f.write(fixed_segment_data)
 | 
			
		||||
 | 
			
		||||
        # prepend the init data to be able to decrypt
 | 
			
		||||
        if newest_init_data:
 | 
			
		||||
            with open(out_path, "rb+") as f:
 | 
			
		||||
                segment_data = f.read()
 | 
			
		||||
                f.seek(0)
 | 
			
		||||
                f.write(newest_init_data)
 | 
			
		||||
                f.write(segment_data)
 | 
			
		||||
 | 
			
		||||
        # decrypt segment if encrypted
 | 
			
		||||
        if newest_segment_key[0]:
 | 
			
		||||
            newest_segment_key[0].decrypt(out_path)
 | 
			
		||||
            track.drm = None
 | 
			
		||||
            if callable(track.OnDecrypted):
 | 
			
		||||
                track.OnDecrypted(track)
 | 
			
		||||
 | 
			
		||||
        return download_size
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_drm(
 | 
			
		||||
        keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user