From 50d12bbc5e13b2a34472408150c36a1e7ba7361f Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Fri, 16 May 2025 20:30:06 +1200 Subject: [PATCH] cleanup some long lines --- yt_dlp/extractor/youtube/pot/README.md | 4 +- .../youtube/pot/_builtin/memory_cache.py | 12 +- .../youtube/pot/_builtin/webpo_cachespec.py | 4 +- yt_dlp/extractor/youtube/pot/_director.py | 115 ++++++++++++------ yt_dlp/extractor/youtube/pot/_provider.py | 11 +- yt_dlp/extractor/youtube/pot/provider.py | 1 + yt_dlp/extractor/youtube/pot/utils.py | 10 +- 7 files changed, 103 insertions(+), 54 deletions(-) diff --git a/yt_dlp/extractor/youtube/pot/README.md b/yt_dlp/extractor/youtube/pot/README.md index 63af46cd68..79e1578e41 100644 --- a/yt_dlp/extractor/youtube/pot/README.md +++ b/yt_dlp/extractor/youtube/pot/README.md @@ -146,7 +146,7 @@ class MyPoTokenProviderPTP(PoTokenProvider): # Provider name must end with "PTP # you can define a preference function to increase/decrease the priority of providers. @register_preference(MyPoTokenProviderPTP) -def my_provider_preference(provider: PoTokenProvider, request: PoTokenRequest, *_, **__) -> int: +def my_provider_preference(provider: PoTokenProvider, request: PoTokenRequest) -> int: return 50 ``` @@ -228,7 +228,7 @@ class MyCacheProviderPCP(PoTokenCacheProvider): # Provider name must end with " @register_preference(MyCacheProviderPCP) -def my_cache_preference(provider: PoTokenCacheProvider, request: PoTokenRequest, *_, **__) -> int: +def my_cache_preference(provider: PoTokenCacheProvider, request: PoTokenRequest) -> int: return 50 ``` diff --git a/yt_dlp/extractor/youtube/pot/_builtin/memory_cache.py b/yt_dlp/extractor/youtube/pot/_builtin/memory_cache.py index 95e1c775f0..36ecba2ce5 100644 --- a/yt_dlp/extractor/youtube/pot/_builtin/memory_cache.py +++ b/yt_dlp/extractor/youtube/pot/_builtin/memory_cache.py @@ -23,7 +23,11 @@ def initialize_global_cache(max_size: int): if _pot_memory_cache.value['max_size'] != max_size: raise ValueError('Cannot change max_size of initialized global memory cache') - return _pot_memory_cache.value['cache'], _pot_memory_cache.value['lock'], _pot_memory_cache.value['max_size'] + return ( + _pot_memory_cache.value['cache'], + _pot_memory_cache.value['lock'], + _pot_memory_cache.value['max_size'], + ) @register_provider @@ -46,31 +50,25 @@ class MemoryLRUPCP(PoTokenCacheProvider, BuiltInIEContentProvider): def get(self, key: str) -> str | None: with self.lock: if key not in self.cache: - self.logger.trace('cache miss') return None value, expires_at = self.cache.pop(key) if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()): - self.logger.trace(f'cache expired key={key}') return None self.cache[key] = (value, expires_at) - self.logger.trace(f'cache hit key={key}') return value def store(self, key: str, value: str, expires_at: int): with self.lock: if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()): - self.logger.trace(f'ignoring expired key={key}') return if key in self.cache: self.cache.pop(key) self.cache[key] = (value, expires_at) if len(self.cache) > self.max_size: self.cache.popitem(last=False) - self.logger.trace(f'storing key={key}') def delete(self, key: str): with self.lock: - self.logger.trace(f'deleting key={key}') self.cache.pop(key, None) diff --git a/yt_dlp/extractor/youtube/pot/_builtin/webpo_cachespec.py b/yt_dlp/extractor/youtube/pot/_builtin/webpo_cachespec.py index aac30add09..cfe9db79e2 100644 --- a/yt_dlp/extractor/youtube/pot/_builtin/webpo_cachespec.py +++ b/yt_dlp/extractor/youtube/pot/_builtin/webpo_cachespec.py @@ -19,7 +19,9 @@ class WebPoPCSP(PoTokenCacheSpecProvider, BuiltInIEContentProvider): PROVIDER_NAME = 'webpo' def generate_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: - bind_to_visitor_id = self._configuration_arg('bind_to_visitor_id', default=['true'])[0] == 'true' + bind_to_visitor_id = self._configuration_arg( + 'bind_to_visitor_id', default=['true'])[0] == 'true' + content_binding, content_binding_type = get_webpo_content_binding( request, bind_to_visitor_id=bind_to_visitor_id) diff --git a/yt_dlp/extractor/youtube/pot/_director.py b/yt_dlp/extractor/youtube/pot/_director.py index 2903634777..75964ecbb8 100644 --- a/yt_dlp/extractor/youtube/pot/_director.py +++ b/yt_dlp/extractor/youtube/pot/_director.py @@ -85,14 +85,10 @@ class PoTokenCache: cache_provider_preferences: list[CacheProviderPreference] | None = None, ): self.cache_providers: dict[str, PoTokenCacheProvider] = { - provider.PROVIDER_KEY: provider for provider in (cache_providers or []) - } + provider.PROVIDER_KEY: provider for provider in (cache_providers or [])} self.cache_provider_preferences: list[CacheProviderPreference] = cache_provider_preferences or [] - self.cache_spec_providers: dict[str, PoTokenCacheSpecProvider] = { - provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or []) - } - + provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or [])} self.logger = logger def _get_cache_providers(self, request: PoTokenRequest) -> Iterable[PoTokenCacheProvider]: @@ -120,15 +116,18 @@ class PoTokenCache: if not spec: continue if not validate_cache_spec(spec): - self.logger.error(f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() returned invalid spec {spec}{provider_bug_report_message(provider)}') + self.logger.error( + f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() ' + f'returned invalid spec {spec}{provider_bug_report_message(provider)}') continue spec = dataclasses.replace(spec, _provider=provider) - self.logger.trace(f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"') + self.logger.trace( + f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"') return spec except Exception as e: self.logger.error( - f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: {e!r}{provider_bug_report_message(provider)}', - ) + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: ' + f'{e!r}{provider_bug_report_message(provider)}') continue return None @@ -140,7 +139,8 @@ class PoTokenCache: } if spec._provider: bindings_cleaned['_p'] = spec._provider.PROVIDER_KEY - self.logger.trace('Generate cache key bindings: {}'.format(', '.join(f'{k}={v}' for k, v in bindings_cleaned.items()))) + self.logger.trace( + 'Generate cache key bindings: {}'.format(', '.join(f'{k}={v}' for k, v in bindings_cleaned.items()))) return bindings_cleaned def _generate_key(self, bindings: dict) -> str: @@ -158,7 +158,8 @@ class PoTokenCache: for idx, provider in enumerate(self._get_cache_providers(request)): try: - self.logger.trace(f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider') + self.logger.trace( + f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider') cache_response = provider.get(cache_key) if not cache_response: continue @@ -167,10 +168,14 @@ class PoTokenCache: except (TypeError, ValueError, json.JSONDecodeError): po_token_response = None if not validate_response(po_token_response): - self.logger.error(f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": {cache_response}{provider_bug_report_message(provider)}') + self.logger.error( + f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": ' + f'{cache_response}{provider_bug_report_message(provider)}') provider.delete(cache_key) continue - self.logger.trace(f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: {po_token_response}') + self.logger.trace( + f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: ' + f'{po_token_response}') if idx > 0: # Write back to the highest priority cache provider, # so we stop trying to fetch from lower priority providers @@ -180,23 +185,32 @@ class PoTokenCache: return po_token_response except PoTokenCacheProviderError as e: self.logger.warning( - f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') continue except Exception as e: self.logger.error( - f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}', + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider)}', ) continue return None - def store(self, request: PoTokenRequest, response: PoTokenResponse, write_policy: CacheProviderWritePolicy | None = None): + def store( + self, + request: PoTokenRequest, + response: PoTokenResponse, + write_policy: CacheProviderWritePolicy | None = None, + ): spec = self._get_cache_spec(request) if not spec: self.logger.trace('No cache spec available for this request. Not caching.') return if not validate_response(response): - self.logger.error(f'Invalid PO Token response provided to PoTokenCache.store(): {response}{bug_reports_message()}') + self.logger.error( + f'Invalid PO Token response provided to PoTokenCache.store(): ' + f'{response}{bug_reports_message()}') return cache_key = self._generate_key(self._generate_key_bindings(spec)) @@ -211,16 +225,20 @@ class PoTokenCache: for idx, provider in enumerate(self._get_cache_providers(request)): try: self.logger.trace( - f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider (key={cache_key}, expires_at={cache_response.expires_at})', - ) - provider.store(key=cache_key, value=json.dumps(dataclasses.asdict(cache_response)), expires_at=cache_response.expires_at) + f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider ' + f'(key={cache_key}, expires_at={cache_response.expires_at})') + provider.store( + key=cache_key, + value=json.dumps(dataclasses.asdict(cache_response)), + expires_at=cache_response.expires_at) except PoTokenCacheProviderError as e: self.logger.warning( - f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') except Exception as e: self.logger.error( - f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}', - ) + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: ' + f'{e!r}{provider_bug_report_message(provider)}') # WRITE_FIRST should not write to lower priority providers in the case the highest priority provider fails if idx == 0 and write_policy == CacheProviderWritePolicy.WRITE_FIRST: @@ -260,32 +278,39 @@ class PoTokenRequestDirector: f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items()))) return ( - provider for provider in sorted(self.providers.values(), key=preferences.get, reverse=True) if provider.is_available() + provider for provider in sorted( + self.providers.values(), key=preferences.get, reverse=True) + if provider.is_available() ) def _get_po_token(self, request) -> PoTokenResponse | None: for provider in self._get_providers(request): try: - self.logger.trace(f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider') + self.logger.trace( + f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider') response = provider.request_pot(request.copy()) except PoTokenProviderRejectedRequest as e: self.logger.trace( - f'PO Token Provider "{provider.PROVIDER_NAME}" rejected this request, trying next available provider. Reason: {e}') + f'PO Token Provider "{provider.PROVIDER_NAME}" rejected this request, ' + f'trying next available provider. Reason: {e}') continue except PoTokenProviderError as e: self.logger.warning( - f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: ' + f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}') continue except Exception as e: self.logger.error( - f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider)}') + f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: ' + f'{e!r}{provider_bug_report_message(provider)}') continue self.logger.trace(f'PO Token response from "{provider.PROVIDER_NAME}" provider: {response}') if not validate_response(response): self.logger.error( - f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: {response}{provider_bug_report_message(provider)}') + f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: ' + f'{response}{provider_bug_report_message(provider)}') continue return response @@ -311,14 +336,14 @@ class PoTokenRequestDirector: if pot_response.expires_at is None or pot_response.expires_at > 0: self.cache.store(request, pot_response) else: - self.logger.trace(f'PO Token response will not be cached (expires_at={pot_response.expires_at})') + self.logger.trace( + f'PO Token response will not be cached (expires_at={pot_response.expires_at})') return pot_response.po_token def close(self): for provider in self.providers.values(): provider.close() - self.cache.close() @@ -342,7 +367,8 @@ def initialize_pot_director(ie): cache_providers = [] for cache_provider in _pot_cache_providers.value.values(): settings = traverse_obj( - ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_provider.PROVIDER_KEY.lower()}')) + ie._downloader.params, + ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_provider.PROVIDER_KEY.lower()}')) cache_provider_logger = YoutubeIEContentProviderLogger( ie, f'pot:cache:{cache_provider.PROVIDER_NAME}', log_level=log_level) cache_providers.append(cache_provider(ie, cache_provider_logger, settings or {})) @@ -350,7 +376,8 @@ def initialize_pot_director(ie): cache_spec_providers = [] for cache_spec_provider in _pot_pcs_providers.value.values(): settings = traverse_obj( - ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_spec_provider.PROVIDER_KEY.lower()}')) + ie._downloader.params, + ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_spec_provider.PROVIDER_KEY.lower()}')) cache_spec_provider_logger = YoutubeIEContentProviderLogger( ie, f'pot:cache:spec:{cache_spec_provider.PROVIDER_NAME}', log_level=log_level) cache_spec_providers.append(cache_spec_provider(ie, cache_spec_provider_logger, settings or {})) @@ -370,14 +397,18 @@ def initialize_pot_director(ie): ie._downloader.add_close_hook(director.close) for provider in _pot_providers.value.values(): - settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}')) - logger = YoutubeIEContentProviderLogger(ie, f'pot:{provider.PROVIDER_NAME}', log_level=log_level) + settings = traverse_obj( + ie._downloader.params, + ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}')) + logger = YoutubeIEContentProviderLogger( + ie, f'pot:{provider.PROVIDER_NAME}', log_level=log_level) director.register_provider(provider(ie, logger, settings or {})) for preference in _ptp_preferences.value: director.register_preference(preference) if director.logger.log_level <= director.logger.LogLevel.DEBUG: + # calling is_available() for every PO Token provider upfront may have some overhead director.logger.debug(f'PO Token Providers: {provider_display_list(director.providers.values())}') director.logger.debug(f'PO Token Cache Providers: {provider_display_list(cache.cache_providers.values())}') director.logger.debug(f'PO Token Cache Spec Providers: {provider_display_list(cache.cache_spec_providers.values())}') @@ -389,7 +420,9 @@ def initialize_pot_director(ie): def provider_display_list(providers: Iterable[IEContentProvider]): def provider_display_name(provider): - display_str = join_nonempty(provider.PROVIDER_NAME, provider.PROVIDER_VERSION if not isinstance(provider, BuiltInIEContentProvider) else None) + display_str = join_nonempty( + provider.PROVIDER_NAME, + provider.PROVIDER_VERSION if not isinstance(provider, BuiltInIEContentProvider) else None) statuses = [] if not isinstance(provider, BuiltInIEContentProvider): statuses.append('external') @@ -406,7 +439,8 @@ def clean_pot(po_token: str): # Clean and validate the PO Token. This will strip invalid characters off # (e.g. additional url params the user may accidentally include) try: - return base64.urlsafe_b64encode(base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode() + return base64.urlsafe_b64encode( + base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode() except (binascii.Error, ValueError): raise ValueError('Invalid PO Token') @@ -428,7 +462,10 @@ def validate_response(response: PoTokenResponse | None): response.expires_at is None or ( isinstance(response.expires_at, int) - and (response.expires_at <= 0 or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp())) + and ( + response.expires_at <= 0 + or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp()) + ) ) ) diff --git a/yt_dlp/extractor/youtube/pot/_provider.py b/yt_dlp/extractor/youtube/pot/_provider.py index dfc208a2db..b107c8b7b7 100644 --- a/yt_dlp/extractor/youtube/pot/_provider.py +++ b/yt_dlp/extractor/youtube/pot/_provider.py @@ -62,7 +62,12 @@ class IEContentProvider(abc.ABC): PROVIDER_VERSION: str = '0.0.0' BUG_REPORT_LOCATION: str = '(developer has not provided a bug report location)' - def __init__(self, ie: InfoExtractor, logger: IEContentProviderLogger, settings: dict[str, list[str]], *_, **__): + def __init__( + self, + ie: InfoExtractor, + logger: IEContentProviderLogger, + settings: dict[str, list[str]], *_, **__, + ): self.ie = ie self.settings = settings or {} self.logger = logger @@ -103,12 +108,12 @@ class IEContentProvider(abc.ABC): pass def _configuration_arg(self, key, default=NO_DEFAULT, *, casesense=False): - ''' + """ @returns A list of values for the setting given by "key" or "default" if no such key is present @param default The default value to return when the key is not present (default: []) @param casesense When false, the values are converted to lower case - ''' + """ val = traverse_obj(self.settings, key) if val is None: return [] if default is NO_DEFAULT else default diff --git a/yt_dlp/extractor/youtube/pot/provider.py b/yt_dlp/extractor/youtube/pot/provider.py index 6649f71443..80b21bfc77 100644 --- a/yt_dlp/extractor/youtube/pot/provider.py +++ b/yt_dlp/extractor/youtube/pot/provider.py @@ -241,6 +241,7 @@ def register_preference(*providers: type[PoTokenProvider]) -> typing.Callable[[P if typing.TYPE_CHECKING: Preference = typing.Callable[[PoTokenProvider, PoTokenRequest], int] + __all__.append('Preference') # Barebones innertube context. There may be more fields. class ClientInfo(typing.TypedDict, total=False): diff --git a/yt_dlp/extractor/youtube/pot/utils.py b/yt_dlp/extractor/youtube/pot/utils.py index bd9a12b099..68d1d6f02d 100644 --- a/yt_dlp/extractor/youtube/pot/utils.py +++ b/yt_dlp/extractor/youtube/pot/utils.py @@ -31,7 +31,12 @@ class ContentBindingType(enum.Enum): VISITOR_ID = 'visitor_id' -def get_webpo_content_binding(request: PoTokenRequest, webpo_clients=WEBPO_CLIENTS, bind_to_visitor_id=False) -> tuple[str | None, ContentBindingType | None]: +def get_webpo_content_binding( + request: PoTokenRequest, + webpo_clients=WEBPO_CLIENTS, + bind_to_visitor_id=False, +) -> tuple[str | None, ContentBindingType | None]: + client_name = traverse_obj(request.innertube_context, ('client', 'clientName')) if not client_name or client_name not in webpo_clients: return None, None @@ -59,7 +64,8 @@ def _extract_visitor_id(visitor_data): # Attempt to extract the visitor ID from the visitor_data protobuf # xxx: ideally should use a protobuf parser with contextlib.suppress(Exception): - visitor_id = base64.urlsafe_b64decode(urllib.parse.unquote_plus(visitor_data))[2:13].decode() + visitor_id = base64.urlsafe_b64decode( + urllib.parse.unquote_plus(visitor_data))[2:13].decode() # check that visitor id is all letters and numbers if re.fullmatch(r'[A-Za-z0-9_-]{11}', visitor_id): return visitor_id