diff --git a/test/test_traversal.py b/test/test_traversal.py index d48606e99c..bc433029d8 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -481,7 +481,7 @@ class TestTraversalHelpers: 'id': 'name', 'data': 'content', 'url': 'url', - }, all, {subs_list_to_dict}]) == { + }, all, {subs_list_to_dict(lang=None)}]) == { 'de': [{'url': 'https://example.com/subs/de.ass'}], 'en': [{'data': 'content'}], }, 'subs with mandatory items missing should be filtered' @@ -507,6 +507,54 @@ class TestTraversalHelpers: {'url': 'https://example.com/subs/en1', 'ext': 'ext'}, {'url': 'https://example.com/subs/en2', 'ext': 'ext'}, ]}, '`quality` key should sort subtitle list accordingly' + assert traverse_obj([ + {'name': 'de', 'url': 'https://example.com/subs/de.ass'}, + {'name': 'de'}, + {'name': 'en', 'content': 'content'}, + {'url': 'https://example.com/subs/en'}, + ], [..., { + 'id': 'name', + 'url': 'url', + 'data': 'content', + }, all, {subs_list_to_dict(lang='en')}]) == { + 'de': [{'url': 'https://example.com/subs/de.ass'}], + 'en': [ + {'data': 'content'}, + {'url': 'https://example.com/subs/en'}, + ], + }, 'optionally provided lang should be used if no id available' + assert traverse_obj([ + {'name': 1, 'url': 'https://example.com/subs/de1'}, + {'name': {}, 'url': 'https://example.com/subs/de2'}, + {'name': 'de', 'ext': 1, 'url': 'https://example.com/subs/de3'}, + {'name': 'de', 'ext': {}, 'url': 'https://example.com/subs/de4'}, + ], [..., { + 'id': 'name', + 'url': 'url', + 'ext': 'ext', + }, all, {subs_list_to_dict(lang=None)}]) == { + 'de': [ + {'url': 'https://example.com/subs/de3'}, + {'url': 'https://example.com/subs/de4'}, + ], + }, 'non str types should be ignored for id and ext' + assert traverse_obj([ + {'name': 1, 'url': 'https://example.com/subs/de1'}, + {'name': {}, 'url': 'https://example.com/subs/de2'}, + {'name': 'de', 'ext': 1, 'url': 'https://example.com/subs/de3'}, + {'name': 'de', 'ext': {}, 'url': 'https://example.com/subs/de4'}, + ], [..., { + 'id': 'name', + 'url': 'url', + 'ext': 'ext', + }, all, {subs_list_to_dict(lang='de')}]) == { + 'de': [ + {'url': 'https://example.com/subs/de1'}, + {'url': 'https://example.com/subs/de2'}, + {'url': 'https://example.com/subs/de3'}, + {'url': 'https://example.com/subs/de4'}, + ], + }, 'non str types should be replaced by default id' def test_trim_str(self): with pytest.raises(TypeError): @@ -525,7 +573,7 @@ class TestTraversalHelpers: def test_unpack(self): assert unpack(lambda *x: ''.join(map(str, x)))([1, 2, 3]) == '123' assert unpack(join_nonempty)([1, 2, 3]) == '1-2-3' - assert unpack(join_nonempty(delim=' '))([1, 2, 3]) == '1 2 3' + assert unpack(join_nonempty, delim=' ')([1, 2, 3]) == '1 2 3' with pytest.raises(TypeError): unpack(join_nonempty)() with pytest.raises(TypeError): diff --git a/test/test_utils.py b/test/test_utils.py index b5f35736b6..835774a912 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -72,7 +72,6 @@ from yt_dlp.utils import ( intlist_to_bytes, iri_to_uri, is_html, - join_nonempty, js_to_json, limit_length, locked_file, @@ -2158,10 +2157,6 @@ Line 1 assert int_or_none(v=10) == 10, 'keyword passed positional should call function' assert int_or_none(scale=0.1)(10) == 100, 'call after partial application should call the function' - assert callable(join_nonempty(delim=', ')), 'varargs positional should apply partially' - assert callable(join_nonempty()), 'varargs positional should apply partially' - assert join_nonempty(None, delim=', ') == '', 'passed varargs should call the function' - if __name__ == '__main__': unittest.main() diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index b28bb555e1..89c53c39e7 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -216,7 +216,7 @@ def partial_application(func): sig = inspect.signature(func) required_args = [ param.name for param in sig.parameters.values() - if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) if param.default is inspect.Parameter.empty ] @@ -4837,7 +4837,6 @@ def number_of_digits(number): return len('%d' % number) -@partial_application def join_nonempty(*values, delim='-', from_dict=None): if from_dict is not None: values = (traversal.traverse_obj(from_dict, variadic(v)) for v in values) diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 361f239ba6..76b51f53d1 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -332,14 +332,14 @@ class _RequiredError(ExtractorError): @typing.overload -def subs_list_to_dict(*, ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ... +def subs_list_to_dict(*, lang: str | None = 'und', ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ... @typing.overload -def subs_list_to_dict(subs: list[dict] | None, /, *, ext: str | None = None) -> dict[str, list[dict]]: ... +def subs_list_to_dict(subs: list[dict] | None, /, *, lang: str | None = 'und', ext: str | None = None) -> dict[str, list[dict]]: ... -def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None): +def subs_list_to_dict(subs: list[dict] | None = None, /, *, lang='und', ext=None): """ Convert subtitles from a traversal into a subtitle dict. The path should have an `all` immediately before this function. @@ -352,7 +352,7 @@ def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None): `quality` The sort order for each subtitle """ if subs is None: - return functools.partial(subs_list_to_dict, ext=ext) + return functools.partial(subs_list_to_dict, lang=lang, ext=ext) result = collections.defaultdict(list) @@ -360,10 +360,16 @@ def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None): if not url_or_none(sub.get('url')) and not sub.get('data'): continue sub_id = sub.pop('id', None) - if sub_id is None: - continue - if ext is not None and not sub.get('ext'): - sub['ext'] = ext + if not isinstance(sub_id, str): + if not lang: + continue + sub_id = lang + sub_ext = sub.get('ext') + if not isinstance(sub_ext, str): + if not ext: + sub.pop('ext', None) + else: + sub['ext'] = ext result[sub_id].append(sub) result = dict(result) @@ -452,9 +458,9 @@ def trim_str(*, start=None, end=None): return trim -def unpack(func): +def unpack(func, **kwargs): @functools.wraps(func) - def inner(items, **kwargs): + def inner(items): return func(*items, **kwargs) return inner