mirror of
				https://github.com/devine-dl/devine.git
				synced 2025-11-04 03:44:49 +00:00 
			
		
		
		
	Multi-thread the new HLS download system
This mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
		
							parent
							
								
									314079c75f
								
							
						
					
					
						commit
						9e6f5b25f3
					
				@ -4,8 +4,14 @@ import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
import re
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import traceback
 | 
			
		||||
from concurrent import futures
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
from hashlib import md5
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from queue import Queue
 | 
			
		||||
from threading import Event
 | 
			
		||||
from typing import Any, Callable, Optional, Union
 | 
			
		||||
 | 
			
		||||
import m3u8
 | 
			
		||||
@ -205,21 +211,17 @@ class HLS:
 | 
			
		||||
            log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
 | 
			
		||||
            sys.exit(1)
 | 
			
		||||
 | 
			
		||||
        init_data = None
 | 
			
		||||
        last_segment_key: tuple[Optional[Union[ClearKey, Widevine]], Optional[m3u8.Key]] = (None, None)
 | 
			
		||||
        state_event = Event()
 | 
			
		||||
 | 
			
		||||
        for i, segment in enumerate(tqdm(master.segments, unit="segments")):
 | 
			
		||||
            segment_filename = str(i).zfill(len(str(len(master.segments))))
 | 
			
		||||
            segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
 | 
			
		||||
        def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue):
 | 
			
		||||
            time.sleep(0.1)
 | 
			
		||||
            if state_event.is_set():
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            if segment.key and last_segment_key[1] != segment.key:
 | 
			
		||||
                # try:
 | 
			
		||||
                #     drm = HLS.get_drm([segment.key])
 | 
			
		||||
                # except NotImplementedError:
 | 
			
		||||
                #     drm = None  # never mind, try with master.keys
 | 
			
		||||
                # if not drm and master.keys:
 | 
			
		||||
                #     # TODO: segment might have multiple keys but m3u8 only grabs the last!
 | 
			
		||||
                #     drm = HLS.get_drm(master.keys)
 | 
			
		||||
            segment_save_path = (save_dir / filename).with_suffix(".mp4")
 | 
			
		||||
 | 
			
		||||
            newest_segment_key = segment_key.get()
 | 
			
		||||
            if segment.key and newest_segment_key[1] != segment.key:
 | 
			
		||||
                try:
 | 
			
		||||
                    drm = HLS.get_drm(
 | 
			
		||||
                        # TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY
 | 
			
		||||
@ -242,12 +244,14 @@ class HLS:
 | 
			
		||||
                            if not license_widevine:
 | 
			
		||||
                                raise ValueError("license_widevine func must be supplied to use Widevine DRM")
 | 
			
		||||
                            license_widevine(drm)
 | 
			
		||||
                        last_segment_key = (drm, segment.key)
 | 
			
		||||
                        newest_segment_key = (drm, segment.key)
 | 
			
		||||
            segment_key.put(newest_segment_key)
 | 
			
		||||
 | 
			
		||||
            if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
 | 
			
		||||
                continue
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            if segment.init_section and (not init_data or segment.discontinuity):
 | 
			
		||||
            newest_init_data = init_data.get()
 | 
			
		||||
            if segment.init_section and (not newest_init_data or segment.discontinuity):
 | 
			
		||||
                # Only use the init data if there's no init data yet (e.g., start of file)
 | 
			
		||||
                # or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
 | 
			
		||||
                # Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
 | 
			
		||||
@ -258,7 +262,8 @@ class HLS:
 | 
			
		||||
                log.debug("Got new init segment, %s", segment.init_section.uri)
 | 
			
		||||
                res = session.get(segment.init_section.uri)
 | 
			
		||||
                res.raise_for_status()
 | 
			
		||||
                init_data = res.content
 | 
			
		||||
                newest_init_data = res.content
 | 
			
		||||
            init_data.put(newest_init_data)
 | 
			
		||||
 | 
			
		||||
            if not segment.uri.startswith(segment.base_uri):
 | 
			
		||||
                segment.uri = segment.base_uri + segment.uri
 | 
			
		||||
@ -270,7 +275,7 @@ class HLS:
 | 
			
		||||
                proxy
 | 
			
		||||
            ))
 | 
			
		||||
 | 
			
		||||
            if isinstance(track, Audio) or init_data:
 | 
			
		||||
            if isinstance(track, Audio) or newest_init_data:
 | 
			
		||||
                with open(segment_save_path, "rb+") as f:
 | 
			
		||||
                    segment_data = f.read()
 | 
			
		||||
                    if isinstance(track, Audio):
 | 
			
		||||
@ -282,17 +287,53 @@ class HLS:
 | 
			
		||||
                            segment_data
 | 
			
		||||
                        )
 | 
			
		||||
                    # prepend the init data to be able to decrypt
 | 
			
		||||
                    if init_data:
 | 
			
		||||
                    if newest_init_data:
 | 
			
		||||
                        f.seek(0)
 | 
			
		||||
                        f.write(init_data)
 | 
			
		||||
                        f.write(newest_init_data)
 | 
			
		||||
                        f.write(segment_data)
 | 
			
		||||
 | 
			
		||||
            if last_segment_key[0]:
 | 
			
		||||
                last_segment_key[0].decrypt(segment_save_path)
 | 
			
		||||
            if newest_segment_key[0]:
 | 
			
		||||
                newest_segment_key[0].decrypt(segment_save_path)
 | 
			
		||||
                track.drm = None
 | 
			
		||||
                if callable(track.OnDecrypted):
 | 
			
		||||
                    track.OnDecrypted(track)
 | 
			
		||||
 | 
			
		||||
        init_data = Queue(maxsize=1)
 | 
			
		||||
        segment_key = Queue(maxsize=1)
 | 
			
		||||
        # otherwise will be stuck waiting on the first pool, forever
 | 
			
		||||
        init_data.put(None)
 | 
			
		||||
        segment_key.put((None, None))
 | 
			
		||||
 | 
			
		||||
        with tqdm(total=len(master.segments), unit="segments") as pbar:
 | 
			
		||||
            with ThreadPoolExecutor(max_workers=16) as pool:
 | 
			
		||||
                try:
 | 
			
		||||
                    for download in futures.as_completed((
 | 
			
		||||
                        pool.submit(
 | 
			
		||||
                            download_segment,
 | 
			
		||||
                            filename=str(i).zfill(len(str(len(master.segments)))),
 | 
			
		||||
                            segment=segment,
 | 
			
		||||
                            init_data=init_data,
 | 
			
		||||
                            segment_key=segment_key
 | 
			
		||||
                        )
 | 
			
		||||
                        for i, segment in enumerate(master.segments)
 | 
			
		||||
                    )):
 | 
			
		||||
                        if download.cancelled():
 | 
			
		||||
                            continue
 | 
			
		||||
                        e = download.exception()
 | 
			
		||||
                        if e:
 | 
			
		||||
                            state_event.set()
 | 
			
		||||
                            pool.shutdown(wait=False, cancel_futures=True)
 | 
			
		||||
                            traceback.print_exception(e)
 | 
			
		||||
                            log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
 | 
			
		||||
                            sys.exit(1)
 | 
			
		||||
                        else:
 | 
			
		||||
                            pbar.update(1)
 | 
			
		||||
                except KeyboardInterrupt:
 | 
			
		||||
                    state_event.set()
 | 
			
		||||
                    pool.shutdown(wait=False, cancel_futures=True)
 | 
			
		||||
                    log.info("Received Keyboard Interrupt, stopping...")
 | 
			
		||||
                    return
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_drm(
 | 
			
		||||
        keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user