[utils] Update traverse_obj() from yt-dlp

* remove `is_user_input` option per https://github.com/yt-dlp/yt-dlp/pull/8673
* support traversal of compat_xml_etree_ElementTree_Element per https://github.com/yt-dlp/yt-dlp/pull/8911
* allow un/branching using all and any per https://github.com/yt-dlp/yt-dlp/pull/9571
* support traversal of compat_cookies.Morsel and multiple types in `set()` keys per https://github.com/yt-dlp/yt-dlp/pull/9577
thx Grub4k for these
* also, move traversal tests to a separate class
* allow for unordered dicts in tests for Py<3.7
This commit is contained in:
dirkf
2024-04-21 23:42:08 +01:00
parent a08f2b7e45
commit 06da64ee51
3 changed files with 267 additions and 101 deletions

View File

@ -49,11 +49,14 @@ from .compat import (
compat_cookiejar,
compat_ctypes_WINFUNCTYPE,
compat_datetime_timedelta_total_seconds,
compat_etree_Element,
compat_etree_fromstring,
compat_etree_iterfind,
compat_expanduser,
compat_html_entities,
compat_html_entities_html5,
compat_http_client,
compat_http_cookies,
compat_integer_types,
compat_kwargs,
compat_ncompress as ncompress,
@ -6253,15 +6256,16 @@ if __debug__:
def traverse_obj(obj, *paths, **kwargs):
"""
Safely traverse nested `dict`s and `Iterable`s
Safely traverse nested `dict`s and `Iterable`s, etc
>>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key"))
"value"
'value'
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
Supported values for traversal are `Mapping`, `Iterable`, `re.Match`, `xml.etree.ElementTree`
(xpath) and `http.cookies.Morsel`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@ -6269,8 +6273,9 @@ def traverse_obj(obj, *paths, **kwargs):
The keys in the path can be one of:
- `None`: Return the current object.
- `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{func}`. If a `type`, returns only values
of this type. If a function, returns `func(obj)`.
like `{type}`/`{type, type, ...}`/`{func}`. If one or more `type`s,
return only values that have one of the types. If a function,
return `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values.
@ -6282,8 +6287,10 @@ def traverse_obj(obj, *paths, **kwargs):
For `Iterable`s, `key` is the enumeration count of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict.
- `dict`: Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
- `any`-builtin: Take the first matching object and return it, resetting branching.
- `all`-builtin: Take all matching objects and return them as a list, resetting branching.
`tuple`, `list`, and `dict` all support nested paths and branches.
@ -6299,10 +6306,8 @@ def traverse_obj(obj, *paths, **kwargs):
@param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive.
The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
The following is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
@param _is_user_input Whether the keys are generated from user input.
If `True` strings get converted to `int`/`slice` if needed.
@param _traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be
converted into a string and then traversed into.
@ -6322,7 +6327,6 @@ def traverse_obj(obj, *paths, **kwargs):
expected_type = kwargs.get('expected_type')
get_all = kwargs.get('get_all', True)
casesense = kwargs.get('casesense', True)
_is_user_input = kwargs.get('_is_user_input', False)
_traverse_string = kwargs.get('_traverse_string', False)
# instant compat
@ -6336,10 +6340,8 @@ def traverse_obj(obj, *paths, **kwargs):
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def lookup_or_none(v, k, getter=None):
try:
with compat_contextlib_suppress(LookupError):
return getter(v, k) if getter else v[k]
except IndexError:
return None
def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
@ -6361,12 +6363,13 @@ def traverse_obj(obj, *paths, **kwargs):
result = obj
elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key))
if isinstance(item, type):
result = obj if isinstance(obj, item) else None
assert len(key) >= 1, 'At least one item is required in a `set` key'
if all(isinstance(item, type) for item in key):
result = obj if isinstance(obj, tuple(key)) else None
else:
result = try_call(item, args=(obj,))
item = next(iter(key))
assert len(key) == 1, 'Multiple items in a `set` key must all be types'
result = try_call(item, args=(obj,)) if not isinstance(item, type) else None
elif isinstance(key, (list, tuple)):
branching = True
@ -6375,9 +6378,11 @@ def traverse_obj(obj, *paths, **kwargs):
elif key is Ellipsis:
branching = True
if isinstance(obj, compat_http_cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, compat_collections_abc.Mapping):
result = obj.values()
elif is_iterable_like(obj):
elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)):
result = obj
elif isinstance(obj, compat_re_Match):
result = obj.groups()
@ -6389,9 +6394,11 @@ def traverse_obj(obj, *paths, **kwargs):
elif callable(key):
branching = True
if isinstance(obj, compat_http_cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, compat_collections_abc.Mapping):
iter_obj = obj.items()
elif is_iterable_like(obj):
elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)):
iter_obj = enumerate(obj)
elif isinstance(obj, compat_re_Match):
iter_obj = itertools.chain(
@ -6413,6 +6420,8 @@ def traverse_obj(obj, *paths, **kwargs):
if v is not None or default is not NO_DEFAULT) or None
elif isinstance(obj, compat_collections_abc.Mapping):
if isinstance(obj, compat_http_cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
result = (try_call(obj.get, args=(key,))
if casesense or try_call(obj.__contains__, args=(key,))
else next((v for k, v in obj.items() if casefold(k) == key), None))
@ -6430,12 +6439,40 @@ def traverse_obj(obj, *paths, **kwargs):
else:
result = None
if isinstance(key, (int, slice)):
if is_iterable_like(obj, compat_collections_abc.Sequence):
if is_iterable_like(obj, (compat_collections_abc.Sequence, compat_etree_Element)):
branching = isinstance(key, slice)
result = lookup_or_none(obj, key)
elif _traverse_string:
result = lookup_or_none(str(obj), key)
elif isinstance(obj, compat_etree_Element) and isinstance(key, str):
xpath, _, special = key.rpartition('/')
if not special.startswith('@') and not special.endswith('()'):
xpath = key
special = None
# Allow abbreviations of relative paths, absolute paths error
if xpath.startswith('/'):
xpath = '.' + xpath
elif xpath and not xpath.startswith('./'):
xpath = './' + xpath
def apply_specials(element):
if special is None:
return element
if special == '@':
return element.attrib
if special.startswith('@'):
return try_call(element.attrib.get, args=(special[1:],))
if special == 'text()':
return element.text
raise SyntaxError('apply_specials is missing case for {0!r}'.format(special))
if xpath:
result = list(map(apply_specials, compat_etree_iterfind(obj, xpath)))
else:
result = apply_specials(obj)
return branching, result if branching else (result,)
def lazy_last(iterable):
@ -6456,17 +6493,18 @@ def traverse_obj(obj, *paths, **kwargs):
key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
if _is_user_input and isinstance(key, str):
if key == ':':
key = Ellipsis
elif ':' in key:
key = slice(*map(int_or_none, key.split(':')))
elif int_or_none(key) is not None:
key = int(key)
if not casesense and isinstance(key, str):
key = compat_casefold(key)
if key in (any, all):
has_branched = False
filtered_objs = (obj for obj in objs if obj not in (None, {}))
if key is any:
objs = (next(filtered_objs, None),)
else:
objs = (list(filtered_objs),)
continue
if __debug__ and callable(key):
# Verify function signature
_try_bind_args(key, None, None)
@ -6505,9 +6543,9 @@ def traverse_obj(obj, *paths, **kwargs):
return None if default is NO_DEFAULT else default
def T(x):
""" For use in yt-dl instead of {type} or set((type,)) """
return set((x,))
def T(*x):
""" For use in yt-dl instead of {type, ...} or set((type, ...)) """
return set(x)
def get_first(obj, keys, **kwargs):