cleanup some long lines

This commit is contained in:
coletdjnz 2025-05-16 20:30:06 +12:00
parent 6cdeec4332
commit 50d12bbc5e
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
7 changed files with 103 additions and 54 deletions

View File

@ -146,7 +146,7 @@ class MyPoTokenProviderPTP(PoTokenProvider): # Provider name must end with "PTP
# you can define a preference function to increase/decrease the priority of providers.
@register_preference(MyPoTokenProviderPTP)
def my_provider_preference(provider: PoTokenProvider, request: PoTokenRequest, *_, **__) -> int:
def my_provider_preference(provider: PoTokenProvider, request: PoTokenRequest) -> int:
return 50
```
@ -228,7 +228,7 @@ class MyCacheProviderPCP(PoTokenCacheProvider): # Provider name must end with "
@register_preference(MyCacheProviderPCP)
def my_cache_preference(provider: PoTokenCacheProvider, request: PoTokenRequest, *_, **__) -> int:
def my_cache_preference(provider: PoTokenCacheProvider, request: PoTokenRequest) -> int:
return 50
```

View File

@ -23,7 +23,11 @@ def initialize_global_cache(max_size: int):
if _pot_memory_cache.value['max_size'] != max_size:
raise ValueError('Cannot change max_size of initialized global memory cache')
return _pot_memory_cache.value['cache'], _pot_memory_cache.value['lock'], _pot_memory_cache.value['max_size']
return (
_pot_memory_cache.value['cache'],
_pot_memory_cache.value['lock'],
_pot_memory_cache.value['max_size'],
)
@register_provider
@ -46,31 +50,25 @@ class MemoryLRUPCP(PoTokenCacheProvider, BuiltInIEContentProvider):
def get(self, key: str) -> str | None:
with self.lock:
if key not in self.cache:
self.logger.trace('cache miss')
return None
value, expires_at = self.cache.pop(key)
if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()):
self.logger.trace(f'cache expired key={key}')
return None
self.cache[key] = (value, expires_at)
self.logger.trace(f'cache hit key={key}')
return value
def store(self, key: str, value: str, expires_at: int):
with self.lock:
if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()):
self.logger.trace(f'ignoring expired key={key}')
return
if key in self.cache:
self.cache.pop(key)
self.cache[key] = (value, expires_at)
if len(self.cache) > self.max_size:
self.cache.popitem(last=False)
self.logger.trace(f'storing key={key}')
def delete(self, key: str):
with self.lock:
self.logger.trace(f'deleting key={key}')
self.cache.pop(key, None)

View File

@ -19,7 +19,9 @@ class WebPoPCSP(PoTokenCacheSpecProvider, BuiltInIEContentProvider):
PROVIDER_NAME = 'webpo'
def generate_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None:
bind_to_visitor_id = self._configuration_arg('bind_to_visitor_id', default=['true'])[0] == 'true'
bind_to_visitor_id = self._configuration_arg(
'bind_to_visitor_id', default=['true'])[0] == 'true'
content_binding, content_binding_type = get_webpo_content_binding(
request, bind_to_visitor_id=bind_to_visitor_id)

View File

@ -85,14 +85,10 @@ class PoTokenCache:
cache_provider_preferences: list[CacheProviderPreference] | None = None,
):
self.cache_providers: dict[str, PoTokenCacheProvider] = {
provider.PROVIDER_KEY: provider for provider in (cache_providers or [])
}
provider.PROVIDER_KEY: provider for provider in (cache_providers or [])}
self.cache_provider_preferences: list[CacheProviderPreference] = cache_provider_preferences or []
self.cache_spec_providers: dict[str, PoTokenCacheSpecProvider] = {
provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or [])
}
provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or [])}
self.logger = logger
def _get_cache_providers(self, request: PoTokenRequest) -> Iterable[PoTokenCacheProvider]:
@ -120,15 +116,18 @@ class PoTokenCache:
if not spec:
continue
if not validate_cache_spec(spec):
self.logger.error(f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() returned invalid spec {spec}{provider_bug_report_message(provider)}')
self.logger.error(
f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() '
f'returned invalid spec {spec}{provider_bug_report_message(provider)}')
continue
spec = dataclasses.replace(spec, _provider=provider)
self.logger.trace(f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"')
self.logger.trace(
f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"')
return spec
except Exception as e:
self.logger.error(
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: {e!r}{provider_bug_report_message(provider)}',
)
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: '
f'{e!r}{provider_bug_report_message(provider)}')
continue
return None
@ -140,7 +139,8 @@ class PoTokenCache:
}
if spec._provider:
bindings_cleaned['_p'] = spec._provider.PROVIDER_KEY
self.logger.trace('Generate cache key bindings: {}'.format(', '.join(f'{k}={v}' for k, v in bindings_cleaned.items())))
self.logger.trace(
'Generate cache key bindings: {}'.format(', '.join(f'{k}={v}' for k, v in bindings_cleaned.items())))
return bindings_cleaned
def _generate_key(self, bindings: dict) -> str:
@ -158,7 +158,8 @@ class PoTokenCache:
for idx, provider in enumerate(self._get_cache_providers(request)):
try:
self.logger.trace(f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider')
self.logger.trace(
f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider')
cache_response = provider.get(cache_key)
if not cache_response:
continue
@ -167,10 +168,14 @@ class PoTokenCache:
except (TypeError, ValueError, json.JSONDecodeError):
po_token_response = None
if not validate_response(po_token_response):
self.logger.error(f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": {cache_response}{provider_bug_report_message(provider)}')
self.logger.error(
f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": '
f'{cache_response}{provider_bug_report_message(provider)}')
provider.delete(cache_key)
continue
self.logger.trace(f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: {po_token_response}')
self.logger.trace(
f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: '
f'{po_token_response}')
if idx > 0:
# Write back to the highest priority cache provider,
# so we stop trying to fetch from lower priority providers
@ -180,23 +185,32 @@ class PoTokenCache:
return po_token_response
except PoTokenCacheProviderError as e:
self.logger.warning(
f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: '
f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
continue
except Exception as e:
self.logger.error(
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}',
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: '
f'{e!r}{provider_bug_report_message(provider)}',
)
continue
return None
def store(self, request: PoTokenRequest, response: PoTokenResponse, write_policy: CacheProviderWritePolicy | None = None):
def store(
self,
request: PoTokenRequest,
response: PoTokenResponse,
write_policy: CacheProviderWritePolicy | None = None,
):
spec = self._get_cache_spec(request)
if not spec:
self.logger.trace('No cache spec available for this request. Not caching.')
return
if not validate_response(response):
self.logger.error(f'Invalid PO Token response provided to PoTokenCache.store(): {response}{bug_reports_message()}')
self.logger.error(
f'Invalid PO Token response provided to PoTokenCache.store(): '
f'{response}{bug_reports_message()}')
return
cache_key = self._generate_key(self._generate_key_bindings(spec))
@ -211,16 +225,20 @@ class PoTokenCache:
for idx, provider in enumerate(self._get_cache_providers(request)):
try:
self.logger.trace(
f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider (key={cache_key}, expires_at={cache_response.expires_at})',
)
provider.store(key=cache_key, value=json.dumps(dataclasses.asdict(cache_response)), expires_at=cache_response.expires_at)
f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider '
f'(key={cache_key}, expires_at={cache_response.expires_at})')
provider.store(
key=cache_key,
value=json.dumps(dataclasses.asdict(cache_response)),
expires_at=cache_response.expires_at)
except PoTokenCacheProviderError as e:
self.logger.warning(
f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: '
f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
except Exception as e:
self.logger.error(
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}',
)
f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: '
f'{e!r}{provider_bug_report_message(provider)}')
# WRITE_FIRST should not write to lower priority providers in the case the highest priority provider fails
if idx == 0 and write_policy == CacheProviderWritePolicy.WRITE_FIRST:
@ -260,32 +278,39 @@ class PoTokenRequestDirector:
f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items())))
return (
provider for provider in sorted(self.providers.values(), key=preferences.get, reverse=True) if provider.is_available()
provider for provider in sorted(
self.providers.values(), key=preferences.get, reverse=True)
if provider.is_available()
)
def _get_po_token(self, request) -> PoTokenResponse | None:
for provider in self._get_providers(request):
try:
self.logger.trace(f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider')
self.logger.trace(
f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider')
response = provider.request_pot(request.copy())
except PoTokenProviderRejectedRequest as e:
self.logger.trace(
f'PO Token Provider "{provider.PROVIDER_NAME}" rejected this request, trying next available provider. Reason: {e}')
f'PO Token Provider "{provider.PROVIDER_NAME}" rejected this request, '
f'trying next available provider. Reason: {e}')
continue
except PoTokenProviderError as e:
self.logger.warning(
f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: '
f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
continue
except Exception as e:
self.logger.error(
f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider)}')
f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: '
f'{e!r}{provider_bug_report_message(provider)}')
continue
self.logger.trace(f'PO Token response from "{provider.PROVIDER_NAME}" provider: {response}')
if not validate_response(response):
self.logger.error(
f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: {response}{provider_bug_report_message(provider)}')
f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: '
f'{response}{provider_bug_report_message(provider)}')
continue
return response
@ -311,14 +336,14 @@ class PoTokenRequestDirector:
if pot_response.expires_at is None or pot_response.expires_at > 0:
self.cache.store(request, pot_response)
else:
self.logger.trace(f'PO Token response will not be cached (expires_at={pot_response.expires_at})')
self.logger.trace(
f'PO Token response will not be cached (expires_at={pot_response.expires_at})')
return pot_response.po_token
def close(self):
for provider in self.providers.values():
provider.close()
self.cache.close()
@ -342,7 +367,8 @@ def initialize_pot_director(ie):
cache_providers = []
for cache_provider in _pot_cache_providers.value.values():
settings = traverse_obj(
ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_provider.PROVIDER_KEY.lower()}'))
ie._downloader.params,
('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_provider.PROVIDER_KEY.lower()}'))
cache_provider_logger = YoutubeIEContentProviderLogger(
ie, f'pot:cache:{cache_provider.PROVIDER_NAME}', log_level=log_level)
cache_providers.append(cache_provider(ie, cache_provider_logger, settings or {}))
@ -350,7 +376,8 @@ def initialize_pot_director(ie):
cache_spec_providers = []
for cache_spec_provider in _pot_pcs_providers.value.values():
settings = traverse_obj(
ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_spec_provider.PROVIDER_KEY.lower()}'))
ie._downloader.params,
('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_spec_provider.PROVIDER_KEY.lower()}'))
cache_spec_provider_logger = YoutubeIEContentProviderLogger(
ie, f'pot:cache:spec:{cache_spec_provider.PROVIDER_NAME}', log_level=log_level)
cache_spec_providers.append(cache_spec_provider(ie, cache_spec_provider_logger, settings or {}))
@ -370,14 +397,18 @@ def initialize_pot_director(ie):
ie._downloader.add_close_hook(director.close)
for provider in _pot_providers.value.values():
settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}'))
logger = YoutubeIEContentProviderLogger(ie, f'pot:{provider.PROVIDER_NAME}', log_level=log_level)
settings = traverse_obj(
ie._downloader.params,
('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}'))
logger = YoutubeIEContentProviderLogger(
ie, f'pot:{provider.PROVIDER_NAME}', log_level=log_level)
director.register_provider(provider(ie, logger, settings or {}))
for preference in _ptp_preferences.value:
director.register_preference(preference)
if director.logger.log_level <= director.logger.LogLevel.DEBUG:
# calling is_available() for every PO Token provider upfront may have some overhead
director.logger.debug(f'PO Token Providers: {provider_display_list(director.providers.values())}')
director.logger.debug(f'PO Token Cache Providers: {provider_display_list(cache.cache_providers.values())}')
director.logger.debug(f'PO Token Cache Spec Providers: {provider_display_list(cache.cache_spec_providers.values())}')
@ -389,7 +420,9 @@ def initialize_pot_director(ie):
def provider_display_list(providers: Iterable[IEContentProvider]):
def provider_display_name(provider):
display_str = join_nonempty(provider.PROVIDER_NAME, provider.PROVIDER_VERSION if not isinstance(provider, BuiltInIEContentProvider) else None)
display_str = join_nonempty(
provider.PROVIDER_NAME,
provider.PROVIDER_VERSION if not isinstance(provider, BuiltInIEContentProvider) else None)
statuses = []
if not isinstance(provider, BuiltInIEContentProvider):
statuses.append('external')
@ -406,7 +439,8 @@ def clean_pot(po_token: str):
# Clean and validate the PO Token. This will strip invalid characters off
# (e.g. additional url params the user may accidentally include)
try:
return base64.urlsafe_b64encode(base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode()
return base64.urlsafe_b64encode(
base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode()
except (binascii.Error, ValueError):
raise ValueError('Invalid PO Token')
@ -428,7 +462,10 @@ def validate_response(response: PoTokenResponse | None):
response.expires_at is None
or (
isinstance(response.expires_at, int)
and (response.expires_at <= 0 or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp()))
and (
response.expires_at <= 0
or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp())
)
)
)

View File

@ -62,7 +62,12 @@ class IEContentProvider(abc.ABC):
PROVIDER_VERSION: str = '0.0.0'
BUG_REPORT_LOCATION: str = '(developer has not provided a bug report location)'
def __init__(self, ie: InfoExtractor, logger: IEContentProviderLogger, settings: dict[str, list[str]], *_, **__):
def __init__(
self,
ie: InfoExtractor,
logger: IEContentProviderLogger,
settings: dict[str, list[str]], *_, **__,
):
self.ie = ie
self.settings = settings or {}
self.logger = logger
@ -103,12 +108,12 @@ class IEContentProvider(abc.ABC):
pass
def _configuration_arg(self, key, default=NO_DEFAULT, *, casesense=False):
'''
"""
@returns A list of values for the setting given by "key"
or "default" if no such key is present
@param default The default value to return when the key is not present (default: [])
@param casesense When false, the values are converted to lower case
'''
"""
val = traverse_obj(self.settings, key)
if val is None:
return [] if default is NO_DEFAULT else default

View File

@ -241,6 +241,7 @@ def register_preference(*providers: type[PoTokenProvider]) -> typing.Callable[[P
if typing.TYPE_CHECKING:
Preference = typing.Callable[[PoTokenProvider, PoTokenRequest], int]
__all__.append('Preference')
# Barebones innertube context. There may be more fields.
class ClientInfo(typing.TypedDict, total=False):

View File

@ -31,7 +31,12 @@ class ContentBindingType(enum.Enum):
VISITOR_ID = 'visitor_id'
def get_webpo_content_binding(request: PoTokenRequest, webpo_clients=WEBPO_CLIENTS, bind_to_visitor_id=False) -> tuple[str | None, ContentBindingType | None]:
def get_webpo_content_binding(
request: PoTokenRequest,
webpo_clients=WEBPO_CLIENTS,
bind_to_visitor_id=False,
) -> tuple[str | None, ContentBindingType | None]:
client_name = traverse_obj(request.innertube_context, ('client', 'clientName'))
if not client_name or client_name not in webpo_clients:
return None, None
@ -59,7 +64,8 @@ def _extract_visitor_id(visitor_data):
# Attempt to extract the visitor ID from the visitor_data protobuf
# xxx: ideally should use a protobuf parser
with contextlib.suppress(Exception):
visitor_id = base64.urlsafe_b64decode(urllib.parse.unquote_plus(visitor_data))[2:13].decode()
visitor_id = base64.urlsafe_b64decode(
urllib.parse.unquote_plus(visitor_data))[2:13].decode()
# check that visitor id is all letters and numbers
if re.fullmatch(r'[A-Za-z0-9_-]{11}', visitor_id):
return visitor_id