2025-04-12 13:24:10 +12:00

434 lines
19 KiB
Python

from __future__ import annotations
import base64
import binascii
import dataclasses
import datetime as dt
import hashlib
import json
import typing
import urllib.parse
from collections.abc import Iterable
from yt_dlp.extractor.youtube.pot._provider import (
BuiltInIEContentProvider,
IEContentProvider,
IEContentProviderLogger,
)
from yt_dlp.extractor.youtube.pot._registry import (
_pot_cache_provider_preferences,
_pot_cache_providers,
_pot_pcs_providers,
_pot_providers,
_ptp_preferences,
)
from yt_dlp.extractor.youtube.pot.cache import (
CacheProviderWritePolicy,
PoTokenCacheProvider,
PoTokenCacheProviderError,
PoTokenCacheSpec,
PoTokenCacheSpecProvider,
)
from yt_dlp.extractor.youtube.pot.provider import (
PoTokenProvider,
PoTokenProviderError,
PoTokenProviderRejectedRequest,
PoTokenRequest,
PoTokenResponse,
provider_bug_report_message,
)
from yt_dlp.utils import ExtractorError, bug_reports_message, format_field, join_nonempty, traverse_obj
if typing.TYPE_CHECKING:
from yt_dlp.extractor.youtube.pot.cache import PCPPreference
class YoutubeIEContentProviderLogger(IEContentProviderLogger):
def __init__(self, ie, prefix, log_level=IEContentProviderLogger.LogLevel.INFO):
self.__ie = ie
self.prefix = prefix
self.log_level = log_level
def _format_msg(self, message: str):
prefixstr = format_field(self.prefix, None, '[%s] ')
return f'{prefixstr}{message}'
def trace(self, message: str):
if self.log_level <= self.LogLevel.TRACE:
self.__ie.write_debug(self._format_msg('TRACE: ' + message))
def debug(self, message: str):
if self.log_level <= self.LogLevel.DEBUG:
self.__ie.write_debug(self._format_msg(message))
def info(self, message: str):
if self.log_level <= self.LogLevel.INFO:
self.__ie.to_screen(self._format_msg(message))
def warning(self, message: str, *, once=False):
if self.log_level <= self.LogLevel.WARNING:
self.__ie.report_warning(self._format_msg(message), only_once=once)
def error(self, message: str):
if self.log_level <= self.LogLevel.ERROR:
self.__ie._downloader.report_error(self._format_msg(message), is_error=False)
class PoTokenCache:
def __init__(
self,
logger: IEContentProviderLogger,
cache_providers: list[PoTokenCacheProvider],
cache_spec_providers: list[PoTokenCacheSpecProvider],
cache_provider_preferences: list[PCPPreference] | None = None,
):
self.cache_providers: dict[str, PoTokenCacheProvider] = {
provider.PROVIDER_KEY: provider for provider in (cache_providers or [])
}
self.cache_provider_preferences: list[PCPPreference] = cache_provider_preferences or []
self.cache_spec_providers: dict[str, PoTokenCacheSpecProvider] = {
provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or [])
}
self.logger = logger
def _get_cache_providers(self, request: PoTokenRequest) -> Iterable[PoTokenCacheProvider]:
"""Sorts available cache providers by preference, given a request"""
preferences = {
provider: sum(pref(provider, request) for pref in self.cache_provider_preferences)
for provider in self.cache_providers.values()
}
if self.logger.log_level <= self.logger.LogLevel.TRACE:
# calling is_available() for every PO Token provider upfront may have some overhead
self.logger.trace(f'PO Token Cache Providers: {provider_display_list(self.cache_providers.values())}')
self.logger.trace('Cache Provider preferences for this request: {}'.format(', '.join(
f'{provider.PROVIDER_KEY}={pref}' for provider, pref in preferences.items())))
return (
provider for provider in sorted(self.cache_providers.values(), key=preferences.get, reverse=True) if provider.is_available()
)
def _get_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None:
for provider in self.cache_spec_providers.values():
if not provider.is_available():
continue
try:
spec = provider.generate_cache_spec(request)
if not spec:
continue
if not validate_cache_spec(spec):
self.logger.error(f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() returned invalid spec {spec}{provider_bug_report_message(provider)}')
continue
spec = dataclasses.replace(spec, _provider=provider)
self.logger.trace(f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"')
return spec
except Exception as e:
self.logger.error(
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: {e!r}{provider_bug_report_message(provider)}',
)
continue
def _generate_key_bindings(self, spec: PoTokenCacheSpec) -> dict[str, str]:
bindings_cleaned = {
**{k: v for k, v in spec.key_bindings.items() if v is not None},
# Allow us to invalidate caches if such need arises
'_yt': 'v1',
'_p': spec._provider.PROVIDER_KEY,
}
self.logger.trace('Generate cache key bindings: {}'.format(', '.join(f'{k}={v}' for k, v in bindings_cleaned.items())))
return bindings_cleaned
def _generate_key(self, bindings: dict) -> str:
binding_string = ''.join(f'{k}{v}' for k, v in sorted(bindings.items()))
return hashlib.sha256(binding_string.encode()).hexdigest()
def get(self, request: PoTokenRequest) -> PoTokenResponse | None:
spec = self._get_cache_spec(request)
if not spec:
self.logger.trace('No cache spec available for this request, unable to fetch from cache')
return None
cache_key = self._generate_key(self._generate_key_bindings(spec))
self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}')
for idx, provider in enumerate(self._get_cache_providers(request)):
try:
self.logger.trace(f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider')
cache_response = provider.get(cache_key)
if not cache_response:
continue
try:
po_token_response = PoTokenResponse(**json.loads(cache_response))
except (TypeError, ValueError, json.JSONDecodeError):
po_token_response = None
if not validate_response(po_token_response):
self.logger.error(f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": {cache_response}{provider_bug_report_message(provider)}')
provider.delete(cache_key)
continue
self.logger.trace(f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: {po_token_response}')
if idx > 0:
# Write back to the highest priority cache provider,
# so we stop trying to fetch from lower priority providers
self.logger.trace('Writing PO Token response to highest priority cache provider')
self.store(request, po_token_response, write_policy=CacheProviderWritePolicy.WRITE_FIRST)
return po_token_response
except PoTokenCacheProviderError as e:
self.logger.warning(
f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
continue
except Exception as e:
self.logger.error(
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}',
)
continue
def store(self, request: PoTokenRequest, response: PoTokenResponse, write_policy: CacheProviderWritePolicy | None = None):
spec = self._get_cache_spec(request)
if not spec:
self.logger.trace('No cache spec available for this request. Not caching.')
return
if not validate_response(response):
self.logger.error(f'Invalid PO Token response provided to PoTokenCache.store(): {response}{bug_reports_message()}')
return
cache_key = self._generate_key(self._generate_key_bindings(spec))
self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}')
default_expires_at = int(dt.datetime.now(dt.timezone.utc).timestamp()) + spec.default_ttl
cache_response = dataclasses.replace(response, expires_at=response.expires_at or default_expires_at)
write_policy = write_policy or spec.write_policy
self.logger.trace(f'Using write policy: {write_policy}')
for idx, provider in enumerate(self._get_cache_providers(request)):
try:
self.logger.trace(
f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider (key={cache_key}, expires_at={cache_response.expires_at})',
)
provider.store(key=cache_key, value=json.dumps(dataclasses.asdict(cache_response)), expires_at=cache_response.expires_at)
except PoTokenCacheProviderError as e:
self.logger.warning(
f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
except Exception as e:
self.logger.error(
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}',
)
# WRITE_FIRST should not write to lower priority providers in the case the highest priority provider fails
if idx == 0 and write_policy == CacheProviderWritePolicy.WRITE_FIRST:
return
def close(self):
for provider in self.cache_providers.values():
provider.close()
for spec_provider in self.cache_spec_providers.values():
spec_provider.close()
class PoTokenRequestDirector:
def __init__(self, logger: IEContentProviderLogger, cache: PoTokenCache):
self.providers = {}
self.preferences = []
self.cache = cache
self.logger = logger
def register_provider(self, provider: PoTokenProvider):
self.providers[provider.PROVIDER_KEY] = provider
def register_preference(self, preference):
self.preferences.append(preference)
def _get_providers(self, request: PoTokenRequest) -> Iterable[PoTokenProvider]:
"""Sorts available providers by preference, given a request"""
preferences = {
provider: sum(pref(provider, request) for pref in self.preferences)
for provider in self.providers.values()
}
if self.logger.log_level <= self.logger.LogLevel.TRACE:
# calling is_available() for every PO Token provider upfront may have some overhead
self.logger.trace(f'PO Token Providers: {provider_display_list(self.providers.values())}')
self.logger.trace('Provider preferences for this request: {}'.format(', '.join(
f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items())))
return (
provider for provider in sorted(self.providers.values(), key=preferences.get, reverse=True) if provider.is_available()
)
def _get_po_token(self, request) -> PoTokenResponse | None:
for provider in self._get_providers(request):
try:
self.logger.trace(f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider')
response = provider.request_pot(request.copy())
except PoTokenProviderRejectedRequest as e:
self.logger.trace(
f'PO Token Provider "{provider.PROVIDER_NAME}" does not support this request, trying next available provider. Reason: {e}')
continue
except PoTokenProviderError as e:
self.logger.warning(
f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
continue
except Exception as e:
self.logger.error(
f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider)}')
continue
self.logger.trace(f'PO Token response from "{provider.PROVIDER_NAME}" provider: {response}')
if not validate_response(response):
self.logger.error(
f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: {response}{provider_bug_report_message(provider)}')
continue
return response
self.logger.trace('No PO Token providers were able to provide a valid PO Token')
return None
def get_po_token(self, request: PoTokenRequest) -> str | None:
if not request.bypass_cache:
if pot_response := self.cache.get(request):
return clean_pot(pot_response.po_token)
if not self.providers:
self.logger.trace('No PO Token providers registered')
return None
pot_response = self._get_po_token(request)
if not pot_response:
return None
pot_response.po_token = clean_pot(pot_response.po_token)
if pot_response.expires_at is None or pot_response.expires_at > 0:
self.cache.store(request, pot_response)
else:
self.logger.trace(f'PO Token response will not be cached (expires_at={pot_response.expires_at})')
return pot_response.po_token
def close(self):
for provider in self.providers.values():
provider.close()
self.cache.close()
EXTRACTOR_ARG_PREFIX = 'youtubepot'
def initialize_pot_director(ie):
if not ie._downloader:
raise ExtractorError('Downloader not set', expected=False)
log_level = min(
IEContentProviderLogger.LogLevel(ie._configuration_arg('pot_log_level', ['INFO'], ie_key='youtube', casesense=False)[0].upper()),
IEContentProviderLogger.LogLevel.DEBUG if ie._downloader.params.get('verbose', False) else IEContentProviderLogger.LogLevel.INFO,
)
cache_providers = []
for cache_provider in _pot_cache_providers.value.values():
settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_provider.PROVIDER_KEY.lower()}'))
cache_provider_logger = YoutubeIEContentProviderLogger(ie, f'pot:cache:{cache_provider.PROVIDER_NAME}', log_level=log_level)
cache_providers.append(cache_provider(ie, cache_provider_logger, settings or {}))
cache_spec_providers = []
for cache_spec_provider in _pot_pcs_providers.value.values():
settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_spec_provider.PROVIDER_KEY.lower()}'))
cache_spec_provider_logger = YoutubeIEContentProviderLogger(ie, f'pot:cache:spec:{cache_spec_provider.PROVIDER_NAME}', log_level=log_level)
cache_spec_providers.append(cache_spec_provider(ie, cache_spec_provider_logger, settings or {}))
cache = PoTokenCache(
logger=YoutubeIEContentProviderLogger(ie, 'pot:cache', log_level=log_level),
cache_providers=cache_providers,
cache_spec_providers=cache_spec_providers,
cache_provider_preferences=list(_pot_cache_provider_preferences.value),
)
director = PoTokenRequestDirector(
logger=YoutubeIEContentProviderLogger(ie, 'pot', log_level=log_level),
cache=cache,
)
ie._downloader.add_close_hook(director.close)
for provider in _pot_providers.value.values():
settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}'))
logger = YoutubeIEContentProviderLogger(ie, f'pot:{provider.PROVIDER_NAME}', log_level=log_level)
director.register_provider(provider(ie, logger, settings or {}))
for preference in _ptp_preferences.value:
director.register_preference(preference)
if director.logger.log_level <= director.logger.LogLevel.DEBUG:
director.logger.debug(f'PO Token Providers: {provider_display_list(director.providers.values())}')
director.logger.debug(f'PO Token Cache Providers: {provider_display_list(cache.cache_providers.values())}')
director.logger.debug(f'PO Token Cache Spec Providers: {provider_display_list(cache.cache_spec_providers.values())}')
director.logger.trace(f'Registered {len(director.preferences)} provider preferences')
director.logger.trace(f'Registered {len(cache.cache_provider_preferences)} cache provider preferences')
return director
def provider_display_list(providers: Iterable[IEContentProvider]):
def provider_display_name(provider):
display_str = join_nonempty(provider.PROVIDER_NAME, provider.PROVIDER_VERSION if not isinstance(provider, BuiltInIEContentProvider) else None)
statuses = []
if not isinstance(provider, BuiltInIEContentProvider):
statuses.append('external')
if not provider.is_available():
statuses.append('unavailable')
if statuses:
display_str += f' ({", ".join(statuses)})'
return display_str
return ', '.join(provider_display_name(provider) for provider in providers) or 'none'
def clean_pot(po_token: str):
# Clean and validate the PO Token. This will strip invalid characters off
# (e.g. additional url params the user may accidentally include)
try:
return base64.urlsafe_b64encode(base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode()
except (binascii.Error, ValueError):
raise ValueError('Invalid PO Token')
def validate_response(response: PoTokenResponse):
if (
not isinstance(response, PoTokenResponse)
or not response.po_token
or not isinstance(response.po_token, str)
): # noqa: SIM103
return False
try:
clean_pot(response.po_token)
except ValueError:
return False
return (
response.expires_at is None
or (
isinstance(response.expires_at, int)
and (response.expires_at <= 0 or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp()))
)
)
def validate_cache_spec(spec: PoTokenCacheSpec):
return (
isinstance(spec, PoTokenCacheSpec)
and isinstance(spec.write_policy, CacheProviderWritePolicy)
and isinstance(spec.default_ttl, int)
and isinstance(spec.key_bindings, dict)
and all(isinstance(k, str) for k in spec.key_bindings)
and all(v is None or isinstance(v, str) for v in spec.key_bindings.values())
and len({k for k in spec.key_bindings.values() if k is not None}) > 0
)