from __future__ import annotations import base64 import binascii import dataclasses import datetime as dt import hashlib import json import re import traceback import typing import urllib.parse from collections.abc import Iterable from yt_dlp.extractor.youtube.pot._provider import ( BuiltinIEContentProvider, IEContentProvider, IEContentProviderLogger, ) from yt_dlp.extractor.youtube.pot._registry import ( _pot_cache_provider_preferences, _pot_cache_providers, _pot_pcs_providers, _pot_providers, _ptp_preferences, ) from yt_dlp.extractor.youtube.pot.cache import ( CacheProviderWritePolicy, PoTokenCacheProvider, PoTokenCacheProviderError, PoTokenCacheSpec, PoTokenCacheSpecProvider, ) from yt_dlp.extractor.youtube.pot.provider import ( PoTokenProvider, PoTokenProviderError, PoTokenProviderRejectedRequest, PoTokenRequest, PoTokenResponse, provider_bug_report_message, ) from yt_dlp.utils import bug_reports_message, format_field, join_nonempty if typing.TYPE_CHECKING: from yt_dlp.extractor.youtube.pot.cache import CacheProviderPreference from yt_dlp.extractor.youtube.pot.provider import Preference class YoutubeIEContentProviderLogger(IEContentProviderLogger): def __init__(self, ie, prefix, log_level: IEContentProviderLogger.LogLevel | None = None): self.__ie = ie self.prefix = prefix self.log_level = log_level if log_level is not None else self.LogLevel.INFO def _format_msg(self, message: str): prefixstr = format_field(self.prefix, None, '[%s] ') return f'{prefixstr}{message}' def trace(self, message: str): if self.log_level <= self.LogLevel.TRACE: self.__ie.write_debug(self._format_msg('TRACE: ' + message)) def debug(self, message: str, *, once=False): if self.log_level <= self.LogLevel.DEBUG: self.__ie.write_debug(self._format_msg(message), only_once=once) def info(self, message: str): if self.log_level <= self.LogLevel.INFO: self.__ie.to_screen(self._format_msg(message)) def warning(self, message: str, *, once=False): if self.log_level <= self.LogLevel.WARNING: self.__ie.report_warning(self._format_msg(message), only_once=once) def error(self, message: str, cause=None): if self.log_level <= self.LogLevel.ERROR: self.__ie._downloader.report_error( self._format_msg(message), is_error=False, tb=''.join(traceback.format_exception(None, cause, cause.__traceback__)) if cause else None) class PoTokenCache: def __init__( self, logger: IEContentProviderLogger, cache_providers: list[PoTokenCacheProvider], cache_spec_providers: list[PoTokenCacheSpecProvider], cache_provider_preferences: list[CacheProviderPreference] | None = None, ): self.cache_providers: dict[str, PoTokenCacheProvider] = { provider.PROVIDER_KEY: provider for provider in (cache_providers or [])} self.cache_provider_preferences: list[CacheProviderPreference] = cache_provider_preferences or [] self.cache_spec_providers: dict[str, PoTokenCacheSpecProvider] = { provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or [])} self.logger = logger def _get_cache_providers(self, request: PoTokenRequest) -> Iterable[PoTokenCacheProvider]: """Sorts available cache providers by preference, given a request""" preferences = { provider: sum(pref(provider, request) for pref in self.cache_provider_preferences) for provider in self.cache_providers.values() } if self.logger.log_level <= self.logger.LogLevel.TRACE: # calling is_available() for every PO Token provider upfront may have some overhead self.logger.trace(f'PO Token Cache Providers: {provider_display_list(self.cache_providers.values())}') self.logger.trace('Cache Provider preferences for this request: {}'.format(', '.join( f'{provider.PROVIDER_KEY}={pref}' for provider, pref in preferences.items()))) return ( provider for provider in sorted( self.cache_providers.values(), key=preferences.get, reverse=True) if provider.is_available()) def _get_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: for provider in self.cache_spec_providers.values(): if not provider.is_available(): continue try: spec = provider.generate_cache_spec(request) if not spec: continue if not validate_cache_spec(spec): self.logger.error( f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() ' f'returned invalid spec {spec}{provider_bug_report_message(provider)}') continue spec = dataclasses.replace(spec, _provider=provider) self.logger.trace( f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"') return spec except Exception as e: self.logger.error( f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: ' f'{e!r}{provider_bug_report_message(provider)}') continue return None def _generate_key_bindings(self, spec: PoTokenCacheSpec) -> dict[str, str]: bindings_cleaned = { **{k: v for k, v in spec.key_bindings.items() if v is not None}, # Allow us to invalidate caches if such need arises '_dlp_cache': 'v1', } if spec._provider: bindings_cleaned['_p'] = spec._provider.PROVIDER_KEY self.logger.trace(f'Generated cache key bindings: {bindings_cleaned}') return bindings_cleaned def _generate_key(self, bindings: dict) -> str: binding_string = ''.join(repr(dict(sorted(bindings.items())))) return hashlib.sha256(binding_string.encode()).hexdigest() def get(self, request: PoTokenRequest) -> PoTokenResponse | None: spec = self._get_cache_spec(request) if not spec: self.logger.trace('No cache spec available for this request, unable to fetch from cache') return None cache_key = self._generate_key(self._generate_key_bindings(spec)) self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}') for idx, provider in enumerate(self._get_cache_providers(request)): try: self.logger.trace( f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider') cache_response = provider.get(cache_key) if not cache_response: continue try: po_token_response = PoTokenResponse(**json.loads(cache_response)) except (TypeError, ValueError, json.JSONDecodeError): po_token_response = None if not validate_response(po_token_response): self.logger.error( f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": ' f'{cache_response}{provider_bug_report_message(provider)}') provider.delete(cache_key) continue self.logger.trace( f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: ' f'{po_token_response}') if idx > 0: # Write back to the highest priority cache provider, # so we stop trying to fetch from lower priority providers self.logger.trace('Writing PO Token response to highest priority cache provider') self.store(request, po_token_response, write_policy=CacheProviderWritePolicy.WRITE_FIRST) return po_token_response except PoTokenCacheProviderError as e: self.logger.warning( f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: ' f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') continue except Exception as e: self.logger.error( f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: ' f'{e!r}{provider_bug_report_message(provider)}', ) continue return None def store( self, request: PoTokenRequest, response: PoTokenResponse, write_policy: CacheProviderWritePolicy | None = None, ): spec = self._get_cache_spec(request) if not spec: self.logger.trace('No cache spec available for this request. Not caching.') return if not validate_response(response): self.logger.error( f'Invalid PO Token response provided to PoTokenCache.store(): ' f'{response}{bug_reports_message()}') return cache_key = self._generate_key(self._generate_key_bindings(spec)) self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}') default_expires_at = int(dt.datetime.now(dt.timezone.utc).timestamp()) + spec.default_ttl cache_response = dataclasses.replace(response, expires_at=response.expires_at or default_expires_at) write_policy = write_policy or spec.write_policy self.logger.trace(f'Using write policy: {write_policy}') for idx, provider in enumerate(self._get_cache_providers(request)): try: self.logger.trace( f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider ' f'(key={cache_key}, expires_at={cache_response.expires_at})') provider.store( key=cache_key, value=json.dumps(dataclasses.asdict(cache_response)), expires_at=cache_response.expires_at) except PoTokenCacheProviderError as e: self.logger.warning( f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: ' f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') except Exception as e: self.logger.error( f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: ' f'{e!r}{provider_bug_report_message(provider)}') # WRITE_FIRST should not write to lower priority providers in the case the highest priority provider fails if idx == 0 and write_policy == CacheProviderWritePolicy.WRITE_FIRST: return def close(self): for provider in self.cache_providers.values(): provider.close() for spec_provider in self.cache_spec_providers.values(): spec_provider.close() class PoTokenRequestDirector: def __init__(self, logger: IEContentProviderLogger, cache: PoTokenCache): self.providers: dict[str, PoTokenProvider] = {} self.preferences: list[Preference] = [] self.cache = cache self.logger = logger def register_provider(self, provider: PoTokenProvider): self.providers[provider.PROVIDER_KEY] = provider def register_preference(self, preference: Preference): self.preferences.append(preference) def _get_providers(self, request: PoTokenRequest) -> Iterable[PoTokenProvider]: """Sorts available providers by preference, given a request""" preferences = { provider: sum(pref(provider, request) for pref in self.preferences) for provider in self.providers.values() } if self.logger.log_level <= self.logger.LogLevel.TRACE: # calling is_available() for every PO Token provider upfront may have some overhead self.logger.trace(f'PO Token Providers: {provider_display_list(self.providers.values())}') self.logger.trace('Provider preferences for this request: {}'.format(', '.join( f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items()))) return ( provider for provider in sorted( self.providers.values(), key=preferences.get, reverse=True) if provider.is_available() ) def _get_po_token(self, request) -> PoTokenResponse | None: for provider in self._get_providers(request): try: self.logger.trace( f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider') response = provider.request_pot(request.copy()) except PoTokenProviderRejectedRequest as e: self.logger.trace( f'PO Token Provider "{provider.PROVIDER_NAME}" rejected this request, ' f'trying next available provider. Reason: {e}') continue except PoTokenProviderError as e: self.logger.warning( f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: ' f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') continue except Exception as e: self.logger.error( f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: ' f'{e!r}{provider_bug_report_message(provider)}') continue self.logger.trace(f'PO Token response from "{provider.PROVIDER_NAME}" provider: {response}') if not validate_response(response): self.logger.error( f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: ' f'{response}{provider_bug_report_message(provider)}') continue return response self.logger.trace('No PO Token providers were able to provide a valid PO Token') return None def get_po_token(self, request: PoTokenRequest) -> str | None: if not request.bypass_cache: if pot_response := self.cache.get(request): return clean_pot(pot_response.po_token) if not self.providers: self.logger.trace('No PO Token providers registered') return None pot_response = self._get_po_token(request) if not pot_response: return None pot_response.po_token = clean_pot(pot_response.po_token) if pot_response.expires_at is None or pot_response.expires_at > 0: self.cache.store(request, pot_response) else: self.logger.trace( f'PO Token response will not be cached (expires_at={pot_response.expires_at})') return pot_response.po_token def close(self): for provider in self.providers.values(): provider.close() self.cache.close() EXTRACTOR_ARG_PREFIX = 'youtubepot' def initialize_pot_director(ie): assert ie._downloader is not None, 'Downloader not set' enable_trace = ie._configuration_arg( 'pot_trace', ['false'], ie_key='youtube', casesense=False)[0] == 'true' if enable_trace: log_level = IEContentProviderLogger.LogLevel.TRACE elif ie.get_param('verbose', False): log_level = IEContentProviderLogger.LogLevel.DEBUG else: log_level = IEContentProviderLogger.LogLevel.INFO def get_provider_logger_and_settings(provider, logger_key): logger_prefix = f'{logger_key}:{provider.PROVIDER_NAME}' extractor_key = f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}' return ( YoutubeIEContentProviderLogger(ie, logger_prefix, log_level=log_level), ie.get_param('extractor_args', {}).get(extractor_key, {})) cache_providers = [] for cache_provider in _pot_cache_providers.value.values(): logger, settings = get_provider_logger_and_settings(cache_provider, 'pot:cache') cache_providers.append(cache_provider(ie, logger, settings)) cache_spec_providers = [] for cache_spec_provider in _pot_pcs_providers.value.values(): logger, settings = get_provider_logger_and_settings(cache_spec_provider, 'pot:cache:spec') cache_spec_providers.append(cache_spec_provider(ie, logger, settings)) cache = PoTokenCache( logger=YoutubeIEContentProviderLogger(ie, 'pot:cache', log_level=log_level), cache_providers=cache_providers, cache_spec_providers=cache_spec_providers, cache_provider_preferences=list(_pot_cache_provider_preferences.value), ) director = PoTokenRequestDirector( logger=YoutubeIEContentProviderLogger(ie, 'pot', log_level=log_level), cache=cache, ) ie._downloader.add_close_hook(director.close) for provider in _pot_providers.value.values(): logger, settings = get_provider_logger_and_settings(provider, 'pot') director.register_provider(provider(ie, logger, settings)) for preference in _ptp_preferences.value: director.register_preference(preference) if director.logger.log_level <= director.logger.LogLevel.DEBUG: # calling is_available() for every PO Token provider upfront may have some overhead director.logger.debug(f'PO Token Providers: {provider_display_list(director.providers.values())}') director.logger.debug(f'PO Token Cache Providers: {provider_display_list(cache.cache_providers.values())}') director.logger.debug(f'PO Token Cache Spec Providers: {provider_display_list(cache.cache_spec_providers.values())}') director.logger.trace(f'Registered {len(director.preferences)} provider preferences') director.logger.trace(f'Registered {len(cache.cache_provider_preferences)} cache provider preferences') return director def provider_display_list(providers: Iterable[IEContentProvider]): def provider_display_name(provider): display_str = join_nonempty( provider.PROVIDER_NAME, provider.PROVIDER_VERSION if not isinstance(provider, BuiltinIEContentProvider) else None) statuses = [] if not isinstance(provider, BuiltinIEContentProvider): statuses.append('external') if not provider.is_available(): statuses.append('unavailable') if statuses: display_str += f' ({", ".join(statuses)})' return display_str return ', '.join(provider_display_name(provider) for provider in providers) or 'none' def clean_pot(po_token: str): # Clean and validate the PO Token. This will strip invalid characters off # (e.g. additional url params the user may accidentally include) mobj = re.match(r'([^?&#]+)', urllib.parse.unquote(po_token)) if not mobj: raise ValueError('Invalid PO Token') try: return base64.urlsafe_b64encode( base64.urlsafe_b64decode(mobj.group(1))).decode() except (binascii.Error, ValueError): raise ValueError('Invalid PO Token') def validate_response(response: PoTokenResponse | None): if ( not isinstance(response, PoTokenResponse) or not isinstance(response.po_token, str) or not response.po_token ): # noqa: SIM103 return False try: clean_pot(response.po_token) except ValueError: return False if not isinstance(response.expires_at, int): return response.expires_at is None return response.expires_at <= 0 or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp()) def validate_cache_spec(spec: PoTokenCacheSpec): return ( isinstance(spec, PoTokenCacheSpec) and isinstance(spec.write_policy, CacheProviderWritePolicy) and isinstance(spec.default_ttl, int) and isinstance(spec.key_bindings, dict) and all(isinstance(k, str) for k in spec.key_bindings) and all(v is None or isinstance(v, str) for v in spec.key_bindings.values()) and bool([v for v in spec.key_bindings.values() if v is not None]) )