From 038699cae61a73ae326ce6f2d1df0b7be9db0088 Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Sat, 5 Apr 2025 14:29:53 +1300 Subject: [PATCH] [ie/youtube] Add a PO Token Provider Framework --- README.md | 11 + test/test_networking_utils.py | 3 +- test/test_pot/conftest.py | 72 + test/test_pot/test_pot_builtin_memorycache.py | 131 ++ test/test_pot/test_pot_builtin_utils.py | 46 + test/test_pot/test_pot_builtin_webpospec.py | 92 ++ test/test_pot/test_pot_director.py | 1396 +++++++++++++++++ test/test_pot/test_pot_framework.py | 510 ++++++ yt_dlp/extractor/youtube/_video.py | 72 +- yt_dlp/extractor/youtube/pot/README.md | 277 ++++ yt_dlp/extractor/youtube/pot/__init__.py | 1 + yt_dlp/extractor/youtube/pot/_director.py | 424 +++++ yt_dlp/extractor/youtube/pot/_provider.py | 131 ++ .../extractor/youtube/pot/builtin/__init__.py | 2 + .../youtube/pot/builtin/cache/memory.py | 82 + .../youtube/pot/builtin/cachespec/webpo.py | 46 + yt_dlp/extractor/youtube/pot/builtin/utils.py | 64 + yt_dlp/extractor/youtube/pot/cache.py | 93 ++ yt_dlp/extractor/youtube/pot/provider.py | 211 +++ yt_dlp/globals.py | 9 + yt_dlp/networking/_curlcffi.py | 3 +- yt_dlp/networking/_helper.py | 14 - yt_dlp/networking/_requests.py | 3 +- yt_dlp/networking/_urllib.py | 3 +- yt_dlp/networking/_websockets.py | 2 +- yt_dlp/utils/networking.py | 16 +- 26 files changed, 3688 insertions(+), 26 deletions(-) create mode 100644 test/test_pot/conftest.py create mode 100644 test/test_pot/test_pot_builtin_memorycache.py create mode 100644 test/test_pot/test_pot_builtin_utils.py create mode 100644 test/test_pot/test_pot_builtin_webpospec.py create mode 100644 test/test_pot/test_pot_director.py create mode 100644 test/test_pot/test_pot_framework.py create mode 100644 yt_dlp/extractor/youtube/pot/README.md create mode 100644 yt_dlp/extractor/youtube/pot/__init__.py create mode 100644 yt_dlp/extractor/youtube/pot/_director.py create mode 100644 yt_dlp/extractor/youtube/pot/_provider.py create mode 100644 yt_dlp/extractor/youtube/pot/builtin/__init__.py create mode 100644 yt_dlp/extractor/youtube/pot/builtin/cache/memory.py create mode 100644 yt_dlp/extractor/youtube/pot/builtin/cachespec/webpo.py create mode 100644 yt_dlp/extractor/youtube/pot/builtin/utils.py create mode 100644 yt_dlp/extractor/youtube/pot/cache.py create mode 100644 yt_dlp/extractor/youtube/pot/provider.py diff --git a/README.md b/README.md index c0329f5394..727b83fd84 100644 --- a/README.md +++ b/README.md @@ -1784,6 +1784,17 @@ The following extractors use this feature: * `po_token`: Proof of Origin (PO) Token(s) to use. Comma seperated list of PO Tokens in the format `CLIENT.CONTEXT+PO_TOKEN`, e.g. `youtube:po_token=web.gvs+XXX,web.player=XXX,web_safari.gvs+YYY`. Context can be either `gvs` (Google Video Server URLs) or `player` (Innertube player request) * `player_js_variant`: The player javascript variant to use for signature and nsig deciphering. The known variants are: `main`, `tce`, `tv`, `tv_es6`, `phone`, `tablet`. Only `main` is recommended as a possible workaround; the others are for debugging purposes. The default is to use what is prescribed by the site, and can be selected with `actual` +##### PO Token settings +* `po_token`: Proof of Origin (PO) Token(s) to use. Comma seperated list of PO Tokens in the format `CLIENT.CONTEXT+PO_TOKEN`, e.g. `youtube:po_token=web.gvs+XXX,web.player=XXX,web_safari.gvs+YYY`. Context can be either `gvs` (Google Video Server URLs) or `player` (Innertube player request) +* `pot_debug`: Enable debug logging for PO Token fetching. Either `true` or `false` (default `false`) +* `fetch_pot`: Policy to use for fetching a PO Token from providers. `always` to always try fetch a PO Token regardless if the client requires one for the given context. `when_required` to only fetch a PO Token if the client requires one for the given context (default) + +###### youtubepot-memory +* `max_size`: Max size of the PO Token LRU memory cache (default 25) + +###### youtubepot-webpo +* `bind_to_visitor_id`: Whether to use the Visitor ID instead of Visitor Data for caching WebPO tokens. Either `true` or `false` (default `true`) + #### youtubetab (YouTube playlists, channels, feeds, etc.) * `skip`: One or more of `webpage` (skip initial webpage download), `authcheck` (allow the download of playlists requiring authentication when no initial webpage is downloaded. This may cause unwanted behavior, see [#1122](https://github.com/yt-dlp/yt-dlp/pull/1122) for more details) * `approximate_date`: Extract approximate `upload_date` and `timestamp` in flat-playlist. This may cause date-based filters to be slightly off diff --git a/test/test_networking_utils.py b/test/test_networking_utils.py index 204fe87bda..a2feacba71 100644 --- a/test/test_networking_utils.py +++ b/test/test_networking_utils.py @@ -20,7 +20,6 @@ from yt_dlp.networking._helper import ( add_accept_encoding_header, get_redirect_method, make_socks_proxy_opts, - select_proxy, ssl_load_certs, ) from yt_dlp.networking.exceptions import ( @@ -28,7 +27,7 @@ from yt_dlp.networking.exceptions import ( IncompleteRead, ) from yt_dlp.socks import ProxyType -from yt_dlp.utils.networking import HTTPHeaderDict +from yt_dlp.utils.networking import HTTPHeaderDict, select_proxy TEST_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/test/test_pot/conftest.py b/test/test_pot/conftest.py new file mode 100644 index 0000000000..7ba46bfa26 --- /dev/null +++ b/test/test_pot/conftest.py @@ -0,0 +1,72 @@ +import pytest + +from yt_dlp import YoutubeDL +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.extractor.common import InfoExtractor +from yt_dlp.extractor.youtube.pot._provider import IEContentProviderLogger +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest, PoTokenContext +from yt_dlp.utils.networking import HTTPHeaderDict + + +class MockLogger(IEContentProviderLogger): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.messages = {} + + def trace(self, message: str): + self.messages.setdefault('trace', []).append(message) + pass + + def debug(self, message: str): + self.messages.setdefault('debug', []).append(message) + pass + + def info(self, message: str): + self.messages.setdefault('info', []).append(message) + pass + + def warning(self, message: str, *, once=False): + self.messages.setdefault('warning', []).append(message) + pass + + def error(self, message: str): + self.messages.setdefault('error', []).append(message) + pass + + +@pytest.fixture +def ie() -> InfoExtractor: + ydl = YoutubeDL() + return ydl.get_info_extractor('Youtube') + + +@pytest.fixture +def logger() -> MockLogger: + return MockLogger() + + +@pytest.fixture() +def pot_request() -> PoTokenRequest: + return PoTokenRequest( + context=PoTokenContext.GVS, + innertube_context={'client': {'clientName': 'WEB'}}, + innertube_host='youtube.com', + session_index=None, + player_url=None, + is_authenticated=False, + video_webpage=None, + + visitor_data='example-visitor-data', + data_sync_id='example-data-sync-id', + video_id='example-video-id', + + request_cookiejar=YoutubeDLCookieJar(), + request_proxy=None, + request_headers=HTTPHeaderDict(), + request_timeout=None, + request_source_address=None, + request_verify_tls=True, + + bypass_cache=False, + ) diff --git a/test/test_pot/test_pot_builtin_memorycache.py b/test/test_pot/test_pot_builtin_memorycache.py new file mode 100644 index 0000000000..79b2e15d4e --- /dev/null +++ b/test/test_pot/test_pot_builtin_memorycache.py @@ -0,0 +1,131 @@ +import threading +import time +from collections import OrderedDict +import pytest +from yt_dlp.extractor.youtube.pot._provider import IEContentProvider, BuiltInIEContentProvider +from yt_dlp.utils import bug_reports_message +from yt_dlp.extractor.youtube.pot.builtin.cache.memory import MemoryLRUPCP, memorylru_preference +from yt_dlp.version import __version__ +from yt_dlp.globals import _pot_memory_cache, _pot_cache_providers + + +class TestMemoryLRUPCS: + + def test_base_type(self): + assert issubclass(MemoryLRUPCP, IEContentProvider) + assert issubclass(MemoryLRUPCP, BuiltInIEContentProvider) + + @pytest.fixture + def pcp(self, ie, logger) -> MemoryLRUPCP: + return MemoryLRUPCP(ie, logger, {}, initialize_cache=lambda max_size: (OrderedDict(), threading.Lock(), max_size)) + + def test_is_registered(self): + assert _pot_cache_providers.value.get('MemoryLRU') == MemoryLRUPCP + + def test_initialization(self, pcp): + assert pcp.PROVIDER_NAME == 'memory' + assert pcp.PROVIDER_VERSION == __version__ + assert pcp.BUG_REPORT_MESSAGE == bug_reports_message(before='') + assert pcp.is_available() + + def test_store_and_get(self, pcp): + pcp.store('key1', 'value1', int(time.time()) + 60) + assert pcp.get('key1') == 'value1' + assert len(pcp.cache) == 1 + + def test_store_ignore_expired(self, pcp): + pcp.store('key1', 'value1', int(time.time()) - 1) + assert len(pcp.cache) == 0 + assert pcp.get('key1') is None + assert len(pcp.cache) == 0 + + def test_store_override_existing_key(self, ie, logger): + MAX_SIZE = 2 + pcp = MemoryLRUPCP(ie, logger, {}, initialize_cache=lambda max_size: (OrderedDict(), threading.Lock(), MAX_SIZE)) + pcp.store('key1', 'value1', int(time.time()) + 60) + pcp.store('key2', 'value2', int(time.time()) + 60) + assert len(pcp.cache) == 2 + pcp.store('key1', 'value2', int(time.time()) + 60) + # Ensure that the override key gets added to the end of the cache instead of in the same position + pcp.store('key3', 'value3', int(time.time()) + 60) + assert pcp.get('key1') == 'value2' + + def test_store_ignore_expired_existing_key(self, pcp): + pcp.store('key1', 'value2', int(time.time()) + 60) + pcp.store('key1', 'value1', int(time.time()) - 1) + assert len(pcp.cache) == 1 + assert pcp.get('key1') == 'value2' + assert len(pcp.cache) == 1 + + def test_get_key_expired(self, pcp): + pcp.store('key1', 'value1', int(time.time()) + 60) + assert pcp.get('key1') == 'value1' + assert len(pcp.cache) == 1 + pcp.cache['key1'] = ('value1', int(time.time()) - 1) + assert pcp.get('key1') is None + assert len(pcp.cache) == 0 + + def test_lru_eviction(self, ie, logger): + MAX_SIZE = 2 + provider = MemoryLRUPCP(ie, logger, {}, initialize_cache=lambda max_size: (OrderedDict(), threading.Lock(), MAX_SIZE)) + provider.store('key1', 'value1', int(time.time()) + 5) + provider.store('key2', 'value2', int(time.time()) + 5) + assert len(provider.cache) == 2 + + assert provider.get('key1') == 'value1' + + provider.store('key3', 'value3', int(time.time()) + 5) + assert len(provider.cache) == 2 + + assert provider.get('key2') is None + + provider.store('key4', 'value4', int(time.time()) + 5) + assert len(provider.cache) == 2 + + assert provider.get('key1') is None + assert provider.get('key3') == 'value3' + assert provider.get('key4') == 'value4' + + def test_delete(self, pcp): + pcp.store('key1', 'value1', int(time.time()) + 5) + assert len(pcp.cache) == 1 + assert pcp.get('key1') == 'value1' + pcp.delete('key1') + assert len(pcp.cache) == 0 + assert pcp.get('key1') is None + + def test_max_size_setting(self, ie, logger): + MAX_SIZE = 5 + called_max_size = 0 + + def initialize_cache(max_size): + nonlocal called_max_size + called_max_size = max_size + return OrderedDict(), threading.Lock(), max_size + + pcp = MemoryLRUPCP(ie, logger, {'max_size': [str(MAX_SIZE)]}, initialize_cache=initialize_cache) + + assert called_max_size == MAX_SIZE + assert pcp.max_size == MAX_SIZE + + def test_use_global_cache_default(self, ie, logger): + pcp = MemoryLRUPCP(ie, logger, {}) + assert pcp.max_size == _pot_memory_cache.value['max_size'] == 25 + assert pcp.cache is _pot_memory_cache.value['cache'] + assert pcp.lock is _pot_memory_cache.value['lock'] + + pcp2 = MemoryLRUPCP(ie, logger, {}) + assert pcp.max_size == pcp2.max_size == _pot_memory_cache.value['max_size'] == 25 + assert pcp.cache is pcp2.cache is _pot_memory_cache.value['cache'] + assert pcp.lock is pcp2.lock is _pot_memory_cache.value['lock'] + + def test_fail_max_size_change_global(self, ie, logger): + pcp = MemoryLRUPCP(ie, logger, {}) + assert pcp.max_size == _pot_memory_cache.value['max_size'] == 25 + with pytest.raises(ValueError, match='Cannot change max_size of initialized global memory cache'): + MemoryLRUPCP(ie, logger, {'max_size': ['50']}) + + assert pcp.max_size == _pot_memory_cache.value['max_size'] == 25 + + def test_memory_lru_preference(self, pcp, ie, pot_request): + assert memorylru_preference(pcp, pot_request) == 10000 diff --git a/test/test_pot/test_pot_builtin_utils.py b/test/test_pot/test_pot_builtin_utils.py new file mode 100644 index 0000000000..d65839c588 --- /dev/null +++ b/test/test_pot/test_pot_builtin_utils.py @@ -0,0 +1,46 @@ +import pytest +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenContext, + +) + +from yt_dlp.extractor.youtube.pot.builtin.utils import get_webpo_content_binding, ContentBindingType + + +class TestGetWebPoContentBinding: + + @pytest.mark.parametrize('client_name, context, is_authenticated, expected', [ + *[(client, context, is_authenticated, expected) for client in [ + 'WEB', 'MWEB', 'TVHTML5', 'WEB_EMBEDDED_PLAYER', 'WEB_CREATOR', 'TVHTML5_SIMPLY_EMBEDDED_PLAYER'] + for context, is_authenticated, expected in [ + (PoTokenContext.GVS, False, ('example-visitor-data', ContentBindingType.VISITOR_DATA)), + (PoTokenContext.PLAYER, False, ('example-video-id', ContentBindingType.VIDEO_ID)), + (PoTokenContext.GVS, True, ('example-data-sync-id', ContentBindingType.DATASYNC_ID)), + ]], + ('WEB_REMIX', PoTokenContext.GVS, False, ('example-visitor-data', ContentBindingType.VISITOR_DATA)), + ('WEB_REMIX', PoTokenContext.PLAYER, False, ('example-visitor-data', ContentBindingType.VISITOR_DATA)), + ('ANDROID', PoTokenContext.GVS, False, (None, None)), + ('IOS', PoTokenContext.GVS, False, (None, None)), + ]) + def test_get_webpo_content_binding(self, pot_request, client_name, context, is_authenticated, expected): + pot_request.innertube_context['client']['clientName'] = client_name + pot_request.context = context + pot_request.is_authenticated = is_authenticated + assert get_webpo_content_binding(pot_request) == expected + + def test_extract_visitor_id(self, pot_request): + pot_request.visitor_data = 'CgsxMjM0NTY3ODkwMSiA4s-qBg%3D%3D' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == ('12345678901', ContentBindingType.VISITOR_ID) + + def test_invalid_visitor_id(self, pot_request): + # visitor id not alphanumeric (i.e. protobuf extraction failed) + pot_request.visitor_data = 'CggxMjM0NTY3OCiA4s-qBg%3D%3D' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == (pot_request.visitor_data, ContentBindingType.VISITOR_DATA) + + def test_no_visitor_id(self, pot_request): + pot_request.visitor_data = 'KIDiz6oG' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == (pot_request.visitor_data, ContentBindingType.VISITOR_DATA) + + def test_invalid_base64(self, pot_request): + pot_request.visitor_data = 'invalid-base64' + assert get_webpo_content_binding(pot_request, bind_to_visitor_id=True) == (pot_request.visitor_data, ContentBindingType.VISITOR_DATA) diff --git a/test/test_pot/test_pot_builtin_webpospec.py b/test/test_pot/test_pot_builtin_webpospec.py new file mode 100644 index 0000000000..b557e41e2f --- /dev/null +++ b/test/test_pot/test_pot_builtin_webpospec.py @@ -0,0 +1,92 @@ +import pytest + +from yt_dlp.extractor.youtube.pot._provider import IEContentProvider, BuiltInIEContentProvider +from yt_dlp.extractor.youtube.pot.cache import CacheProviderWritePolicy +from yt_dlp.utils import bug_reports_message +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + +) +from yt_dlp.version import __version__ + +from yt_dlp.extractor.youtube.pot.builtin.cachespec.webpo import WebPoPCSP +from yt_dlp.globals import _pot_pcs_providers + + +@pytest.fixture() +def pot_request(pot_request) -> PoTokenRequest: + pot_request.visitor_data = 'CgsxMjM0NTY3ODkwMSiA4s-qBg%3D%3D' # visitor_id=12345678901 + return pot_request + + +class TestWebPoPCSP: + def test_base_type(self): + assert issubclass(WebPoPCSP, IEContentProvider) + assert issubclass(WebPoPCSP, BuiltInIEContentProvider) + + def test_init(self, ie, logger): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + assert pcs.PROVIDER_NAME == 'webpo' + assert pcs.PROVIDER_VERSION == __version__ + assert pcs.BUG_REPORT_MESSAGE == bug_reports_message(before='') + assert pcs.is_available() + + def test_is_registered(self): + assert _pot_pcs_providers.value.get('WebPo') == WebPoPCSP + + @pytest.mark.parametrize('client_name, context, is_authenticated', [ + ('ANDROID', PoTokenContext.GVS, False), + ('IOS', PoTokenContext.GVS, False), + ('IOS', PoTokenContext.PLAYER, False), + ]) + def test_not_supports(self, ie, logger, pot_request, client_name, context, is_authenticated): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + pot_request.innertube_context['client']['clientName'] = client_name + pot_request.context = context + pot_request.is_authenticated = is_authenticated + assert pcs.generate_cache_spec(pot_request) is None + + @pytest.mark.parametrize('client_name, context, is_authenticated, remote_host, source_address, request_proxy, expected', [ + *[(client, context, is_authenticated, remote_host, source_address, request_proxy, expected) for client in [ + 'WEB', 'MWEB', 'TVHTML5', 'WEB_EMBEDDED_PLAYER', 'WEB_CREATOR', 'TVHTML5_SIMPLY_EMBEDDED_PLAYER'] + for context, is_authenticated, remote_host, source_address, request_proxy, expected in [ + (PoTokenContext.GVS, False, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': '12345678901', 'cbt': 'visitor_id'}), + (PoTokenContext.PLAYER, False, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': '12345678901', 'cbt': 'video_id'}), + (PoTokenContext.GVS, True, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': 'example-data-sync-id', 'cbt': 'datasync_id'}), + ]], + ('WEB_REMIX', PoTokenContext.PLAYER, False, 'example-remote-host', 'example-source-address', 'example-request-proxy', {'t': 'webpo', 'ip': 'example-remote-host', 'sa': 'example-source-address', 'px': 'example-request-proxy', 'cb': '12345678901', 'cbt': 'visitor_id'}), + ('WEB', PoTokenContext.GVS, False, None, None, None, {'t': 'webpo', 'cb': '12345678901', 'cbt': 'visitor_id', 'ip': None, 'sa': None, 'px': None}), + ('TVHTML5', PoTokenContext.PLAYER, False, None, None, 'http://example.com', {'t': 'webpo', 'cb': '12345678901', 'cbt': 'video_id', 'ip': None, 'sa': None, 'px': 'http://example.com'}), + + ]) + def test_generate_key_bindings(self, ie, logger, pot_request, client_name, context, is_authenticated, remote_host, source_address, request_proxy, expected): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + pot_request.innertube_context['client']['clientName'] = client_name + pot_request.context = context + pot_request.is_authenticated = is_authenticated + pot_request.innertube_context['client']['remoteHost'] = remote_host + pot_request.request_source_address = source_address + pot_request.request_proxy = request_proxy + pot_request.video_id = '12345678901' # same as visitor id to test type + + assert pcs.generate_cache_spec(pot_request).key_bindings == expected + + def test_no_bind_visitor_id(self, ie, logger, pot_request): + # Should not bind to visitor id if setting is set to False + pcs = WebPoPCSP(ie=ie, logger=logger, settings={'bind_to_visitor_id': ['false']}) + pot_request.innertube_context['client']['clientName'] = 'WEB' + pot_request.context = PoTokenContext.GVS + pot_request.is_authenticated = False + assert pcs.generate_cache_spec(pot_request).key_bindings == {'t': 'webpo', 'ip': None, 'sa': None, 'px': None, 'cb': 'CgsxMjM0NTY3ODkwMSiA4s-qBg%3D%3D', 'cbt': 'visitor_data'} + + def test_default_ttl(self, ie, logger, pot_request): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + assert pcs.generate_cache_spec(pot_request).default_ttl == 6 * 60 * 60 # should default to 6 hours + + def test_write_policy(self, ie, logger, pot_request): + pcs = WebPoPCSP(ie=ie, logger=logger, settings={}) + pot_request.context = PoTokenContext.GVS + assert pcs.generate_cache_spec(pot_request).write_policy == CacheProviderWritePolicy.WRITE_ALL + pot_request.context = PoTokenContext.PLAYER + assert pcs.generate_cache_spec(pot_request).write_policy == CacheProviderWritePolicy.WRITE_FIRST diff --git a/test/test_pot/test_pot_director.py b/test/test_pot/test_pot_director.py new file mode 100644 index 0000000000..83b3ebaf90 --- /dev/null +++ b/test/test_pot/test_pot_director.py @@ -0,0 +1,1396 @@ +from __future__ import annotations +import abc +import base64 +import dataclasses +import hashlib +import json +import time +import pytest + +from yt_dlp.extractor.youtube.pot._provider import BuiltInIEContentProvider, IEContentProvider + +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + PoTokenProviderError, + PoTokenProviderRejectedRequest, +) +from yt_dlp.extractor.youtube.pot._director import ( + PoTokenCache, + validate_cache_spec, + clean_pot, + validate_response, + PoTokenRequestDirector, + provider_display_list, +) + +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + PoTokenCacheProvider, + CacheProviderWritePolicy, + PoTokenCacheProviderError, +) + + +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenResponse, + PoTokenProvider, +) + + +class BaseMockPoTokenProvider(PoTokenProvider, abc.ABC): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.available_called_times = 0 + self.request_called_times = 0 + self.close_called = False + + def is_available(self) -> bool: + self.available_called_times += 1 + return True + + def request_pot(self, *args, **kwargs): + self.request_called_times += 1 + return super().request_pot(*args, **kwargs) + + def close(self): + self.close_called = True + super().close() + + +class ExamplePTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + _SUPPORTED_CLIENTS = ('WEB',) + _SUPPORTED_CONTEXTS = (PoTokenContext.GVS, ) + _SUPPORTED_PROXY_SCHEMES = ('socks5', 'http') + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + if request.data_sync_id == 'example': + return PoTokenResponse(request.video_id) + return PoTokenResponse(EXAMPLE_PO_TOKEN) + + +def success_ptp(response: PoTokenResponse | None = None, key: str | None = None): + class SuccessPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'success' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://success.example.com/issues' + + _SUPPORTED_CLIENTS = ('WEB',) + _SUPPORTED_CONTEXTS = (PoTokenContext.GVS,) + _SUPPORTED_PROXY_SCHEMES = ('socks5', 'http') + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + return response or PoTokenResponse(EXAMPLE_PO_TOKEN) + + if key: + SuccessPTP.PROVIDER_KEY = key + return SuccessPTP + + +@pytest.fixture +def pot_provider(ie, logger): + return success_ptp()(ie=ie, logger=logger, settings={}) + + +class UnavailablePTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'unavailable' + BUG_REPORT_LOCATION = 'https://unavailable.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def is_available(self) -> bool: + super().is_available() + return False + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderError('something went wrong') + + +class UnsupportedPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'unsupported' + BUG_REPORT_LOCATION = 'https://unsupported.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('unsupported request') + + +class ErrorPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'error' + BUG_REPORT_LOCATION = 'https://error.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + expected = request.video_id == 'expected' + raise PoTokenProviderError('an error occurred', expected=expected) + + +class UnexpectedErrorPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'unexpected_error' + BUG_REPORT_LOCATION = 'https://unexpected.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise ValueError('an unexpected error occurred') + + +class InvalidPTP(BaseMockPoTokenProvider): + PROVIDER_NAME = 'invalid' + BUG_REPORT_LOCATION = 'https://invalid.example.com/issues' + _SUPPORTED_CLIENTS = None + _SUPPORTED_CONTEXTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + if request.video_id == 'invalid_type': + return 'invalid-response' + else: + return PoTokenResponse('example-token?', expires_at='123') + + +class BaseMockCacheSpecProvider(PoTokenCacheSpecProvider, abc.ABC): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.generate_called_times = 0 + self.is_available_called_times = 0 + self.close_called = False + + def is_available(self) -> bool: + self.is_available_called_times += 1 + return super().is_available() + + def generate_cache_spec(self, request: PoTokenRequest): + self.generate_called_times += 1 + + def close(self): + self.close_called = True + super().close() + + +class ExampleCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=60, + ) + + +class UnavailableCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'unavailable' + PROVIDER_VERSION = '0.0.1' + + def is_available(self) -> bool: + super().is_available() + return False + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return None + + +class UnsupportedCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'unsupported' + PROVIDER_VERSION = '0.0.1' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return None + + +class InvalidSpecCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'invalid' + PROVIDER_VERSION = '0.0.1' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return 'invalid-spec' + + +class ErrorSpecCacheSpecProviderPCSP(BaseMockCacheSpecProvider): + + PROVIDER_NAME = 'invalid' + PROVIDER_VERSION = '0.0.1' + + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + raise ValueError('something went wrong') + + +class BaseMockCacheProvider(PoTokenCacheProvider, abc.ABC): + BUG_REPORT_MESSAGE = 'example bug report message' + + def __init__(self, *args, available=True, **kwargs): + super().__init__(*args, **kwargs) + self.store_calls = 0 + self.delete_calls = 0 + self.get_calls = 0 + self.available = available + + def is_available(self) -> bool: + return self.available + + def store(self, *args, **kwargs): + self.store_calls += 1 + + def delete(self, *args, **kwargs): + self.delete_calls += 1 + + def get(self, *args, **kwargs): + self.get_calls += 1 + + def close(self): + self.close_called = True + super().close() + + +class ErrorPCP(BaseMockCacheProvider): + PROVIDER_NAME = 'error' + + def store(self, *args, **kwargs): + super().store(*args, **kwargs) + raise PoTokenCacheProviderError('something went wrong') + + def get(self, *args, **kwargs): + super().get(*args, **kwargs) + raise PoTokenCacheProviderError('something went wrong') + + +class UnexpectedErrorPCP(BaseMockCacheProvider): + PROVIDER_NAME = 'unexpected_error' + + def store(self, *args, **kwargs): + super().store(*args, **kwargs) + raise ValueError('something went wrong') + + def get(self, *args, **kwargs): + super().get(*args, **kwargs) + raise ValueError('something went wrong') + + +class MockMemoryPCP(BaseMockCacheProvider): + PROVIDER_NAME = 'memory' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cache = {} + + def store(self, key, value, expires_at): + super().store(key, value, expires_at) + self.cache[key] = (value, expires_at) + + def delete(self, key): + super().delete(key) + self.cache.pop(key, None) + + def get(self, key): + super().get(key) + return self.cache.get(key, [None])[0] + + +def create_memory_pcp(ie, logger, provider_key='memory'): + cache = MockMemoryPCP(ie, logger, {}) + cache.PROVIDER_KEY = provider_key + return cache + + +@pytest.fixture +def memorypcp(ie, logger) -> MockMemoryPCP: + return create_memory_pcp(ie, logger) + + +@pytest.fixture +def pot_cache(ie, logger): + class MockPoTokenCache(PoTokenCache): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.get_calls = 0 + self.store_calls = 0 + self.close_called = False + + def get(self, *args, **kwargs): + self.get_calls += 1 + return super().get(*args, **kwargs) + + def store(self, *args, **kwargs): + self.store_calls += 1 + return super().store(*args, **kwargs) + + def close(self): + self.close_called = True + super().close() + + return MockPoTokenCache( + cache_providers=[MockMemoryPCP(ie, logger, {})], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie, logger, settings={})], + logger=logger, + ) + + +EXAMPLE_PO_TOKEN = base64.urlsafe_b64encode(b'example-token').decode() + + +class TestPoTokenCache: + + def test_cache_success(self, memorypcp, pot_request, ie, logger): + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + + cached_response = cache.get(pot_request) + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert cache.get(dataclasses.replace(pot_request, video_id='another-video-id')) is None + + def test_unsupported_cache_spec_no_fallback(self, memorypcp, pot_request, ie, logger): + unsupported_provider = UnsupportedCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unsupported_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + assert cache.get(pot_request) is None + assert unsupported_provider.generate_called_times == 1 + cache.store(pot_request, response) + assert len(memorypcp.cache) == 0 + assert unsupported_provider.generate_called_times == 2 + assert cache.get(pot_request) is None + assert unsupported_provider.generate_called_times == 3 + assert len(logger.messages.get('error', [])) == 0 + + def test_unsupported_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + unsupported_provider = UnsupportedCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unsupported_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert unsupported_provider.generate_called_times == 1 + assert example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert unsupported_provider.generate_called_times == 2 + assert example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert unsupported_provider.generate_called_times == 3 + assert example_provider.generate_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert len(logger.messages.get('error', [])) == 0 + + def test_invalid_cache_spec_no_fallback(self, memorypcp, pot_request, ie, logger): + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[InvalidSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + + assert cache.get(pot_request) is None + + assert 'PoTokenCacheSpecProvider "InvalidSpecCacheSpecProvider" generate_cache_spec() returned invalid spec invalid-spec; please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_invalid_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + + invalid_provider = InvalidSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[invalid_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert invalid_provider.generate_called_times == example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert invalid_provider.generate_called_times == example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert invalid_provider.generate_called_times == example_provider.generate_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert 'PoTokenCacheSpecProvider "InvalidSpecCacheSpecProvider" generate_cache_spec() returned invalid spec invalid-spec; please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_unavailable_cache_spec_no_fallback(self, memorypcp, pot_request, ie, logger): + unavailable_provider = UnavailableCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unavailable_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert unavailable_provider.generate_called_times == 0 + + def test_unavailable_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + unavailable_provider = UnavailableCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[unavailable_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert unavailable_provider.generate_called_times == 0 + assert unavailable_provider.is_available_called_times == 1 + assert example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert unavailable_provider.generate_called_times == 0 + assert unavailable_provider.is_available_called_times == 2 + assert example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert unavailable_provider.generate_called_times == 0 + assert unavailable_provider.is_available_called_times == 3 + assert example_provider.generate_called_times == 3 + assert example_provider.is_available_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + def test_unexpected_error_cache_spec(self, memorypcp, pot_request, ie, logger): + error_provider = ErrorSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[error_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert error_provider.generate_called_times == 3 + assert error_provider.is_available_called_times == 3 + + assert 'Error occurred with "invalid" PO Token cache spec provider: ValueError(\'something went wrong\'); please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_unexpected_error_cache_spec_fallback(self, memorypcp, pot_request, ie, logger): + error_provider = ErrorSpecCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[error_provider, example_provider], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + assert cache.get(pot_request) is None + assert error_provider.generate_called_times == 1 + assert error_provider.is_available_called_times == 1 + assert example_provider.generate_called_times == 1 + + cache.store(pot_request, response) + assert error_provider.generate_called_times == 2 + assert error_provider.is_available_called_times == 2 + assert example_provider.generate_called_times == 2 + + cached_response = cache.get(pot_request) + assert error_provider.generate_called_times == 3 + assert error_provider.is_available_called_times == 3 + assert example_provider.generate_called_times == 3 + assert example_provider.is_available_called_times == 3 + assert cached_response is not None + assert cached_response.po_token == EXAMPLE_PO_TOKEN + assert cached_response.expires_at is not None + + assert 'Error occurred with "invalid" PO Token cache spec provider: ValueError(\'something went wrong\'); please report this issue to the provider developer at (developer has not provided a bug report location) .' in logger.messages['error'] + + def test_key_bindings_spec_provider(self, memorypcp, pot_request, ie, logger): + + class ExampleProviderPCSP(PoTokenCacheSpecProvider): + PROVIDER_NAME = 'example' + + def generate_cache_spec(self, request: PoTokenRequest): + return PoTokenCacheSpec( + key_bindings={'v': request.video_id}, + default_ttl=60, + ) + + class ExampleProviderTwoPCSP(ExampleProviderPCSP): + pass + + example_provider = ExampleProviderPCSP(ie=ie, logger=logger, settings={}) + example_provider_two = ExampleProviderTwoPCSP(ie=ie, logger=logger, settings={}) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[example_provider], + logger=logger, + ) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert len(memorypcp.cache) == 1 + assert hashlib.sha256(f'_pExampleProvider_ytv1v{pot_request.video_id}'.encode()).hexdigest() in memorypcp.cache + + # The second spec provider returns the exact same key bindings as the first one, + # however the PoTokenCache should use the provider key to differentiate between them + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[example_provider_two], + logger=logger, + ) + + assert cache.get(pot_request) is None + cache.store(pot_request, response) + assert len(memorypcp.cache) == 2 + assert hashlib.sha256(f'_pExampleProviderTwo_ytv1v{pot_request.video_id}'.encode()).hexdigest() in memorypcp.cache + + def test_cache_provider_preferences(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert len(pcp_one.cache) == 1 + assert len(pcp_two.cache) == 0 + + assert cache.get(pot_request) + assert pcp_one.get_calls == 1 + assert pcp_two.get_calls == 0 + + standard_preference_called = False + pcp_one_preference_claled = False + + def standard_preference(provider, request, *_, **__): + nonlocal standard_preference_called + standard_preference_called = True + assert isinstance(provider, PoTokenCacheProvider) + assert isinstance(request, PoTokenRequest) + return 1 + + def pcp_one_preference(provider, request, *_, **__): + nonlocal pcp_one_preference_claled + pcp_one_preference_claled = True + assert isinstance(provider, PoTokenCacheProvider) + assert isinstance(request, PoTokenRequest) + if provider.PROVIDER_KEY == pcp_one.PROVIDER_KEY: + return -100 + return 0 + + # test that it can hanldle multiple preferences + cache.cache_provider_preferences.append(standard_preference) + cache.cache_provider_preferences.append(pcp_one_preference) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert cache.get(pot_request) + assert len(pcp_one.cache) == len(pcp_two.cache) == 1 + assert pcp_two.get_calls == pcp_one.get_calls == 1 + assert pcp_one.store_calls == pcp_two.store_calls == 1 + assert standard_preference_called + assert pcp_one_preference_claled + + def test_secondary_cache_provider_hit(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + # Given the lower priority provider has the cache hit, store the response in the higher priority provider + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + assert cache.get(pot_request) + + cache.cache_providers[pcp_one.PROVIDER_KEY] = pcp_one + + def pcp_one_pref(provider, *_, **__): + if provider.PROVIDER_KEY == pcp_one.PROVIDER_KEY: + return 1 + return -1 + + cache.cache_provider_preferences.append(pcp_one_pref) + + assert cache.get(pot_request) + assert pcp_one.get_calls == 1 + assert pcp_two.get_calls == 2 + # Should write back to pcp_one (now the highest priority cache provider) + assert pcp_one.store_calls == pcp_two.store_calls == 1 + assert 'Writing PO Token response to highest priority cache provider' in logger.messages['trace'] + + def test_cache_provider_no_hits(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + assert cache.get(pot_request) is None + assert pcp_one.get_calls == pcp_two.get_calls == 1 + + def test_get_invalid_po_token_response(self, pot_request, ie, logger): + # Test various scenarios where the po token response stored in the cache provider is invalid + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + valid_response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, valid_response) + assert len(pcp_one.cache) == len(pcp_two.cache) == 1 + # Overwrite the valid response with an invalid one in the cache + pcp_one.store(next(iter(pcp_one.cache.keys())), json.dumps(dataclasses.asdict(PoTokenResponse(None))), int(time.time() + 1000)) + assert cache.get(pot_request).po_token == valid_response.po_token + assert pcp_one.get_calls == pcp_two.get_calls == 1 + assert pcp_one.delete_calls == 1 # Invalid response should be deleted from cache + assert pcp_one.store_calls == 3 # Since response was fetched from second cache provider, it should be stored in the first one + assert len(pcp_one.cache) == 1 + assert 'Invalid PO Token response retrieved from cache provider "memory": {"po_token": null, "expires_at": null}; example bug report message' in logger.messages['error'] + + # Overwrite the valid response with an invalid json in the cache + pcp_one.store(next(iter(pcp_one.cache.keys())), 'invalid-json', int(time.time() + 1000)) + assert cache.get(pot_request).po_token == valid_response.po_token + assert pcp_one.get_calls == pcp_two.get_calls == 2 + assert pcp_one.delete_calls == 2 + assert pcp_one.store_calls == 5 # 3 + 1 store we made in the test + 1 store from lower priority cache provider + assert len(pcp_one.cache) == 1 + + assert 'Invalid PO Token response retrieved from cache provider "memory": invalid-json; example bug report message' in logger.messages['error'] + + # Valid json, but missing required fields + pcp_one.store(next(iter(pcp_one.cache.keys())), '{"unknown_param": 0}', int(time.time() + 1000)) + assert cache.get(pot_request).po_token == valid_response.po_token + assert pcp_one.get_calls == pcp_two.get_calls == 3 + assert pcp_one.delete_calls == 3 + assert pcp_one.store_calls == 7 # 5 + 1 store from test + 1 store from lower priority cache provider + assert len(pcp_one.cache) == 1 + + assert 'Invalid PO Token response retrieved from cache provider "memory": {"unknown_param": 0}; example bug report message' in logger.messages['error'] + + def test_store_invalid_po_token_response(self, pot_request, ie, logger): + # Should not store an invalid po token response + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + + cache = PoTokenCache( + cache_providers=[pcp_one], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(po_token=EXAMPLE_PO_TOKEN, expires_at=80)) + assert cache.get(pot_request) is None + assert pcp_one.store_calls == 0 + assert 'Invalid PO Token response provided to PoTokenCache.store()' in logger.messages['error'][0] + + def test_store_write_policy(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_FIRST) + assert pcp_one.store_calls == 1 + assert pcp_two.store_calls == 0 + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN), write_policy=CacheProviderWritePolicy.WRITE_ALL) + assert pcp_one.store_calls == 2 + assert pcp_two.store_calls == 1 + + def test_store_write_first_policy_cache_spec(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + class WriteFirstPCSP(BaseMockCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=60, + write_policy=CacheProviderWritePolicy.WRITE_FIRST, + ) + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[WriteFirstPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + assert pcp_one.store_calls == 1 + assert pcp_two.store_calls == 0 + + def test_store_write_all_policy_cache_spec(self, pot_request, ie, logger): + pcp_one = create_memory_pcp(ie, logger, provider_key='memory_pcp_one') + pcp_two = create_memory_pcp(ie, logger, provider_key='memory_pcp_two') + + class WriteAllPCSP(BaseMockCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=60, + write_policy=CacheProviderWritePolicy.WRITE_ALL, + ) + + cache = PoTokenCache( + cache_providers=[pcp_one, pcp_two], + cache_spec_providers=[WriteAllPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + assert pcp_one.store_calls == 1 + assert pcp_two.store_calls == 1 + + def test_expires_at_pot_response(self, pot_request, memorypcp, ie, logger): + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=10000000000) + cache.store(pot_request, response) + assert next(iter(memorypcp.cache.values()))[1] == 10000000000 + + def test_expires_at_default_spec(self, pot_request, memorypcp, ie, logger): + + class TtlPCSP(BaseMockCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + super().generate_cache_spec(request) + return PoTokenCacheSpec( + key_bindings={'v': request.video_id, 'e': None}, + default_ttl=10000000000, + ) + + cache = PoTokenCache( + cache_providers=[memorypcp], + cache_spec_providers=[TtlPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert next(iter(memorypcp.cache.values()))[1] >= 10000000000 + + def test_cache_provider_error_no_fallback(self, pot_request, ie, logger): + error_pcp = ErrorPCP(ie, logger, {}) + cache = PoTokenCache( + cache_providers=[error_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 1 + + assert logger.messages['warning'].count("Error from \"error\" PO Token cache provider: PoTokenCacheProviderError('something went wrong'); example bug report message") == 2 + + def test_cache_provider_error_fallback(self, pot_request, ie, logger): + error_pcp = ErrorPCP(ie, logger, {}) + memory_pcp = create_memory_pcp(ie, logger, provider_key='memory') + + cache = PoTokenCache( + cache_providers=[error_pcp, memory_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + + # 1. Store fails for error_pcp, stored in memory_pcp + # 2. Get fails for error_pcp, fetched from memory_pcp + # 3. Since fetched from lower priority, it should be stored in the highest priority cache provider + # 4. Store fails in error_pcp. Since write policy is WRITE_FIRST, it should not try to store in memory_pcp regardless of if the store in error_pcp fails + + assert cache.get(pot_request) + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 2 # since highest priority, when fetched from lower priority, it should be stored in the highest priority cache provider + assert memory_pcp.get_calls == 1 + assert memory_pcp.store_calls == 1 + + assert logger.messages['warning'].count("Error from \"error\" PO Token cache provider: PoTokenCacheProviderError('something went wrong'); example bug report message") == 3 + + def test_cache_provider_unexpected_error_no_fallback(self, pot_request, ie, logger): + error_pcp = UnexpectedErrorPCP(ie, logger, {}) + cache = PoTokenCache( + cache_providers=[error_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + assert cache.get(pot_request) is None + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 1 + + assert logger.messages['error'].count("Error occurred with \"unexpected_error\" PO Token cache provider: ValueError('something went wrong'); example bug report message") == 2 + + def test_cache_provider_unexpected_error_fallback(self, pot_request, ie, logger): + error_pcp = UnexpectedErrorPCP(ie, logger, {}) + memory_pcp = create_memory_pcp(ie, logger, provider_key='memory') + + cache = PoTokenCache( + cache_providers=[error_pcp, memory_pcp], + cache_spec_providers=[ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={})], + logger=logger, + ) + + response = PoTokenResponse(EXAMPLE_PO_TOKEN) + cache.store(pot_request, response) + + # 1. Store fails for error_pcp, stored in memory_pcp + # 2. Get fails for error_pcp, fetched from memory_pcp + # 3. Since fetched from lower priority, it should be stored in the highest priority cache provider + # 4. Store fails in error_pcp. Since write policy is WRITE_FIRST, it should not try to store in memory_pcp regardless of if the store in error_pcp fails + + assert cache.get(pot_request) + assert error_pcp.get_calls == 1 + assert error_pcp.store_calls == 2 # since highest priority, when fetched from lower priority, it should be stored in the highest priority cache provider + assert memory_pcp.get_calls == 1 + assert memory_pcp.store_calls == 1 + + assert logger.messages['error'].count("Error occurred with \"unexpected_error\" PO Token cache provider: ValueError('something went wrong'); example bug report message") == 3 + + def test_close(self, ie, pot_request, logger): + # Should call close on the cache providers and cache specs + memory_pcp = create_memory_pcp(ie, logger, provider_key='memory') + memory2_pcp = create_memory_pcp(ie, logger, provider_key='memory2') + + spec1 = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + spec2 = UnavailableCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + + cache = PoTokenCache( + cache_providers=[memory2_pcp, memory_pcp], + cache_spec_providers=[spec1, spec2], + logger=logger, + ) + + cache.close() + assert memory_pcp.close_called + assert memory2_pcp.close_called + assert spec1.close_called + assert spec2.close_called + + +class TestPoTokenRequestDirector: + + def test_request_pot_success(self, ie, pot_request, pot_cache, pot_provider, logger): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + + def test_request_and_cache(self, ie, pot_request, pot_cache, pot_provider, logger): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 1 + assert pot_cache.get_calls == 1 + assert pot_cache.store_calls == 1 + + # Second request, should be cached + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.get_calls == 2 + assert pot_cache.store_calls == 1 + assert pot_provider.request_called_times == 1 + + def test_bypass_cache(self, ie, pot_request, pot_cache, logger, pot_provider): + pot_request.bypass_cache = True + + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 1 + assert pot_cache.get_calls == 0 + assert pot_cache.store_calls == 1 + + # Second request, should not get from cache + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 2 + assert pot_cache.get_calls == 0 + assert pot_cache.store_calls == 2 + + # POT is still cached, should get from cache + pot_request.bypass_cache = False + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_provider.request_called_times == 2 + assert pot_cache.get_calls == 1 + assert pot_cache.store_calls == 2 + + def test_clean_pot_generate(self, ie, pot_request, pot_cache, logger): + # Token should be cleaned before returning + base_token = base64.urlsafe_b64encode(b'token').decode() + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(base_token + '?extra=params'))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == base_token + assert provider.request_called_times == 1 + + # Confirm the cleaned version was stored in the cache + cached_token = pot_cache.get(pot_request) + assert cached_token.po_token == base_token + + def test_clean_pot_cache(self, ie, pot_request, pot_cache, logger, pot_provider): + # Token retrieved from cache should be cleaned before returning + base_token = base64.urlsafe_b64encode(b'token').decode() + pot_cache.store(pot_request, PoTokenResponse(base_token + '?extra=params')) + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == base_token + assert pot_cache.get_calls == 1 + assert pot_provider.request_called_times == 0 + + def test_cache_expires_at_none(self, ie, pot_request, pot_cache, logger, pot_provider): + # Should cache if expires_at=None in the response + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=None))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.store_calls == 1 + assert pot_cache.get(pot_request).po_token == EXAMPLE_PO_TOKEN + + def test_cache_expires_at_positive(self, ie, pot_request, pot_cache, logger, pot_provider): + # Should cache if expires_at is a positive number in the response + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=99999999999))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.store_calls == 1 + assert pot_cache.get(pot_request).po_token == EXAMPLE_PO_TOKEN + + @pytest.mark.parametrize('expires_at', [0, -1]) + def test_not_cache_expires_at(self, ie, pot_request, pot_cache, logger, pot_provider, expires_at): + # Should not cache if expires_at <= 0 in the response + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = success_ptp(PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=expires_at))(ie, logger, settings={}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert pot_cache.store_calls == 0 + assert pot_cache.get(pot_request) is None + + def test_no_providers(self, ie, pot_request, pot_cache, logger): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + response = director.get_po_token(pot_request) + assert response is None + + def test_try_cache_no_providers(self, ie, pot_request, pot_cache, logger): + # Should still try the cache even if no providers are configured + pot_cache.store(pot_request, PoTokenResponse(EXAMPLE_PO_TOKEN)) + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + + def test_close(self, ie, pot_request, pot_cache, pot_provider, logger): + # Should call close on the pot cache and any providers + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + + provider2 = UnavailablePTP(ie, logger, {}) + director.register_provider(pot_provider) + director.register_provider(provider2) + + director.close() + assert pot_provider.close_called + assert provider2.close_called + assert pot_cache.close_called + + def test_pot_provider_preferences(self, pot_request, pot_cache, ie, logger): + pot_request.bypass_cache = True + provider_two_pot = base64.urlsafe_b64encode(b'token2').decode() + + example_provider = success_ptp(response=PoTokenResponse(EXAMPLE_PO_TOKEN), key='exampleone')(ie, logger, settings={}) + example_provider_two = success_ptp(response=PoTokenResponse(provider_two_pot), key='exampletwo')(ie, logger, settings={}) + + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + director.register_provider(example_provider) + director.register_provider(example_provider_two) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert example_provider.request_called_times == 1 + assert example_provider_two.request_called_times == 0 + + standard_preference_called = False + example_preference_called = False + + # Test that the provider preferences are respected + def standard_preference(provider, request, *_, **__): + nonlocal standard_preference_called + standard_preference_called = True + assert isinstance(provider, PoTokenProvider) + assert isinstance(request, PoTokenRequest) + return 1 + + def example_preference(provider, request, *_, **__): + nonlocal example_preference_called + example_preference_called = True + assert isinstance(provider, PoTokenProvider) + assert isinstance(request, PoTokenRequest) + if provider.PROVIDER_KEY == example_provider.PROVIDER_KEY: + return -100 + return 0 + + # test that it can handle multiple preferences + director.register_preference(example_preference) + director.register_preference(standard_preference) + + response = director.get_po_token(pot_request) + assert response == provider_two_pot + assert example_provider.request_called_times == 1 + assert example_provider_two.request_called_times == 1 + assert standard_preference_called + assert example_preference_called + + def test_unsupported_request_no_fallback(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnsupportedPTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + + def test_unsupported_request_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one does not support the request + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnsupportedPTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 1 + assert pot_provider.request_called_times == 1 + assert 'PO Token Provider "unsupported" does not support this request, trying next available provider. Reason: unsupported request' in logger.messages['trace'] + + def test_unavailable_request_no_fallback(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnavailablePTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 0 + assert provider.available_called_times + + def test_unavailable_request_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one is unavailable + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnavailablePTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 0 + assert provider.available_called_times + assert pot_provider.request_called_times == 1 + assert pot_provider.available_called_times + # should not even try use the provider for the request + assert 'Provider preferences for this request: success=0' in logger.messages['trace'] + + def test_provider_error_no_fallback_unexpected(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = ErrorPTP(ie, logger, {}) + director.register_provider(provider) + pot_request.video_id = 'unexpected' + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Error fetching PO Token from \"error\" provider: PoTokenProviderError('an error occurred'); please report this issue to the provider developer at https://error.example.com/issues ." in logger.messages['warning'] + + def test_provider_error_no_fallback_expected(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = ErrorPTP(ie, logger, {}) + director.register_provider(provider) + pot_request.video_id = 'expected' + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Error fetching PO Token from \"error\" provider: PoTokenProviderError('an error occurred')" in logger.messages['warning'] + + def test_provider_error_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one raises an error + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = ErrorPTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 1 + assert pot_provider.request_called_times == 1 + assert "Error fetching PO Token from \"error\" provider: PoTokenProviderError('an error occurred'); please report this issue to the provider developer at https://error.example.com/issues ." in logger.messages['warning'] + + def test_provider_unexpected_error_no_fallback(self, ie, logger, pot_cache, pot_request): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnexpectedErrorPTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Unexpected error when fetching PO Token from \"unexpected_error\" provider: ValueError('an unexpected error occurred'); please report this issue to the provider developer at https://unexpected.example.com/issues ." in logger.messages['error'] + + def test_provider_unexpected_error_fallback(self, ie, logger, pot_cache, pot_request, pot_provider): + # Should fallback to the next provider if the first one raises an unexpected error + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = UnexpectedErrorPTP(ie, logger, {}) + director.register_provider(provider) + director.register_provider(pot_provider) + + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 1 + assert pot_provider.request_called_times == 1 + assert "Unexpected error when fetching PO Token from \"unexpected_error\" provider: ValueError('an unexpected error occurred'); please report this issue to the provider developer at https://unexpected.example.com/issues ." in logger.messages['error'] + + def test_invalid_po_token_response_type(self, ie, logger, pot_cache, pot_request, pot_provider): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = InvalidPTP(ie, logger, {}) + director.register_provider(provider) + + pot_request.video_id = 'invalid_type' + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert 'Invalid PO Token response received from "invalid" provider: invalid-response; please report this issue to the provider developer at https://invalid.example.com/issues .' in logger.messages['error'] + + # Should fallback to next available provider + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 2 + assert pot_provider.request_called_times == 1 + + def test_invalid_po_token_response(self, ie, logger, pot_cache, pot_request, pot_provider): + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + provider = InvalidPTP(ie, logger, {}) + director.register_provider(provider) + + response = director.get_po_token(pot_request) + assert response is None + assert provider.request_called_times == 1 + assert "Invalid PO Token response received from \"invalid\" provider: PoTokenResponse(po_token='example-token?', expires_at='123'); please report this issue to the provider developer at https://invalid.example.com/issues ." in logger.messages['error'] + + # Should fallback to next available provider + director.register_provider(pot_provider) + response = director.get_po_token(pot_request) + assert response == EXAMPLE_PO_TOKEN + assert provider.request_called_times == 2 + assert pot_provider.request_called_times == 1 + + def test_copy_request_provider(self, ie, logger, pot_cache, pot_request): + + class BadProviderPTP(BaseMockPoTokenProvider): + _SUPPORTED_CONTEXTS = None + _SUPPORTED_CLIENTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + # Providers should not modify the request object, but we should guard against it + request.video_id = 'bad' + raise PoTokenProviderRejectedRequest('bad request') + + class GoodProviderPTP(BaseMockPoTokenProvider): + _SUPPORTED_CONTEXTS = None + _SUPPORTED_CLIENTS = None + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + return PoTokenResponse(base64.urlsafe_b64encode(request.video_id.encode()).decode()) + + director = PoTokenRequestDirector(logger=logger, cache=pot_cache) + + bad_provider = BadProviderPTP(ie, logger, {}) + good_provider = GoodProviderPTP(ie, logger, {}) + + director.register_provider(bad_provider) + director.register_provider(good_provider) + + pot_request.video_id = 'good' + response = director.get_po_token(pot_request) + assert response == base64.urlsafe_b64encode(b'good').decode() + assert bad_provider.request_called_times == 1 + assert good_provider.request_called_times == 1 + assert pot_request.video_id == 'good' + + +@pytest.mark.parametrize('spec, expected', [ + (None, False), + (PoTokenCacheSpec(key_bindings={'v': 'video-id'}, default_ttl=60, write_policy=None), False), # type: ignore + (PoTokenCacheSpec(key_bindings={'v': 'video-id'}, default_ttl='invalid'), False), # type: ignore + (PoTokenCacheSpec(key_bindings='invalid', default_ttl=60), False), # type: ignore + (PoTokenCacheSpec(key_bindings={2: 'video-id'}, default_ttl=60), False), # type: ignore + (PoTokenCacheSpec(key_bindings={'v': 2}, default_ttl=60), False), # type: ignore + (PoTokenCacheSpec(key_bindings={'v': None}, default_ttl=60), False), # type: ignore + + (PoTokenCacheSpec(key_bindings={'v': 'video_id', 'e': None}, default_ttl=60), True), + (PoTokenCacheSpec(key_bindings={'v': 'video_id'}, default_ttl=60, write_policy=CacheProviderWritePolicy.WRITE_FIRST), True), +]) +def test_validate_cache_spec(spec, expected): + assert validate_cache_spec(spec) == expected + + +@pytest.mark.parametrize('po_token', [ + 'invalid-token?', + '123', +]) +def test_clean_pot_fail(po_token): + with pytest.raises(ValueError, match='Invalid PO Token'): + clean_pot(po_token) + + +@pytest.mark.parametrize('po_token,expected', [ + ('TwAA/+8=', 'TwAA_-8='), + ('TwAA%5F%2D9VA6Q92v%5FvEQ4==?extra-param=2', 'TwAA_-9VA6Q92v_vEQ4='), +]) +def test_clean_pot(po_token, expected): + assert clean_pot(po_token) == expected + + +@pytest.mark.parametrize( + 'response, expected', + [ + (None, False), + (PoTokenResponse(None), False), + (PoTokenResponse(1), False), + (PoTokenResponse('invalid-token?'), False), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at='abc'), False), # type: ignore + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=100), False), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=time.time() + 10000.0), False), # type: ignore + (PoTokenResponse(EXAMPLE_PO_TOKEN), True), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=-1), True), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=0), True), + (PoTokenResponse(EXAMPLE_PO_TOKEN, expires_at=int(time.time()) + 10000), True), + ], +) +def test_validate_pot_response(response, expected): + assert validate_response(response) == expected + + +def test_built_in_provider(ie, logger): + class BuiltInProviderDefaultT(BuiltInIEContentProvider, suffix='T'): + def is_available(self): + return True + + class BuiltInProviderCustomNameT(BuiltInIEContentProvider, suffix='T'): + PROVIDER_NAME = 'CustomName' + + def is_available(self): + return True + + class ExternalProviderDefaultT(IEContentProvider, suffix='T'): + def is_available(self): + return True + + class ExternalProviderCustomT(IEContentProvider, suffix='T'): + PROVIDER_NAME = 'custom' + PROVIDER_VERSION = '5.4b2' + + def is_available(self): + return True + + class ExternalProviderUnavailableT(IEContentProvider, suffix='T'): + def is_available(self) -> bool: + return False + + class BuiltInProviderUnavailableT(IEContentProvider, suffix='T'): + def is_available(self) -> bool: + return False + + built_in_default = BuiltInProviderDefaultT(ie=ie, logger=logger, settings={}) + built_in_custom_name = BuiltInProviderCustomNameT(ie=ie, logger=logger, settings={}) + built_in_unavailable = BuiltInProviderUnavailableT(ie=ie, logger=logger, settings={}) + external_default = ExternalProviderDefaultT(ie=ie, logger=logger, settings={}) + external_custom = ExternalProviderCustomT(ie=ie, logger=logger, settings={}) + external_unavailable = ExternalProviderUnavailableT(ie=ie, logger=logger, settings={}) + + assert provider_display_list([]) == 'none' + assert provider_display_list([built_in_default]) == 'BuiltInProviderDefault' + assert provider_display_list([external_unavailable]) == 'ExternalProviderUnavailable-0.0.0 (external, unavailable)' + assert provider_display_list([built_in_default, built_in_custom_name, external_default, external_custom, external_unavailable, built_in_unavailable]) == 'BuiltInProviderDefault, CustomName, ExternalProviderDefault-0.0.0 (external), custom-5.4b2 (external), ExternalProviderUnavailable-0.0.0 (external, unavailable), BuiltInProviderUnavailable-0.0.0 (external, unavailable)' diff --git a/test/test_pot/test_pot_framework.py b/test/test_pot/test_pot_framework.py new file mode 100644 index 0000000000..fce4f55fd8 --- /dev/null +++ b/test/test_pot/test_pot_framework.py @@ -0,0 +1,510 @@ +import pytest + +from yt_dlp.extractor.youtube.pot._provider import IEContentProvider +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.utils.networking import HTTPHeaderDict +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + +) + +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheProvider, + register_pcp, + register_pcp_preference, + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + register_pcsp, + CacheProviderWritePolicy, +) + +from yt_dlp.globals import _pot_cache_providers, _pot_cache_provider_preferences, _pot_pcs_providers + +from yt_dlp.networking import Request +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenResponse, + PoTokenProvider, + PoTokenProviderRejectedRequest, + provider_bug_report_message, + register_provider, + register_preference, +) + +from yt_dlp.globals import _pot_providers, _ptp_preferences + + +class ExamplePTP(PoTokenProvider): + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + _SUPPORTED_CLIENTS = ('WEB',) + _SUPPORTED_CONTEXTS = (PoTokenContext.GVS, ) + _SUPPORTED_PROXY_SCHEMES = ('socks5', 'http') + + def is_available(self) -> bool: + return True + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + return PoTokenResponse('example-token', expires_at=123) + + +class ExampleCacheProviderPCP(PoTokenCacheProvider): + + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + def is_available(self) -> bool: + return True + + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + +class ExampleCacheSpecProviderPCSP(PoTokenCacheSpecProvider): + + PROVIDER_NAME = 'example' + PROVIDER_VERSION = '0.0.1' + BUG_REPORT_LOCATION = 'https://example.com/issues' + + def generate_cache_spec(self, request: PoTokenRequest): + return PoTokenCacheSpec( + key_bindings={'field': 'example-key'}, + default_ttl=60, + write_policy=CacheProviderWritePolicy.WRITE_FIRST, + ) + + +class TestPoTokenProvider: + + def test_base_type(self): + assert issubclass(PoTokenProvider, IEContentProvider) + + def test_create_provider_missing_fetch_method(self, ie, logger): + class MissingMethodsPTP(PoTokenProvider): + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPTP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_available_method(self, ie, logger): + class MissingMethodsPTP(PoTokenProvider): + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + with pytest.raises(TypeError): + MissingMethodsPTP(ie=ie, logger=logger, settings={}) + + def test_barebones_provider(self, ie, logger): + class BarebonesProviderPTP(PoTokenProvider): + def is_available(self) -> bool: + return True + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + provider = BarebonesProviderPTP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'BarebonesProvider' + assert provider.PROVIDER_KEY == 'BarebonesProvider' + assert provider.PROVIDER_VERSION == '0.0.0' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at (developer has not provided a bug report location) .' + + def test_example_provider_success(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'example' + assert provider.PROVIDER_KEY == 'Example' + assert provider.PROVIDER_VERSION == '0.0.1' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + assert provider.is_available() + + response = provider.request_pot(pot_request) + + assert response.po_token == 'example-token' + assert response.expires_at == 123 + + def test_provider_unsupported_context(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + pot_request.context = PoTokenContext.PLAYER + + with pytest.raises(PoTokenProviderRejectedRequest): + provider.request_pot(pot_request) + + def test_provider_unsupported_client(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + pot_request.innertube_context['client']['clientName'] = 'ANDROID' + + with pytest.raises(PoTokenProviderRejectedRequest): + provider.request_pot(pot_request) + + def test_provider_unsupported_proxy_scheme(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + pot_request.request_proxy = 'socks4://example.com' + + with pytest.raises(PoTokenProviderRejectedRequest): + provider.request_pot(pot_request) + + def test_provider_urlopen(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + + cookiejar = YoutubeDLCookieJar() + pot_request.request_headers = HTTPHeaderDict({'User-Agent': 'example-user-agent'}) + pot_request.request_proxy = 'socks5://example-proxy.com' + pot_request.request_cookiejar = cookiejar + + def mock_urlopen(request): + return request + + ie._downloader.urlopen = mock_urlopen + + sent_request = provider._urlopen(pot_request, Request( + 'https://example.com', + )) + + assert sent_request.url == 'https://example.com' + assert sent_request.headers['User-Agent'] == 'example-user-agent' + assert sent_request.proxies == {'all': 'socks5://example-proxy.com'} + assert sent_request.extensions['cookiejar'] is cookiejar + + def test_provider_urlopen_override(self, ie, logger, pot_request): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + + cookiejar_request = YoutubeDLCookieJar() + pot_request.request_headers = HTTPHeaderDict({'User-Agent': 'example-user-agent'}) + pot_request.request_proxy = 'socks5://example-proxy.com' + pot_request.request_cookiejar = cookiejar_request + + def mock_urlopen(request): + return request + + ie._downloader.urlopen = mock_urlopen + + sent_request = provider._urlopen(pot_request, Request( + 'https://example.com', + headers={'User-Agent': 'override-user-agent-override'}, + proxies={'http': 'http://example-proxy-override.com'}, + extensions={'cookiejar': YoutubeDLCookieJar()}, + )) + + assert sent_request.url == 'https://example.com' + assert sent_request.headers['User-Agent'] == 'override-user-agent-override' + assert sent_request.proxies == {'http': 'http://example-proxy-override.com'} + assert sent_request.extensions['cookiejar'] is not cookiejar_request + + def test_get_setting(self, ie, logger): + provider = ExamplePTP(ie=ie, logger=logger, settings={'abc': ['123D'], 'xyz': ['456a', '789B']}) + + assert provider.get_setting('abc') == ['123d'] + assert provider.get_setting('abc', default=['default']) == ['123d'] + assert provider.get_setting('ABC', default=['default']) == ['default'] + assert provider.get_setting('abc', casesense=True) == ['123D'] + assert provider.get_setting('xyz', casesense=False) == ['456a', '789b'] + + def test_require_class_end_with_suffix(self, ie, logger): + class InvalidSuffix(PoTokenProvider): + PROVIDER_NAME = 'invalid-suffix' + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + def is_available(self) -> bool: + return True + + provider = InvalidSuffix(ie=ie, logger=logger, settings={}) + + with pytest.raises(AssertionError): + provider.PROVIDER_KEY # noqa: B018 + + +class TestPoTokenCacheProvider: + + def test_base_type(self): + assert issubclass(PoTokenCacheProvider, IEContentProvider) + + def test_create_provider_missing_get_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_store_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def get(self, key: str): + pass + + def delete(self, key: str): + pass + + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_delete_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def get(self, key: str): + pass + + def store(self, key: str, value: str, expires_at: int): + pass + + def is_available(self) -> bool: + return True + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_create_provider_missing_is_available_method(self, ie, logger): + class MissingMethodsPCP(PoTokenCacheProvider): + def get(self, key: str): + pass + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + with pytest.raises(TypeError): + MissingMethodsPCP(ie=ie, logger=logger, settings={}) + + def test_barebones_provider(self, ie, logger): + class BarebonesProviderPCP(PoTokenCacheProvider): + + def is_available(self) -> bool: + return True + + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + provider = BarebonesProviderPCP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'BarebonesProvider' + assert provider.PROVIDER_KEY == 'BarebonesProvider' + assert provider.PROVIDER_VERSION == '0.0.0' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at (developer has not provided a bug report location) .' + + def test_create_provider_example(self, ie, logger): + provider = ExampleCacheProviderPCP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'example' + assert provider.PROVIDER_KEY == 'ExampleCacheProvider' + assert provider.PROVIDER_VERSION == '0.0.1' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + assert provider.is_available() + + def test_get_setting(self, ie, logger): + provider = ExampleCacheProviderPCP(ie=ie, logger=logger, settings={'abc': ['123D'], 'xyz': ['456a', '789B']}) + assert provider.get_setting('abc') == ['123d'] + assert provider.get_setting('abc', default=['default']) == ['123d'] + assert provider.get_setting('ABC', default=['default']) == ['default'] + assert provider.get_setting('abc', casesense=True) == ['123D'] + assert provider.get_setting('xyz', casesense=False) == ['456a', '789b'] + + def test_require_class_end_with_suffix(self, ie, logger): + class InvalidSuffix(PoTokenCacheProvider): + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + def is_available(self) -> bool: + return True + + provider = InvalidSuffix(ie=ie, logger=logger, settings={}) + + with pytest.raises(AssertionError): + provider.PROVIDER_KEY # noqa: B018 + + +class TestPoTokenCacheSpecProvider: + + def test_base_type(self): + assert issubclass(PoTokenCacheSpecProvider, IEContentProvider) + + def test_create_provider_missing_supports_method(self, ie, logger): + class MissingMethodsPCS(PoTokenCacheSpecProvider): + pass + + with pytest.raises(TypeError): + MissingMethodsPCS(ie=ie, logger=logger, settings={}) + + def test_create_provider_barebones(self, ie, pot_request, logger): + class BarebonesProviderPCSP(PoTokenCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + return PoTokenCacheSpec( + default_ttl=100, + key_bindings={}, + ) + + provider = BarebonesProviderPCSP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'BarebonesProvider' + assert provider.PROVIDER_KEY == 'BarebonesProvider' + assert provider.PROVIDER_VERSION == '0.0.0' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at (developer has not provided a bug report location) .' + assert provider.is_available() + assert provider.generate_cache_spec(request=pot_request).default_ttl == 100 + assert provider.generate_cache_spec(request=pot_request).key_bindings == {} + assert provider.generate_cache_spec(request=pot_request).write_policy == CacheProviderWritePolicy.WRITE_ALL + + def test_create_provider_example(self, ie, pot_request, logger): + provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={}) + assert provider.PROVIDER_NAME == 'example' + assert provider.PROVIDER_KEY == 'ExampleCacheSpecProvider' + assert provider.PROVIDER_VERSION == '0.0.1' + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + assert provider.is_available() + assert provider.generate_cache_spec(pot_request) + assert provider.generate_cache_spec(pot_request).key_bindings == {'field': 'example-key'} + assert provider.generate_cache_spec(pot_request).default_ttl == 60 + assert provider.generate_cache_spec(pot_request).write_policy == CacheProviderWritePolicy.WRITE_FIRST + + def test_get_setting(self, ie, logger): + provider = ExampleCacheSpecProviderPCSP(ie=ie, logger=logger, settings={'abc': ['123D'], 'xyz': ['456a', '789B']}) + + assert provider.get_setting('abc') == ['123d'] + assert provider.get_setting('abc', default=['default']) == ['123d'] + assert provider.get_setting('ABC', default=['default']) == ['default'] + assert provider.get_setting('abc', casesense=True) == ['123D'] + assert provider.get_setting('xyz', casesense=False) == ['456a', '789b'] + + def test_require_class_end_with_suffix(self, ie, logger): + class InvalidSuffix(PoTokenCacheSpecProvider): + def generate_cache_spec(self, request: PoTokenRequest): + return None + + provider = InvalidSuffix(ie=ie, logger=logger, settings={}) + + with pytest.raises(AssertionError): + provider.PROVIDER_KEY # noqa: B018 + + +class TestPoTokenRequest: + def test_copy_request(self, pot_request): + copied_request = pot_request.copy() + + assert copied_request is not pot_request + assert copied_request.context == pot_request.context + assert copied_request.innertube_context == pot_request.innertube_context + assert copied_request.innertube_context is not pot_request.innertube_context + copied_request.innertube_context['client']['clientName'] = 'ANDROID' + assert pot_request.innertube_context['client']['clientName'] != 'ANDROID' + assert copied_request.innertube_host == pot_request.innertube_host + assert copied_request.session_index == pot_request.session_index + assert copied_request.player_url == pot_request.player_url + assert copied_request.is_authenticated == pot_request.is_authenticated + assert copied_request.visitor_data == pot_request.visitor_data + assert copied_request.data_sync_id == pot_request.data_sync_id + assert copied_request.video_id == pot_request.video_id + assert copied_request.request_cookiejar is pot_request.request_cookiejar + assert copied_request.request_proxy == pot_request.request_proxy + assert copied_request.request_headers == pot_request.request_headers + assert copied_request.request_headers is not pot_request.request_headers + assert copied_request.request_timeout == pot_request.request_timeout + assert copied_request.request_source_address == pot_request.request_source_address + assert copied_request.request_verify_tls == pot_request.request_verify_tls + assert copied_request.bypass_cache == pot_request.bypass_cache + + +def test_provider_bug_report_message(ie, logger): + provider = ExamplePTP(ie=ie, logger=logger, settings={}) + assert provider.BUG_REPORT_MESSAGE == 'please report this issue to the provider developer at https://example.com/issues .' + + message = provider_bug_report_message(provider) + assert message == '; please report this issue to the provider developer at https://example.com/issues .' + + message_before = provider_bug_report_message(provider, before='custom message!') + assert message_before == 'custom message! Please report this issue to the provider developer at https://example.com/issues .' + + +def test_register_provider(ie): + + @register_provider + class UnavailableProviderPTP(PoTokenProvider): + def is_available(self) -> bool: + return False + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + raise PoTokenProviderRejectedRequest('Not implemented') + + assert _pot_providers.value.get('UnavailableProvider') == UnavailableProviderPTP + _pot_providers.value.pop('UnavailableProvider') + + +def test_register_pot_preference(ie): + before = len(_ptp_preferences.value) + + @register_preference(ExamplePTP) + def unavailable_preference(provider: PoTokenProvider, request: PoTokenRequest): + return 1 + + assert len(_ptp_preferences.value) == before + 1 + + +def test_register_cache_provider(ie): + + @register_pcp + class UnavailableCacheProviderPCP(PoTokenCacheProvider): + def is_available(self) -> bool: + return False + + def get(self, key: str): + return 'example-cache' + + def store(self, key: str, value: str, expires_at: int): + pass + + def delete(self, key: str): + pass + + assert _pot_cache_providers.value.get('UnavailableCacheProvider') == UnavailableCacheProviderPCP + _pot_cache_providers.value.pop('UnavailableCacheProvider') + + +def test_register_cache_provider_spec(ie): + + @register_pcsp + class UnavailableCacheProviderPCSP(PoTokenCacheSpecProvider): + def is_available(self) -> bool: + return False + + def generate_cache_spec(self, request: PoTokenRequest): + return None + + assert _pot_pcs_providers.value.get('UnavailableCacheProvider') == UnavailableCacheProviderPCSP + _pot_pcs_providers.value.pop('UnavailableCacheProvider') + + +def test_register_cache_provider_preference(ie): + before = len(_pot_cache_provider_preferences.value) + + @register_pcp_preference(ExampleCacheProviderPCP) + def unavailable_preference(provider: PoTokenCacheProvider, request: PoTokenRequest): + return 1 + + assert len(_pot_cache_provider_preferences.value) == before + 1 diff --git a/yt_dlp/extractor/youtube/_video.py b/yt_dlp/extractor/youtube/_video.py index 074a2a0d8d..d731e8d6b9 100644 --- a/yt_dlp/extractor/youtube/_video.py +++ b/yt_dlp/extractor/youtube/_video.py @@ -23,6 +23,8 @@ from ._base import ( _split_innertube_client, short_client_name, ) +from .pot._director import initialize_pot_director +from .pot.provider import PoTokenContext, PoTokenRequest from ..openload import PhantomJSwrapper from ...jsinterp import JSInterpreter from ...networking.exceptions import HTTPError @@ -66,6 +68,7 @@ from ...utils import ( urljoin, variadic, ) +from ...utils.networking import clean_headers, clean_proxies, select_proxy STREAMING_DATA_CLIENT_NAME = '__yt_dlp_client' STREAMING_DATA_INITIAL_PO_TOKEN = '__yt_dlp_po_token' @@ -1784,6 +1787,11 @@ class YoutubeIE(YoutubeBaseInfoExtractor): super().__init__(*args, **kwargs) self._code_cache = {} self._player_cache = {} + self._pot_director = None + + def _real_initialize(self): + super()._real_initialize() + self._pot_director = initialize_pot_director(self) def _prepare_live_from_start_formats(self, formats, video_id, live_start_time, url, webpage_url, smuggled_data, is_live): lock = threading.Lock() @@ -2818,7 +2826,7 @@ class YoutubeIE(YoutubeBaseInfoExtractor): continue def fetch_po_token(self, client='web', context=_PoTokenContext.GVS, ytcfg=None, visitor_data=None, - data_sync_id=None, session_index=None, player_url=None, video_id=None, **kwargs): + data_sync_id=None, session_index=None, player_url=None, video_id=None, webpage=None, **kwargs): """ Fetch a PO Token for a given client and context. This function will validate required parameters for a given context and client. @@ -2832,10 +2840,14 @@ class YoutubeIE(YoutubeBaseInfoExtractor): @param session_index: session index. @param player_url: player URL. @param video_id: video ID. + @param webpage: video webpage. @param kwargs: Additional arguments to pass down. May be more added in the future. @return: The fetched PO Token. None if it could not be fetched. """ + # TODO(future): This validation should be moved into pot framework. + # Some sort of middleware or validation provider perhaps? + # GVS WebPO Token is bound to visitor_data / Visitor ID when logged out. # Must have visitor_data for it to function. if player_url and context == _PoTokenContext.GVS and not visitor_data and not self.is_authenticated: @@ -2875,11 +2887,64 @@ class YoutubeIE(YoutubeBaseInfoExtractor): session_index=session_index, player_url=player_url, video_id=video_id, + video_webpage=webpage, **kwargs, ) def _fetch_po_token(self, client, **kwargs): - """(Unstable) External PO Token fetch stub""" + + context = kwargs.get('context') + + # Avoid fetching PO Tokens when not required + if not ( + _PoTokenContext(context) in self._get_default_ytcfg(client)['PO_TOKEN_REQUIRED_CONTEXTS'] + or self._configuration_arg('fetch_pot', ['when_required'], ie_key=YoutubeIE)[0] == 'always' + ): + return None + + headers = self.get_param('http_headers').copy() + proxies = self._downloader.proxies.copy() + clean_headers(headers) + clean_proxies(proxies, headers) + + innertube_host = self._select_api_hostname(None, default_client=client) + + pot_request = PoTokenRequest( + context=PoTokenContext(context), + innertube_context=traverse_obj(kwargs, ('ytcfg', 'INNERTUBE_CONTEXT')), + innertube_host=innertube_host, + internal_client_name=client, + session_index=kwargs.get('session_index'), + player_url=kwargs.get('player_url'), + video_webpage=kwargs.get('video_webpage'), + is_authenticated=self.is_authenticated, + visitor_data=kwargs.get('visitor_data'), + data_sync_id=kwargs.get('data_sync_id'), + video_id=kwargs.get('video_id'), + request_cookiejar=self._downloader.cookiejar, + + # All requests that would need to be proxied should be in the + # context of www.youtube.com or the innertube host + request_proxy=( + select_proxy('https://www.youtube.com', proxies) + or select_proxy(f'https://{innertube_host}', proxies) + ), + request_headers=headers, + request_timeout=self.get_param('socket_timeout'), + request_verify_tls=not self.get_param('nocheckcertificate'), + request_source_address=self.get_param('source_address'), + + bypass_cache=False, + ) + + po_token = self._pot_director.get_po_token(pot_request) + + if not po_token: + self.write_debug(f'{kwargs.get("video_id")}: No {pot_request.context.value} PO Token available for {client} client') + return + + self.write_debug(f'{kwargs.get("video_id")}: Fetched a {pot_request.context.value} PO Token for {client} client') + return po_token @staticmethod def _is_agegated(player_response): @@ -3037,8 +3102,9 @@ class YoutubeIE(YoutubeBaseInfoExtractor): 'video_id': video_id, 'data_sync_id': data_sync_id if self.is_authenticated else None, 'player_url': player_url if require_js_player else None, + 'webpage': webpage, 'session_index': self._extract_session_index(master_ytcfg, player_ytcfg), - 'ytcfg': player_ytcfg, + 'ytcfg': player_ytcfg or self._get_default_ytcfg(client), } player_po_token = self.fetch_po_token( diff --git a/yt_dlp/extractor/youtube/pot/README.md b/yt_dlp/extractor/youtube/pot/README.md new file mode 100644 index 0000000000..34199cbf04 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/README.md @@ -0,0 +1,277 @@ +# YoutubeIE PO Token Provider Framework + +As part of the YouTube extractor, we have a framework for providing PO Tokens programmatically. This can be used in plugins. + +Refer to the [PO Token Guide](https://github.com/yt-dlp/yt-dlp/wiki/PO-Token-Guide) for more information on PO Tokens. + +## Public APIs + +- `yt_dlp.extractor.youtube.pot.cache` +- `yt_dlp.extractor.youtube.pot.provider` +- `yt_dlp.extractor.youtube.pot.builtin.utils` + +Everything else is internal-only and no guarantees are made about the API stability. + +> [!WARNING] +> We will try our best to maintain stability with the public APIs. +> However, due to the nature of extractors and YouTube, we may need to remove or change APIs in the future. +> If you are using these APIs outside yt-dlp plugins, please account for this by importing them safely. + +## PO Token Provider + +`yt_dlp.extractor.youtube.pot.provider` + +```python +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, + PoTokenContext, + PoTokenProvider, + PoTokenResponse, + PoTokenProviderError, + PoTokenProviderRejectedRequest, + register_provider, + register_preference, +) +from yt_dlp.networking.common import Request +from yt_dlp.extractor.youtube.pot.builtin.utils import get_webpo_content_binding +from yt_dlp.utils import traverse_obj +from yt_dlp.networking.exceptions import RequestError +import json + + +@register_provider +class MyPoTokenProviderPTP(PoTokenProvider): # Provider name must end with "PTP" + PROVIDER_VERSION = '0.2.1' + # Define a unique display name for the provider + PROVIDER_NAME = 'my-provider' + BUG_REPORT_LOCATION = 'https://issues.example.com/report' + + # -- Validation shortcuts. Set these to None to disable. -- + + # Innertube Client Name. + # For example, "WEB", "ANDROID", "TVHTML5". + # For a list of WebPO client names, see yt_dlp.extractor.youtube.pot.builtin.utils.WEBPO_CLIENTS. + # Also see yt_dlp.extractor.youtube._base.INNERTUBE_CLIENTS for a list of client names currently supported by the YouTube extractor. + _SUPPORTED_CLIENTS = ('WEB', 'TVHTML5') + + _SUPPORTED_CONTEXTS = ( + PoTokenContext.GVS, + ) + + # Possible values: http, https, socks4, socks4a, socks5, socks5h + # If your provider makes requests outside PoTokenProvider._urllib, you should set this to any proxy schemes supported. + # If you use PoTokenProvider._urllib to make requests, set to None. + _SUPPORTED_PROXY_SCHEMES = ('http',) + + def is_available(self) -> bool: + """ + Check if the provider is available (e.g. all required dependencies are available) + This is used to determine if the provider should be used and to provide debug information. + + IMPORTANT: This method SHOULD NOT make any network requests or perform any expensive operations. + + Since this is called multiple times, we recommend caching the result. + """ + return True + + def close(self): + # Optional close hook, called when YoutubeDL is closed. + pass + + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + # ℹ️ If you need to validate the request before making the request to the external source + # Raise yt_dlp.extractor.youtube.pot.provider.PoTokenProviderRejectedRequest if the request is not supported + if request.is_authenticated: + raise PoTokenProviderRejectedRequest( + 'This provider does not support authenticated requests' + ) + + # ℹ️ Settings are pulled from extractor args passed to yt-dlp with the key `youtubepot-`. + # For this example, the extractor arg would be `--extractor-args "youtubepot-mypotokenprovider:url=https://custom.example.com/get_pot"` + external_provider_url = self.get_setting('url', default=['https://provider.example.com/get_pot'])[0] + + # You should use the internal HTTP client to make requests where possible, + # as it will handle cookies and other networking settings passed to yt-dlp. + try: + # See below for logging guidelines + self.logger.info(f'Requesting {request.context.value} PO Token for {request.internal_client_name} client from external provider') + + # See docstring in _urlopen method for request tips + response = self._urlopen( + request, Request(external_provider_url, data=json.dumps({ + 'content_binding': get_webpo_content_binding(request), + 'proxy': request.request_proxy, + 'headers': request.request_headers, + 'source_address': request.request_source_address, + 'verify_tls': request.request_verify_tls, + # Important: If your provider has its own caching, please respect `bypass_cache`. + # This may be used in the future to request a fresh PO Token if required. + 'do_not_cache': request.bypass_cache, + }).encode(), proxies={'all': None})) + + except RequestError as e: + # ℹ️ If there is an error, raise PoTokenProviderError. + # You can specify whether it is expected or not. If it is unexpected, the log will include a link to the bug report location (BUG_REPORT_LOCATION). + raise PoTokenProviderError( + 'Networking error while fetching to get PO Token from external provider', + expected=True + ) from e + + # Note: PO Token is expected to be base64url encoded + po_token = traverse_obj(response, 'po_token') + if not po_token: + raise PoTokenProviderError( + 'Bad PO Token Response from external provider', + expected=False + ) + + return PoTokenResponse( + po_token=po_token, + # Optional, add a custom expiration time for the token. Use for caching. + # By default, yt-dlp will use the default ttl from a registered cache spec (see below) + # Set to 0 or -1 to not cache this response. + expires_at=None, + ) + + +# If there are multiple PO Token Providers that can handle the same PoTokenRequest, +# you can define a preference function to increase/decrease the priority of providers. + +@register_preference(MyPoTokenProviderPTP) +def my_provider_preference(provider: PoTokenProvider, request: PoTokenRequest, *_, **__) -> int: + return 50 +``` + +## Logging Guidelines + +- Use the `self.logger` object to log messages. +- When making HTTP requests, use `self.logger.info` to log a message to standard non-verbose output. This lets users know what is happening when a time-expensive operation is taking place. + - For example, `self.logger.info(f'Requesting {request.context.value} PO Token for {request.internal_client_name} client from external provider')` +- Use `self.logger.debug` to log a message to the verbose output (`--verbose`). Try to keep this to a minimum. +- Use `self.logger.trace` to log a message to the PO Token debug output (`--extractor-args "youtube:pot_debug=true"`). Log as much as you like here as needed for debugging your provider. +- Avoid logging PO Tokens or any sensitive information to debug or info output. + +## Debugging + +- Use `-v --extractor-args "youtube:pot_debug=true"` to enable PO Token debug output. + +## Caching + +> [!IMPORTANT] +> yt-dlp currently has a built-in LRU Memory Cache Provider and a cache spec provider for WebPO Tokens. +> You should only need to implement cache providers if you want an external cache, or a cache spec if you are handling non-WebPO Tokens. + +### Cache Providers + +`yt_dlp.extractor.youtube.pot.cache` + +```python +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheProvider, + register_pcp_preference, + register_pcp +) + +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest + + +@register_pcp +class MyCacheProviderPCP(PoTokenCacheProvider): # Provider name must end with "PCP" + PROVIDER_VERSION = '0.1.0' + # Define a unique display name for the provider + PROVIDER_NAME = 'my-cache-provider' + BUG_REPORT_LOCATION = 'https://issues.example.com/report' + + def is_available(self) -> bool: + """ + Check if the provider is available (e.g. all required dependencies are available) + This is used to determine if the provider should be used and to provide debug information. + + IMPORTANT: This method SHOULD NOT make any network requests or perform any expensive operations. + + Since this is called multiple times, we recommend caching the result. + """ + return True + + """ + Implement the below cache operations. + + - expires_at is a timestamp in UTC. It MUST be respected - cache entries should not be returned if they have expired. + """ + + def get(self, key: str): + # ℹ️ Similar to PO Token Providers, Cache Providers and Cache Spec Providers are passed down extractor args matching key youtubepot-. + some_setting = self.get_setting('some_setting', default=['default_value'])[0] + return self.my_cache.get(key) + + def store(self, key: str, value: str, expires_at: str): + self.my_cache.store(key, value, expires_at) + + def delete(self, key: str): + self.my_cache.delete(key) + + def close(self): + # Optional close hook, called when YoutubeDL is closed. + pass + +# If there are multiple PO Token Cache Providers available, you can define a preference function to increase/decrease the priority of providers. +# IMPORTANT: Providers should be in preference of cache lookup time. For example, a memory cache should have a higher preference than a disk cache. +# VERY IMPORTANT: yt-dlp has a built-in memory cache with a priority of 10000. Your cache provider should be lower than this. + + +@register_pcp_preference(MyCacheProviderPCP) +def my_cache_preference(provider: PoTokenCacheProvider, request: PoTokenRequest, *_, **__) -> int: + return 50 +``` + +### Cache Specs + +`yt_dlp.extractor.youtube.pot.cache` + +These are used to provide information on how to cache a particular PO Token Request. +You might have a different cache spec for different kinds of PO Tokens (e.g. WebPO vs iOSGuard) + +```python +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + CacheProviderWritePolicy, + register_pcsp, +) +from yt_dlp.utils import traverse_obj +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest + + +@register_pcsp +class MyCacheSpecProviderPCSP(PoTokenCacheSpecProvider): # Provider name must end with "PCSP" + PROVIDER_VERSION = '0.1.0' + # Define a unique display name for the provider + PROVIDER_NAME = 'mycachespec' + BUG_REPORT_LOCATION = 'https://issues.example.com/report' + + def generate_cache_spec(self, request: PoTokenRequest): + + client_name = traverse_obj(request.innertube_context, ('client', 'clientName')) + if client_name != 'ANDROID': + # ℹ️ If the request is not supported by the cache spec, return None + return None + + # Generate a cache spec for the request + return PoTokenCacheSpec( + # Key bindings to uniquely identify the request. These are used to generate a cache key. + key_bindings={ + 'client_name': client_name, + 'content_binding': 'unique_content_binding', + 'ip': traverse_obj(request.innertube_context, ('client', 'remoteHost')), + 'source_address': request.request_source_address, + 'proxy': request.request_proxy, + }, + # Default Cache TTL in seconds + default_ttl=21600, + + # Optional: Specify a write policy. + # WRITE_FIRST will write to the highest priority provider only, whereas WRITE_ALL will write to all providers. + # WRITE_FIRST may be useful if the PO Token is short-lived and there is no use writing to all providers. + write_policy=CacheProviderWritePolicy.WRITE_ALL, + ) +``` \ No newline at end of file diff --git a/yt_dlp/extractor/youtube/pot/__init__.py b/yt_dlp/extractor/youtube/pot/__init__.py new file mode 100644 index 0000000000..f2075d349f --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/__init__.py @@ -0,0 +1 @@ +from .builtin import utils as _ # noqa: F401 diff --git a/yt_dlp/extractor/youtube/pot/_director.py b/yt_dlp/extractor/youtube/pot/_director.py new file mode 100644 index 0000000000..96d2bad348 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/_director.py @@ -0,0 +1,424 @@ +from __future__ import annotations + +import base64 +import binascii +import dataclasses +import datetime as dt +import hashlib +import json +import typing +import urllib.parse +from collections.abc import Iterable + +from yt_dlp.extractor.youtube.pot._provider import ( + BuiltInIEContentProvider, + IEContentProvider, + IEContentProviderLogger, +) +from yt_dlp.extractor.youtube.pot.cache import ( + CacheProviderWritePolicy, + PoTokenCacheProvider, + PoTokenCacheProviderError, + PoTokenCacheSpec, + PoTokenCacheSpecProvider, +) +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenProvider, + PoTokenProviderError, + PoTokenProviderRejectedRequest, + PoTokenRequest, + PoTokenResponse, + provider_bug_report_message, +) +from yt_dlp.globals import ( + _pot_cache_provider_preferences, + _pot_cache_providers, + _pot_pcs_providers, + _pot_providers, + _ptp_preferences, +) +from yt_dlp.utils import ExtractorError, bug_reports_message, format_field, join_nonempty, traverse_obj + +if typing.TYPE_CHECKING: + from yt_dlp.extractor.youtube.pot.cache import PCPPreference + + +class YoutubeIEContentProviderLogger(IEContentProviderLogger): + def __init__(self, ie, prefix, enable_trace=False): + self.__ie = ie + self.prefix = prefix + self.enable_trace = enable_trace + + def _format_msg(self, message: str): + prefixstr = format_field(self.prefix, None, '[%s] ') + return f'{prefixstr}{message}' + + def trace(self, message: str): + if self.enable_trace: + self.__ie.write_debug(self._format_msg('TRACE: ' + message)) + + def debug(self, message: str): + self.__ie.write_debug(self._format_msg(message)) + + def info(self, message: str): + self.__ie.to_screen(self._format_msg(message)) + + def warning(self, message: str, *, once=False): + self.__ie.report_warning(self._format_msg(message), only_once=once) + + def error(self, message: str): + self.__ie._downloader.report_error(self._format_msg(message), is_error=False) + + +class PoTokenCache: + + def __init__( + self, + logger: IEContentProviderLogger, + cache_providers: list[PoTokenCacheProvider], + cache_spec_providers: list[PoTokenCacheSpecProvider], + cache_provider_preferences: list[PCPPreference] | None = None, + ): + self.cache_providers: dict[str, PoTokenCacheProvider] = { + provider.PROVIDER_KEY: provider for provider in (cache_providers or []) + } + self.cache_provider_preferences: list[PCPPreference] = cache_provider_preferences or [] + + self.cache_spec_providers: dict[str, PoTokenCacheSpecProvider] = { + provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or []) + } + + self.logger = logger + + def _get_available_providers(self) -> Iterable[PoTokenCacheProvider]: + return (provider for provider in self.cache_providers.values() if provider.is_available()) + + def _get_cache_providers(self, request: PoTokenRequest) -> list[PoTokenCacheProvider]: + """Sorts cache providers by preference, given a request""" + preferences = { + provider: sum(pref(provider, request) for pref in self.cache_provider_preferences) + for provider in self._get_available_providers() + } + self.logger.trace(f'Available PO Token Providers: {provider_display_list(self._get_available_providers())}') + self.logger.trace('Cache Provider preferences for this request: {}'.format(', '.join( + f'{provider.PROVIDER_KEY}={pref}' for provider, pref in preferences.items()))) + + return sorted(self._get_available_providers(), key=preferences.get, reverse=True) + + def _get_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: + for provider in self.cache_spec_providers.values(): + if not provider.is_available(): + continue + try: + spec = provider.generate_cache_spec(request) + if not spec: + continue + if not validate_cache_spec(spec): + self.logger.error(f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() returned invalid spec {spec}{provider_bug_report_message(provider)}') + continue + spec = dataclasses.replace(spec, _provider=provider) + self.logger.trace(f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"') + return spec + except Exception as e: + self.logger.error( + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: {e!r}{provider_bug_report_message(provider)}', + ) + continue + + def _generate_key_bindings(self, spec: PoTokenCacheSpec) -> dict[str, str]: + bindings_cleaned = { + **{k: v for k, v in spec.key_bindings.items() if v is not None}, + # Allow us to invalidate caches if such need arises + '_yt': 'v1', + '_p': spec._provider.PROVIDER_KEY, + } + self.logger.trace('Generate cache key bindings: {}'.format(', '.join(f'{k}={v}' for k, v in bindings_cleaned.items()))) + return bindings_cleaned + + def _generate_key(self, bindings: dict) -> str: + binding_string = ''.join(f'{k}{v}' for k, v in sorted(bindings.items())) + return hashlib.sha256(binding_string.encode()).hexdigest() + + def get(self, request: PoTokenRequest) -> PoTokenResponse | None: + spec = self._get_cache_spec(request) + if not spec: + self.logger.trace('No cache spec available for this request, unable to fetch from cache') + return None + + cache_key = self._generate_key(self._generate_key_bindings(spec)) + self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}') + + for idx, provider in enumerate(self._get_cache_providers(request)): + try: + cache_response = provider.get(cache_key) + if not cache_response: + continue + try: + po_token_response = PoTokenResponse(**json.loads(cache_response)) + except (TypeError, ValueError, json.JSONDecodeError): + po_token_response = None + if not validate_response(po_token_response): + self.logger.error(f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": {cache_response}{provider_bug_report_message(provider)}') + provider.delete(cache_key) + continue + self.logger.trace(f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: {po_token_response}') + if idx > 0: + # Write back to the highest priority cache provider, + # so we stop trying to fetch from lower priority providers + self.logger.trace('Writing PO Token response to highest priority cache provider') + self.store(request, po_token_response, write_policy=CacheProviderWritePolicy.WRITE_FIRST) + + return po_token_response + except PoTokenCacheProviderError as e: + self.logger.warning( + f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + continue + except Exception as e: + self.logger.error( + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}', + ) + continue + + def store(self, request: PoTokenRequest, response: PoTokenResponse, write_policy: CacheProviderWritePolicy | None = None): + spec = self._get_cache_spec(request) + if not spec: + self.logger.trace('No cache spec available for this request. Not caching.') + return + + if not validate_response(response): + self.logger.error(f'Invalid PO Token response provided to PoTokenCache.store(): {response}{bug_reports_message()}') + return + + cache_key = self._generate_key(self._generate_key_bindings(spec)) + self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}') + + default_expires_at = int(dt.datetime.now(dt.timezone.utc).timestamp()) + spec.default_ttl + cache_response = dataclasses.replace(response, expires_at=response.expires_at or default_expires_at) + + write_policy = write_policy or spec.write_policy + self.logger.trace(f'Using write policy: {write_policy}') + + for idx, provider in enumerate(self._get_cache_providers(request)): + # WRITE_FIRST should not write to lower priority providers in the case the highest priority provider fails + if idx > 0 and write_policy == CacheProviderWritePolicy.WRITE_FIRST: + return + try: + self.logger.trace( + f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider (key={cache_key}, expires_at={cache_response.expires_at})', + ) + provider.store(key=cache_key, value=json.dumps(dataclasses.asdict(cache_response)), expires_at=cache_response.expires_at) + + except PoTokenCacheProviderError as e: + self.logger.warning( + f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + continue + except Exception as e: + self.logger.error( + f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: {e!r}{provider_bug_report_message(provider)}', + ) + continue + + def close(self): + for provider in self.cache_providers.values(): + provider.close() + for spec_provider in self.cache_spec_providers.values(): + spec_provider.close() + + +class PoTokenRequestDirector: + + def __init__(self, logger: IEContentProviderLogger, cache: PoTokenCache): + self.providers = {} + self.preferences = [] + self.cache = cache + self.logger = logger + + def register_provider(self, provider: PoTokenProvider): + self.providers[provider.PROVIDER_KEY] = provider + + def register_preference(self, preference): + self.preferences.append(preference) + + def _get_available_providers(self) -> list[PoTokenProvider]: + return [provider for provider in self.providers.values() if provider.is_available()] + + def _get_providers(self, request: PoTokenRequest) -> list[PoTokenProvider]: + """Sorts available providers by preference, given a request""" + preferences = { + provider: sum(pref(provider, request) for pref in self.preferences) + for provider in self._get_available_providers() + if provider.is_available() + } + self.logger.trace(f'Available PO Token Providers: {provider_display_list(self._get_available_providers())}') + self.logger.trace('Provider preferences for this request: {}'.format(', '.join( + f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items()))) + + return sorted(self._get_available_providers(), key=preferences.get, reverse=True) + + def _get_po_token(self, request) -> PoTokenResponse | None: + for provider in self._get_providers(request): + try: + self.logger.trace(f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider') + response = provider.request_pot(request.copy()) + except PoTokenProviderRejectedRequest as e: + self.logger.trace( + f'PO Token Provider "{provider.PROVIDER_NAME}" does not support this request, trying next available provider. Reason: {e}') + continue + except PoTokenProviderError as e: + self.logger.warning( + f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider) if not e.expected else ""}') + continue + except Exception as e: + self.logger.error( + f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: {e!r}{provider_bug_report_message(provider)}') + continue + + self.logger.trace(f'PO Token response from "{provider.PROVIDER_NAME}" provider: {response}') + + if not validate_response(response): + self.logger.error( + f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: {response}{provider_bug_report_message(provider)}') + continue + + return response + + self.logger.trace('No PO Token providers were able to provide a valid PO Token') + return None + + def get_po_token(self, request: PoTokenRequest) -> str | None: + if not request.bypass_cache: + if pot_response := self.cache.get(request): + return clean_pot(pot_response.po_token) + + if not self.providers: + self.logger.trace('No PO Token providers registered') + return None + + pot_response = self._get_po_token(request) + if not pot_response: + return None + + pot_response.po_token = clean_pot(pot_response.po_token) + + if pot_response.expires_at is None or pot_response.expires_at > 0: + self.cache.store(request, pot_response) + else: + self.logger.trace(f'PO Token response will not be cached (expires_at={pot_response.expires_at})') + + return pot_response.po_token + + def close(self): + # TODO: hook into ydl close + for provider in self.providers.values(): + provider.close() + + self.cache.close() + + +EXTRACTOR_ARG_PREFIX = 'youtubepot' + + +def initialize_pot_director(ie): + if not ie._downloader: + raise ExtractorError('Downloader not set', expected=False) + + enable_trace = ie._configuration_arg('pot_debug', ['false'], ie_key='youtube', casesense=False)[0] == 'true' + + cache_providers = [] + for cache_provider in _pot_cache_providers.value.values(): + settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_provider.PROVIDER_KEY.lower()}')) + cache_provider_logger = YoutubeIEContentProviderLogger(ie, f'pot:cache:{cache_provider.PROVIDER_NAME}', enable_trace=enable_trace) + cache_providers.append(cache_provider(ie, cache_provider_logger, settings or {})) + + cache_spec_providers = [] + for cache_spec_provider in _pot_pcs_providers.value.values(): + settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{cache_spec_provider.PROVIDER_KEY.lower()}')) + cache_spec_provider_logger = YoutubeIEContentProviderLogger(ie, f'pot:cache:spec:{cache_spec_provider.PROVIDER_NAME}', enable_trace=enable_trace) + cache_spec_providers.append(cache_spec_provider(ie, cache_spec_provider_logger, settings or {})) + + cache = PoTokenCache( + logger=YoutubeIEContentProviderLogger(ie, 'pot:cache', enable_trace=enable_trace), + cache_providers=cache_providers, + cache_spec_providers=cache_spec_providers, + cache_provider_preferences=list(_pot_cache_provider_preferences.value), + ) + + director = PoTokenRequestDirector( + logger=YoutubeIEContentProviderLogger(ie, 'pot', enable_trace=enable_trace), + cache=cache, + ) + + for provider in _pot_providers.value.values(): + settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}')) + logger = YoutubeIEContentProviderLogger(ie, f'pot:{provider.PROVIDER_NAME}', enable_trace=enable_trace) + director.register_provider(provider(ie, logger, settings or {})) + + for preference in _ptp_preferences.value: + director.register_preference(preference) + + director.logger.debug(f'PO Token Providers: {provider_display_list(director.providers.values())}') + director.logger.debug(f'PO Token Cache Providers: {provider_display_list(cache.cache_providers.values())}') + director.logger.debug(f'PO Token Cache Spec Providers: {provider_display_list(cache.cache_spec_providers.values())}') + director.logger.trace(f'Registered {len(director.preferences)} provider preferences') + director.logger.trace(f'Registered {len(cache.cache_provider_preferences)} cache provider preferences') + + return director + + +def provider_display_list(providers: Iterable[IEContentProvider]): + def provider_display_name(provider): + display_str = join_nonempty(provider.PROVIDER_NAME, provider.PROVIDER_VERSION if not isinstance(provider, BuiltInIEContentProvider) else None) + statuses = [] + if not isinstance(provider, BuiltInIEContentProvider): + statuses.append('external') + if not provider.is_available(): + statuses.append('unavailable') + if statuses: + display_str += f' ({", ".join(statuses)})' + return display_str + + return ', '.join(provider_display_name(provider) for provider in providers) or 'none' + + +def clean_pot(po_token: str): + # Clean and validate the PO Token. This will strip invalid characters off + # (e.g. additional url params the user may accidentally include) + try: + return base64.urlsafe_b64encode(base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode() + except (binascii.Error, ValueError): + raise ValueError('Invalid PO Token') + + +def validate_response(response: PoTokenResponse): + if ( + not isinstance(response, PoTokenResponse) + or not response.po_token + or not isinstance(response.po_token, str) + ): # noqa: SIM103 + return False + + try: + clean_pot(response.po_token) + except ValueError: + return False + + return ( + response.expires_at is None + or ( + isinstance(response.expires_at, int) + and (response.expires_at <= 0 or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp())) + ) + ) + + +def validate_cache_spec(spec: PoTokenCacheSpec): + + return ( + isinstance(spec, PoTokenCacheSpec) + and isinstance(spec.write_policy, CacheProviderWritePolicy) + and isinstance(spec.default_ttl, int) + and isinstance(spec.key_bindings, dict) + and all(isinstance(k, str) for k in spec.key_bindings) + and all(v is None or isinstance(v, str) for v in spec.key_bindings.values()) + and len({k for k in spec.key_bindings.values() if k is not None}) > 0 + ) diff --git a/yt_dlp/extractor/youtube/pot/_provider.py b/yt_dlp/extractor/youtube/pot/_provider.py new file mode 100644 index 0000000000..cda88986b2 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/_provider.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import abc +import functools + +from yt_dlp.extractor.common import InfoExtractor +from yt_dlp.utils import NO_DEFAULT, bug_reports_message, classproperty, traverse_obj +from yt_dlp.version import __version__ + +# xxx: these could be generalized outside YoutubeIE eventually + + +class IEContentProviderLogger(abc.ABC): + @abc.abstractmethod + def trace(self, message: str): + pass + + @abc.abstractmethod + def debug(self, message: str): + pass + + @abc.abstractmethod + def info(self, message: str): + pass + + @abc.abstractmethod + def warning(self, message: str, *, once=False): + pass + + @abc.abstractmethod + def error(self, message: str): + pass + + +class IEContentProviderError(Exception): + def __init__(self, msg=None, expected=False): + super().__init__(msg) + self.expected = expected + + +class IEContentProvider(abc.ABC): + PROVIDER_VERSION: str = '0.0.0' + BUG_REPORT_LOCATION: str = '(developer has not provided a bug report location)' + + def __init__(self, ie: InfoExtractor, logger: IEContentProviderLogger, settings: dict[str, list[str]], *_, **__): + self.ie = ie + self.settings = settings or {} + self.logger = logger + super().__init__() + + @classmethod + def __init_subclass__(cls, *, suffix=None, **kwargs): + if suffix: + cls._PROVIDER_KEY_SUFFIX = suffix + return super().__init_subclass__(**kwargs) + + @classproperty + def PROVIDER_NAME(cls) -> str: + return cls.__name__[:-len(cls._PROVIDER_KEY_SUFFIX)] + + @classproperty + def BUG_REPORT_MESSAGE(cls): + return f'please report this issue to the provider developer at {cls.BUG_REPORT_LOCATION} .' + + @classproperty + def PROVIDER_KEY(cls) -> str: + assert hasattr(cls, '_PROVIDER_KEY_SUFFIX'), 'Content Provider implementation must define a suffix for the provider key' + assert cls.__name__.endswith(cls._PROVIDER_KEY_SUFFIX), f'PoTokenProvider class names must end with "{cls._PROVIDER_KEY_SUFFIX}"' + return cls.__name__[:-len(cls._PROVIDER_KEY_SUFFIX)] + + @abc.abstractmethod + def is_available(self) -> bool: + """ + Check if the provider is available (e.g. all required dependencies are available) + This is used to determine if the provider should be used and to provide debug information. + + IMPORTANT: This method should not make any network requests or perform any expensive operations. + It is called multiple times. + """ + raise NotImplementedError + + def close(self): # noqa: B027 + pass + + def get_setting(self, key, default=NO_DEFAULT, *, casesense=False): + ''' + @returns A list of values for the setting given by "key" + or "default" if no such key is present + @param default The default value to return when the key is not present (default: []) + @param casesense When false, the values are converted to lower case + ''' + val = traverse_obj(self.settings, key) + if val is None: + return [] if default is NO_DEFAULT else default + return list(val) if casesense else [x.lower() for x in val] + + +class BuiltInIEContentProvider(IEContentProvider, abc.ABC): + PROVIDER_VERSION = __version__ + BUG_REPORT_MESSAGE = bug_reports_message(before='') + + +def register_provider_generic( + provider, + base_class, + registry, +): + """Generic function to register a provider class""" + assert issubclass(provider, base_class), f'{provider} must be a subclass of {base_class.__name__}' + assert provider.PROVIDER_KEY not in registry, f'{base_class.__name__} {provider.PROVIDER_KEY} already registered' + registry[provider.PROVIDER_KEY] = provider + return provider + + +def register_preference_generic( + base_class, + registry, + *providers, +): + """Generic function to register a preference for a provider""" + assert all(issubclass(provider, base_class) for provider in providers) + + def outer(preference): + @functools.wraps(preference) + def inner(provider, *args, **kwargs): + if not providers or isinstance(provider, providers): + return preference(provider, *args, **kwargs) + return 0 + registry.add(inner) + return preference + return outer diff --git a/yt_dlp/extractor/youtube/pot/builtin/__init__.py b/yt_dlp/extractor/youtube/pot/builtin/__init__.py new file mode 100644 index 0000000000..ceeca557f5 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/builtin/__init__.py @@ -0,0 +1,2 @@ +from .cache.memory import MemoryLRUPCP as _MemoryPCP # noqa: F401 +from .cachespec.webpo import WebPoPCSP as _WebPoPCSP # noqa: F401 diff --git a/yt_dlp/extractor/youtube/pot/builtin/cache/memory.py b/yt_dlp/extractor/youtube/pot/builtin/cache/memory.py new file mode 100644 index 0000000000..2a7b7991e9 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/builtin/cache/memory.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import datetime as dt +import typing +from collections import OrderedDict +from threading import Lock + +from yt_dlp.extractor.youtube.pot._provider import BuiltInIEContentProvider +from yt_dlp.extractor.youtube.pot.cache import ( + PoTokenCacheProvider, + register_pcp, + register_pcp_preference, +) +from yt_dlp.globals import _pot_memory_cache +from yt_dlp.utils import int_or_none + + +def initialize_global_cache(max_size: int): + if _pot_memory_cache.value.get('cache') is None: + _pot_memory_cache.value['cache'] = OrderedDict() + _pot_memory_cache.value['lock'] = Lock() + _pot_memory_cache.value['max_size'] = max_size + + if _pot_memory_cache.value['max_size'] != max_size: + raise ValueError('Cannot change max_size of initialized global memory cache') + + return _pot_memory_cache.value['cache'], _pot_memory_cache.value['lock'], _pot_memory_cache.value['max_size'] + + +@register_pcp +class MemoryLRUPCP(PoTokenCacheProvider, BuiltInIEContentProvider): + PROVIDER_NAME = 'memory' + DEFAULT_CACHE_SIZE = 25 + + def __init__( + self, + *args, + initialize_cache: typing.Callable[[int], tuple[OrderedDict, Lock, int]] = initialize_global_cache, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.cache, self.lock, self.max_size = initialize_cache( + int_or_none(self.settings.get('max_size', [''])[0]) or self.DEFAULT_CACHE_SIZE) + + def is_available(self) -> bool: + return True + + def get(self, key: str) -> str | None: + with self.lock: + if key not in self.cache: + self.logger.trace('cache miss') + return None + value, expires_at = self.cache.pop(key) + if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()): + self.logger.trace(f'cache expired key={key}') + return None + self.cache[key] = (value, expires_at) + self.logger.trace(f'cache hit key={key}') + return value + + def store(self, key: str, value: str, expires_at: int): + with self.lock: + if expires_at < int(dt.datetime.now(dt.timezone.utc).timestamp()): + self.logger.trace(f'ignoring expired key={key}') + return + if key in self.cache: + self.cache.pop(key) + self.cache[key] = (value, expires_at) + if len(self.cache) > self.max_size: + self.cache.popitem(last=False) + self.logger.trace(f'storing key={key}') + + def delete(self, key: str): + with self.lock: + self.logger.trace(f'deleting key={key}') + self.cache.pop(key, None) + + +@register_pcp_preference(MemoryLRUPCP) +def memorylru_preference(*_, **__): + # Memory LRU Cache SHOULD be the highest priority + return 10000 diff --git a/yt_dlp/extractor/youtube/pot/builtin/cachespec/webpo.py b/yt_dlp/extractor/youtube/pot/builtin/cachespec/webpo.py new file mode 100644 index 0000000000..b098645ad4 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/builtin/cachespec/webpo.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from yt_dlp.extractor.youtube.pot._provider import BuiltInIEContentProvider +from yt_dlp.extractor.youtube.pot.builtin.utils import ContentBindingType, get_webpo_content_binding +from yt_dlp.extractor.youtube.pot.cache import ( + CacheProviderWritePolicy, + PoTokenCacheSpec, + PoTokenCacheSpecProvider, + register_pcsp, +) +from yt_dlp.extractor.youtube.pot.provider import ( + PoTokenRequest, +) +from yt_dlp.utils import traverse_obj + + +@register_pcsp +class WebPoPCSP(PoTokenCacheSpecProvider, BuiltInIEContentProvider): + PROVIDER_NAME = 'webpo' + + def generate_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: + bind_to_visitor_id = self.get_setting('bind_to_visitor_id', default=['true'])[0] == 'true' + content_binding, content_binding_type = get_webpo_content_binding( + request, bind_to_visitor_id=bind_to_visitor_id) + + if not content_binding: + return None + + write_policy = CacheProviderWritePolicy.WRITE_ALL + if content_binding_type == ContentBindingType.VIDEO_ID: + write_policy = CacheProviderWritePolicy.WRITE_FIRST + + return PoTokenCacheSpec( + key_bindings={ + 't': 'webpo', + 'cb': content_binding, + 'cbt': content_binding_type.value, + 'ip': traverse_obj(request.innertube_context, ('client', 'remoteHost')), + 'sa': request.request_source_address, + 'px': request.request_proxy, + }, + # Integrity token response usually states it has a ttl of 12 hours (43200 seconds). + # We will default to 6 hours to be safe. + default_ttl=21600, + write_policy=write_policy, + ) diff --git a/yt_dlp/extractor/youtube/pot/builtin/utils.py b/yt_dlp/extractor/youtube/pot/builtin/utils.py new file mode 100644 index 0000000000..34c06b21bb --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/builtin/utils.py @@ -0,0 +1,64 @@ +"""PUBLIC API""" + +from __future__ import annotations + +import base64 +import contextlib +import enum +import urllib.parse + +from yt_dlp.extractor.youtube.pot.provider import PoTokenContext, PoTokenRequest +from yt_dlp.utils import traverse_obj + +__all__ = ['WEBPO_CLIENTS', 'get_webpo_content_binding'] + +WEBPO_CLIENTS = ( + 'WEB', + 'MWEB', + 'TVHTML5', + 'WEB_EMBEDDED_PLAYER', + 'WEB_CREATOR', + 'WEB_REMIX', + 'TVHTML5_SIMPLY_EMBEDDED_PLAYER', +) + + +class ContentBindingType(enum.Enum): + VISITOR_DATA = 'visitor_data' + DATASYNC_ID = 'datasync_id' + VIDEO_ID = 'video_id' + VISITOR_ID = 'visitor_id' + + +def get_webpo_content_binding(request: PoTokenRequest, webpo_clients=WEBPO_CLIENTS, bind_to_visitor_id=False) -> tuple[str | None, ContentBindingType | None]: + client_name = traverse_obj(request.innertube_context, ('client', 'clientName')) + if not client_name or client_name not in webpo_clients: + return None, None + + if request.context == PoTokenContext.GVS or client_name in ('WEB_REMIX', ): + if request.is_authenticated: + return request.data_sync_id, ContentBindingType.DATASYNC_ID + else: + if bind_to_visitor_id: + visitor_id = _extract_visitor_id(request.visitor_data) + if visitor_id: + return visitor_id, ContentBindingType.VISITOR_ID + return request.visitor_data, ContentBindingType.VISITOR_DATA + + elif request.context == PoTokenContext.PLAYER or client_name != 'WEB_REMIX': + return request.video_id, ContentBindingType.VIDEO_ID + + +def _extract_visitor_id(visitor_data): + if not visitor_data: + return + + # Attempt to extract the visitor ID from the visitor_data protobuf + # xxx: ideally should use a protobuf parser + with contextlib.suppress(Exception): + visitor_id = base64.urlsafe_b64decode(urllib.parse.unquote_plus(visitor_data))[2:13].decode() + # check that visitor id is all letters and numbers + if visitor_id.isalnum() and len(visitor_id) == 11: + return visitor_id + + return None diff --git a/yt_dlp/extractor/youtube/pot/cache.py b/yt_dlp/extractor/youtube/pot/cache.py new file mode 100644 index 0000000000..7dd1fbcdb5 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/cache.py @@ -0,0 +1,93 @@ +"""PUBLIC API""" + +from __future__ import annotations + +import abc +import dataclasses +import enum +import typing + +from yt_dlp.extractor.youtube.pot._provider import ( + IEContentProvider, + IEContentProviderError, + register_preference_generic, + register_provider_generic, +) +from yt_dlp.extractor.youtube.pot.provider import PoTokenRequest +from yt_dlp.globals import _pot_cache_provider_preferences, _pot_cache_providers, _pot_pcs_providers + + +class PoTokenCacheProviderError(IEContentProviderError): + """An error occurred while fetching a PO Token""" + + +class PoTokenCacheProvider(IEContentProvider, abc.ABC, suffix='PCP'): + @abc.abstractmethod + def get(self, key: str) -> str | None: + pass + + @abc.abstractmethod + def store(self, key: str, value: str, expires_at: int): + pass + + @abc.abstractmethod + def delete(self, key: str): + pass + + +class CacheProviderWritePolicy(enum.Enum): + WRITE_ALL = enum.auto() # Write to all cache providers + WRITE_FIRST = enum.auto() # Write to only the first cache provider + + +@dataclasses.dataclass +class PoTokenCacheSpec: + key_bindings: dict[str, str | None] + default_ttl: int + write_policy: CacheProviderWritePolicy = CacheProviderWritePolicy.WRITE_ALL + + # Internal + _provider: PoTokenCacheSpecProvider | None = None + + +class PoTokenCacheSpecProvider(IEContentProvider, abc.ABC, suffix='PCSP'): + + def is_available(self) -> bool: + return True + + @abc.abstractmethod + def generate_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None: + """Generate a cache spec for the given request""" + pass + + +def register_pcp(provider: type[PoTokenCacheProvider]): + """Register a PoTokenCacheProvider class""" + return register_provider_generic( + provider=provider, + base_class=PoTokenCacheProvider, + registry=_pot_cache_providers.value, + ) + + +def register_pcsp(provider: type[PoTokenCacheSpecProvider]): + """Register a PoTokenCacheSpecProvider class""" + return register_provider_generic( + provider=provider, + base_class=PoTokenCacheSpecProvider, + registry=_pot_pcs_providers.value, + ) + + +# XXX: I don't think the typing is correct, and that we need py3.10 to properly type this +def register_pcp_preference(*providers: type[PoTokenCacheProvider]) -> typing.Callable[[PCPPreference], PCPPreference]: + """Register a preference for a PoTokenCacheProvider""" + return register_preference_generic( + PoTokenCacheProvider, + _pot_cache_provider_preferences.value, + *providers, + ) + + +if typing.TYPE_CHECKING: + PCPPreference = typing.Callable[[PoTokenCacheProvider, PoTokenRequest, ...], int] diff --git a/yt_dlp/extractor/youtube/pot/provider.py b/yt_dlp/extractor/youtube/pot/provider.py new file mode 100644 index 0000000000..2c71c3fb60 --- /dev/null +++ b/yt_dlp/extractor/youtube/pot/provider.py @@ -0,0 +1,211 @@ +"""PUBLIC API""" + +from __future__ import annotations + +import abc +import copy +import dataclasses +import enum +import typing +import urllib.parse + +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.extractor.youtube.pot._provider import ( + IEContentProvider, + IEContentProviderError, + register_preference_generic, + register_provider_generic, +) +from yt_dlp.globals import _pot_providers, _ptp_preferences +from yt_dlp.networking import Request +from yt_dlp.utils import traverse_obj +from yt_dlp.utils.networking import HTTPHeaderDict + +__all__ = [ + 'PoTokenContext', + 'PoTokenProvider', + 'PoTokenProviderError', + 'PoTokenProviderRejectedRequest', + 'PoTokenRequest', + 'PoTokenResponse', + 'provider_bug_report_message', + 'register_preference', + 'register_provider', +] + + +class PoTokenContext(enum.Enum): + GVS = 'gvs' + PLAYER = 'player' + + +@dataclasses.dataclass +class PoTokenRequest: + # YouTube parameters + context: PoTokenContext + innertube_context: InnertubeContext + innertube_host: str | None = None + session_index: str | None = None + player_url: str | None = None + is_authenticated: bool = False + video_webpage: str | None = None + internal_client_name: str | None = None + + # Content binding parameters + visitor_data: str | None = None + data_sync_id: str | None = None + video_id: str | None = None + + # Networking parameters + request_cookiejar: YoutubeDLCookieJar = dataclasses.field(default_factory=YoutubeDLCookieJar) + request_proxy: str | None = None + request_headers: HTTPHeaderDict = dataclasses.field(default_factory=HTTPHeaderDict) + request_timeout: float | None = None + request_source_address: str | None = None + request_verify_tls: bool = True + + # Generate a new token, do not used a cached token + # The token should still be cached for future requests + bypass_cache: bool = False + + def copy(self): + return dataclasses.replace( + self, + request_headers=HTTPHeaderDict(self.request_headers), + innertube_context=copy.deepcopy(self.innertube_context), + ) + + +@dataclasses.dataclass +class PoTokenResponse: + po_token: str + expires_at: int | None = None + + +class PoTokenProviderRejectedRequest(IEContentProviderError): + """Reject the PoTokenRequest (cannot handle the request)""" + + +class PoTokenProviderError(IEContentProviderError): + """An error occurred while fetching a PO Token""" + + +class PoTokenProvider(IEContentProvider, abc.ABC, suffix='PTP'): + + # Set to None to disable the check + _SUPPORTED_CONTEXTS: tuple[PoTokenContext] | None = () + + # Innertube Client Name. + # For example, "WEB", "ANDROID", "TVHTML5". + # For a list of WebPO client names, see yt_dlp.extractor.youtube.pot.builtin.utils.WEBPO_CLIENTS. + # Also see yt_dlp.extractor.youtube._base.INNERTUBE_CLIENTS for a list of client names currently supported by the YouTube extractor. + _SUPPORTED_CLIENTS: tuple[str] | None = () + + # Possible values: http, https, socks4, socks4a, socks5, socks5h + _SUPPORTED_PROXY_SCHEMES: tuple[str] | None = () + + def __validate_request(self, request: PoTokenRequest): + if not self.is_available(): + raise PoTokenProviderRejectedRequest(f'{self.PROVIDER_NAME} is not available') + + # Validate request using built-in settings + if self._SUPPORTED_CONTEXTS is not None and request.context not in self._SUPPORTED_CONTEXTS: + raise PoTokenProviderRejectedRequest(f'PO Token Context "{request.context}" is not supported by {self.PROVIDER_NAME}') + + if self._SUPPORTED_CLIENTS is not None: + client_name = traverse_obj(request.innertube_context, ('client', 'clientName')) + if client_name not in self._SUPPORTED_CLIENTS: + raise PoTokenProviderRejectedRequest( + f'Client "{client_name}" is not supported by {self.PROVIDER_NAME}. Supported clients: {", ".join(self._SUPPORTED_CLIENTS) or "none"}') + + if self._SUPPORTED_PROXY_SCHEMES is not None and request.request_proxy: + scheme = urllib.parse.urlparse(request.request_proxy).scheme + if scheme.lower() not in self._SUPPORTED_PROXY_SCHEMES: + raise PoTokenProviderRejectedRequest( + f'Proxy scheme "{scheme}" is not supported by {self.PROVIDER_NAME}. Supported proxy schemes: {", ".join(self._SUPPORTED_PROXY_SCHEMES) or "none"}') + + def request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + self.__validate_request(request) + return self._real_request_pot(request) + + @abc.abstractmethod + def _real_request_pot(self, request: PoTokenRequest) -> PoTokenResponse: + """To be implemented by subclasses""" + pass + + # Helper functions + + def _urlopen(self, pot_request: PoTokenRequest, http_request: Request): + """Make a request using the request parameters from the PoTokenRequest. + Use this instead of calling requests, urllib3 or other HTTP client libraries directly!! + + YouTube cookies will be automatically applied if this request is made to YouTube. + + Tips: + - Disable proxy (e.g. if calling local service): Request(..., proxies={'all': None}) + - Set request timeout: Request(..., extensions={'timeout': 5.0}) + """ + req = http_request.copy() + + # Merge some ctx request settings into the request + # Most of these will already be used by the configured ydl instance, + # however, the YouTube extractor may override some. + req.headers = HTTPHeaderDict(pot_request.request_headers, req.headers) + req.proxies = req.proxies or ({'all': pot_request.request_proxy} if pot_request.request_proxy else {}) + + if pot_request.request_cookiejar is not None: + req.extensions['cookiejar'] = req.extensions.get('cookiejar', pot_request.request_cookiejar) + + return self.ie._downloader.urlopen(req) + + +def register_provider(provider: type[PoTokenProvider]): + """Register a PoTokenProvider class""" + return register_provider_generic( + provider=provider, + base_class=PoTokenProvider, + registry=_pot_providers.value, + ) + + +def provider_bug_report_message(provider: IEContentProvider, before=';'): + msg = provider.BUG_REPORT_MESSAGE + + before = before.rstrip() + if not before or before.endswith(('.', '!', '?')): + msg = msg[0].title() + msg[1:] + + return (before + ' ' if before else '') + msg + + +# XXX: I don't think the typing is correct, and that we need py3.10 to properly type this +def register_preference(*providers: type[PoTokenProvider]) -> typing.Callable[[Preference], Preference]: + """Register a preference for a PoTokenProvider""" + return register_preference_generic( + PoTokenProvider, + _ptp_preferences.value, + *providers, + ) + + +if typing.TYPE_CHECKING: + Preference = typing.Callable[[PoTokenProvider, PoTokenRequest, ...], int] + + # Barebones innertube context. There may be more fields. + class ClientInfo(typing.TypedDict, total=False): + hl: str | None + gl: str | None + remoteHost: str | None + deviceMake: str | None + deviceModel: str | None + visitorData: str | None + userAgent: str | None + clientName: str + clientVersion: str + osName: str | None + osVersion: str | None + + class InnertubeContext(typing.TypedDict, total=False): + client: ClientInfo + request: dict + user: dict diff --git a/yt_dlp/globals.py b/yt_dlp/globals.py index 0cf276cc9e..8693b02f04 100644 --- a/yt_dlp/globals.py +++ b/yt_dlp/globals.py @@ -28,3 +28,12 @@ plugin_ies_overrides = Indirect(defaultdict(list)) # Misc IN_CLI = Indirect(False) LAZY_EXTRACTORS = Indirect(None) # `False`=force, `None`=disabled, `True`=enabled + + +# YouTube extractor - Internal only +_pot_providers = Indirect({}) +_ptp_preferences = Indirect(set()) +_pot_pcs_providers = Indirect({}) +_pot_cache_providers = Indirect({}) +_pot_cache_provider_preferences = Indirect(set()) +_pot_memory_cache = Indirect({}) diff --git a/yt_dlp/networking/_curlcffi.py b/yt_dlp/networking/_curlcffi.py index c800f2c095..747879da87 100644 --- a/yt_dlp/networking/_curlcffi.py +++ b/yt_dlp/networking/_curlcffi.py @@ -6,7 +6,8 @@ import math import re import urllib.parse -from ._helper import InstanceStoreMixin, select_proxy +from ._helper import InstanceStoreMixin +from ..utils.networking import select_proxy from .common import ( Features, Request, diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py index b86d3606d8..ef9c8bafab 100644 --- a/yt_dlp/networking/_helper.py +++ b/yt_dlp/networking/_helper.py @@ -13,7 +13,6 @@ import urllib.request from .exceptions import RequestError from ..dependencies import certifi from ..socks import ProxyType, sockssocket -from ..utils import format_field, traverse_obj if typing.TYPE_CHECKING: from collections.abc import Iterable @@ -82,19 +81,6 @@ def make_socks_proxy_opts(socks_proxy): } -def select_proxy(url, proxies): - """Unified proxy selector for all backends""" - url_components = urllib.parse.urlparse(url) - if 'no' in proxies: - hostport = url_components.hostname + format_field(url_components.port, None, ':%s') - if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}): - return - elif urllib.request.proxy_bypass(hostport): # check system settings - return - - return traverse_obj(proxies, url_components.scheme or 'http', 'all') - - def get_redirect_method(method, status): """Unified redirect method handling""" diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py index 5b6b264a68..d02e976b57 100644 --- a/yt_dlp/networking/_requests.py +++ b/yt_dlp/networking/_requests.py @@ -10,7 +10,7 @@ import warnings from ..dependencies import brotli, requests, urllib3 from ..utils import bug_reports_message, int_or_none, variadic -from ..utils.networking import normalize_url +from ..utils.networking import normalize_url, select_proxy if requests is None: raise ImportError('requests module is not installed') @@ -41,7 +41,6 @@ from ._helper import ( create_socks_proxy_socket, get_redirect_method, make_socks_proxy_opts, - select_proxy, ) from .common import ( Features, diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py index a188b35f57..cb7a430bb3 100644 --- a/yt_dlp/networking/_urllib.py +++ b/yt_dlp/networking/_urllib.py @@ -26,7 +26,6 @@ from ._helper import ( create_socks_proxy_socket, get_redirect_method, make_socks_proxy_opts, - select_proxy, ) from .common import Features, RequestHandler, Response, register_rh from .exceptions import ( @@ -41,7 +40,7 @@ from .exceptions import ( from ..dependencies import brotli from ..socks import ProxyError as SocksProxyError from ..utils import update_url_query -from ..utils.networking import normalize_url +from ..utils.networking import normalize_url, select_proxy SUPPORTED_ENCODINGS = ['gzip', 'deflate'] CONTENT_DECODE_ERRORS = [zlib.error, OSError] diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index d29f8e45a9..fd8730ac7e 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -11,8 +11,8 @@ from ._helper import ( create_connection, create_socks_proxy_socket, make_socks_proxy_opts, - select_proxy, ) +from ..utils.networking import select_proxy from .common import Features, Response, register_rh from .exceptions import ( CertificateVerifyError, diff --git a/yt_dlp/utils/networking.py b/yt_dlp/utils/networking.py index 542abace87..9fcab6456f 100644 --- a/yt_dlp/utils/networking.py +++ b/yt_dlp/utils/networking.py @@ -10,7 +10,8 @@ import urllib.request if typing.TYPE_CHECKING: T = typing.TypeVar('T') -from ._utils import NO_DEFAULT, remove_start +from ._utils import NO_DEFAULT, remove_start, format_field +from .traversal import traverse_obj def random_user_agent(): @@ -278,3 +279,16 @@ def normalize_url(url): query=escape_rfc3986(url_parsed.query), fragment=escape_rfc3986(url_parsed.fragment), ).geturl() + + +def select_proxy(url, proxies): + """Unified proxy selector for all backends""" + url_components = urllib.parse.urlparse(url) + if 'no' in proxies: + hostport = url_components.hostname + format_field(url_components.port, None, ':%s') + if urllib.request.proxy_bypass_environment(hostport, {'no': proxies['no']}): + return + elif urllib.request.proxy_bypass(hostport): # check system settings + return + + return traverse_obj(proxies, url_components.scheme or 'http', 'all')