[core] Add close hook API to YoutubeDL

This commit is contained in:
coletdjnz 2025-04-12 13:24:10 +12:00
parent 8e9a6553a6
commit d14c0fe223
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
3 changed files with 31 additions and 1 deletions

View File

@ -1435,6 +1435,27 @@ class TestYoutubeDL(unittest.TestCase):
FakeYDL().close()
assert all_plugins_loaded.value
def test_close_hooks(self):
# Should call all registered close hooks on close
close_hook_called = False
close_hook_two_called = False
def close_hook():
nonlocal close_hook_called
close_hook_called = True
def close_hook_two():
nonlocal close_hook_two_called
close_hook_two_called = True
ydl = FakeYDL()
ydl.add_close_hook(close_hook)
ydl.add_close_hook(close_hook_two)
ydl.close()
self.assertTrue(close_hook_called, 'Close hook was not called')
self.assertTrue(close_hook_two_called, 'Close hook two was not called')
if __name__ == '__main__':
unittest.main()

View File

@ -640,6 +640,7 @@ class YoutubeDL:
self._printed_messages = set()
self._first_webpage_request = True
self._post_hooks = []
self._close_hooks = []
self._progress_hooks = []
self._postprocessor_hooks = []
self._download_retcode = 0
@ -908,6 +909,10 @@ class YoutubeDL:
"""Add the post hook"""
self._post_hooks.append(ph)
def add_close_hook(self, ch):
"""Add a close hook, called when YoutubeDL.close() is called"""
self._close_hooks.append(ch)
def add_progress_hook(self, ph):
"""Add the download progress hook"""
self._progress_hooks.append(ph)
@ -1016,6 +1021,9 @@ class YoutubeDL:
self._request_director.close()
del self._request_director
for close_hook in self._close_hooks:
close_hook()
def trouble(self, message=None, tb=None, is_error=True):
"""Determine action to take when a download problem appears.

View File

@ -312,7 +312,6 @@ class PoTokenRequestDirector:
return pot_response.po_token
def close(self):
# TODO: hook into ydl close
for provider in self.providers.values():
provider.close()
@ -355,6 +354,8 @@ def initialize_pot_director(ie):
cache=cache,
)
ie._downloader.add_close_hook(director.close)
for provider in _pot_providers.value.values():
settings = traverse_obj(ie._downloader.params, ('extractor_args', f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}'))
logger = YoutubeIEContentProviderLogger(ie, f'pot:{provider.PROVIDER_NAME}', log_level=log_level)