You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
1.7 KiB

"""Comparison functions"""
from enum import Enum
from typing import Any, List, Literal, Optional
class Cmp(Enum):
"""Comparison results"""
LT = -1
EQ = 0
GT = 1
def deep_compare(val_a: Any, val_b: Any) -> Cmp:
"""Compare two dictionaries key by key"""
if isinstance(val_a, dict):
for key, elem_a in val_a.items():
elem_b = val_b[key]
cmp = deep_compare(elem_a, elem_b)
if cmp == Cmp.EQ:
return cmp
return compare_arrays(list(val_a.keys()), list(val_b.keys()))
if val_a > val_b:
return Cmp.GT
if val_a < val_b:
return Cmp.LT
return Cmp.EQ
def compare_basic(val_a: Any, val_b: Any, order: Literal['ASC', 'DESC'] = 'ASC') -> Cmp:
"""Compare two basic values"""
cmp = deep_compare(val_a, val_b)
if cmp == Cmp.EQ:
return Cmp.EQ
if order == 'ASC':
return cmp
return Cmp.LT if cmp == Cmp.GT else Cmp.GT
def compare_arrays(
arr_a: List[Any],
arr_b: List[Any],
sort_orders: Optional[List[Literal['ASC', 'DESC']]] = None,
) -> Cmp:
"""Compare two arrays"""
sort_orders = sort_orders or []
for idx, (elem_a, elem_b) in enumerate(zip(arr_a, arr_b)):
sort_order = sort_orders[idx]
except IndexError:
sort_order = 'ASC'
elem_cmp = compare_basic(elem_a, elem_b, sort_order)
if elem_cmp != Cmp.EQ:
return elem_cmp
if len(arr_a) == len(arr_b):
return Cmp.EQ
idx = min(len(arr_a), len(arr_b))
sort_order = sort_orders[idx]
except IndexError:
sort_order = 'ASC'
return compare_basic(len(arr_a), len(arr_b), sort_order)