diff options
-rw-r--r-- | dts/framework/params/__init__.py | 359 |
1 files changed, 359 insertions, 0 deletions
diff --git a/dts/framework/params/__init__.py b/dts/framework/params/__init__.py new file mode 100644 index 0000000000..5a6fd93053 --- /dev/null +++ b/dts/framework/params/__init__.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright(c) 2024 Arm Limited + +"""Parameter manipulation module. + +This module provides :class:`Params` which can be used to model any data structure +that is meant to represent any command line parameters. +""" + +from dataclasses import dataclass, fields +from enum import Flag +from typing import ( + Any, + Callable, + Iterable, + Literal, + Reversible, + TypedDict, + TypeVar, + cast, +) + +from typing_extensions import Self + +T = TypeVar("T") + +#: Type for a function taking one argument. +FnPtr = Callable[[Any], Any] +#: Type for a switch parameter. +Switch = Literal[True, None] +#: Type for a yes/no switch parameter. +YesNoSwitch = Literal[True, False, None] + + +def _reduce_functions(funcs: Iterable[FnPtr]) -> FnPtr: + """Reduces an iterable of :attr:`FnPtr` from left to right to a single function. + + If the iterable is empty, the created function just returns its fed value back. + + Args: + funcs: An iterable containing the functions to be chained from left to right. + + Returns: + FnPtr: A function that calls the given functions from left to right. + """ + + def reduced_fn(value): + for fn in funcs: + value = fn(value) + return value + + return reduced_fn + + +def modify_str(*funcs: FnPtr) -> Callable[[T], T]: + """Class decorator modifying the ``__str__`` method with a function created from its arguments. + + The :attr:`FnPtr`s fed to the decorator are executed from left to right in the arguments list + order. + + Args: + *funcs: The functions to chain from left to right. + + Returns: + The decorator. + + Example: + .. code:: python + + @convert_str(hex_from_flag_value) + class BitMask(enum.Flag): + A = auto() + B = auto() + + will allow ``BitMask`` to render as a hexadecimal value. + """ + + def _class_decorator(original_class): + original_class.__str__ = _reduce_functions(funcs) + return original_class + + return _class_decorator + + +def comma_separated(values: Iterable[Any]) -> str: + """Converts an iterable into a comma-separated string. + + Args: + values: An iterable of objects. + + Returns: + A comma-separated list of stringified values. + """ + return ",".join([str(value).strip() for value in values if value is not None]) + + +def bracketed(value: str) -> str: + """Adds round brackets to the input. + + Args: + value: Any string. + + Returns: + A string surrounded by round brackets. + """ + return f"({value})" + + +def str_from_flag_value(flag: Flag) -> str: + """Returns the value from a :class:`enum.Flag` as a string. + + Args: + flag: An instance of :class:`Flag`. + + Returns: + The stringified value of the given flag. + """ + return str(flag.value) + + +def hex_from_flag_value(flag: Flag) -> str: + """Returns the value from a :class:`enum.Flag` converted to hexadecimal. + + Args: + flag: An instance of :class:`Flag`. + + Returns: + The value of the given flag in hexadecimal representation. + """ + return hex(flag.value) + + +class ParamsModifier(TypedDict, total=False): + """Params modifiers dict compatible with the :func:`dataclasses.field` metadata parameter.""" + + #: + Params_short: str + #: + Params_long: str + #: + Params_multiple: bool + #: + Params_convert_value: Reversible[FnPtr] + + +@dataclass +class Params: + """Dataclass that renders its fields into command line arguments. + + The parameter name is taken from the field name by default. The following: + + .. code:: python + + name: str | None = "value" + + is rendered as ``--name=value``. + + Through :func:`dataclasses.field` the resulting parameter can be manipulated by applying + this class' metadata modifier functions. These return regular dictionaries which can be combined + together using the pipe (OR) operator, as used in the example for :meth:`~Params.multiple`. + + To use fields as switches, set the value to ``True`` to render them. If you + use a yes/no switch you can also set ``False`` which would render a switch + prefixed with ``--no-``. Examples: + + .. code:: python + + interactive: Switch = True # renders --interactive + numa: YesNoSwitch = False # renders --no-numa + + Setting ``None`` will prevent it from being rendered. The :attr:`~Switch` type alias is provided + for regular switches, whereas :attr:`~YesNoSwitch` is offered for yes/no ones. + + An instance of a dataclass inheriting ``Params`` can also be assigned to an attribute, + this helps with grouping parameters together. + The attribute holding the dataclass will be ignored and the latter will just be rendered as + expected. + """ + + _suffix = "" + """Holder of the plain text value of Params when called directly. A suffix for child classes.""" + + """========= BEGIN FIELD METADATA MODIFIER FUNCTIONS ========""" + + @staticmethod + def short(name: str) -> ParamsModifier: + """Overrides any parameter name with the given short option. + + Args: + name: The short parameter name. + + Returns: + ParamsModifier: A dictionary for the `dataclasses.field` metadata argument containing + the parameter short name modifier. + + Example: + .. code:: python + + logical_cores: str | None = field(default="1-4", metadata=Params.short("l")) + + will render as ``-l=1-4`` instead of ``--logical-cores=1-4``. + """ + return ParamsModifier(Params_short=name) + + @staticmethod + def long(name: str) -> ParamsModifier: + """Overrides the inferred parameter name to the specified one. + + Args: + name: The long parameter name. + + Returns: + ParamsModifier: A dictionary for the `dataclasses.field` metadata argument containing + the parameter long name modifier. + + Example: + .. code:: python + + x_name: str | None = field(default="y", metadata=Params.long("x")) + + will render as ``--x=y``, but the field is accessed and modified through ``x_name``. + """ + return ParamsModifier(Params_long=name) + + @staticmethod + def multiple() -> ParamsModifier: + """Specifies that this parameter is set multiple times. The parameter type must be a list. + + Returns: + ParamsModifier: A dictionary for the `dataclasses.field` metadata argument containing + the multiple parameters modifier. + + Example: + .. code:: python + + ports: list[int] | None = field( + default_factory=lambda: [0, 1, 2], + metadata=Params.multiple() | Params.long("port") + ) + + will render as ``--port=0 --port=1 --port=2``. + """ + return ParamsModifier(Params_multiple=True) + + @staticmethod + def convert_value(*funcs: FnPtr) -> ParamsModifier: + """Takes in a variable number of functions to convert the value text representation. + + Functions can be chained together, executed from left to right in the arguments list order. + + Args: + *funcs: The functions to chain from left to right. + + Returns: + ParamsModifier: A dictionary for the `dataclasses.field` metadata argument containing + the convert value modifier. + + Example: + .. code:: python + + hex_bitmask: int | None = field( + default=0b1101, + metadata=Params.convert_value(hex) | Params.long("mask") + ) + + will render as ``--mask=0xd``. + """ + return ParamsModifier(Params_convert_value=funcs) + + """========= END FIELD METADATA MODIFIER FUNCTIONS ========""" + + def append_str(self, text: str) -> None: + """Appends a string at the end of the string representation. + + Args: + text: Any text to append at the end of the parameters string representation. + """ + self._suffix += text + + def __iadd__(self, text: str) -> Self: + """Appends a string at the end of the string representation. + + Args: + text: Any text to append at the end of the parameters string representation. + + Returns: + The given instance back. + """ + self.append_str(text) + return self + + @classmethod + def from_str(cls, text: str) -> Self: + """Creates a plain Params object from a string. + + Args: + text: The string parameters. + + Returns: + A new plain instance of :class:`Params`. + """ + obj = cls() + obj.append_str(text) + return obj + + @staticmethod + def _make_switch( + name: str, is_short: bool = False, is_no: bool = False, value: str | None = None + ) -> str: + """Make the string representation of the parameter. + + Args: + name: The name of the parameters. + is_short: If the parameters is short or not. + is_no: If the parameter is negated or not. + value: The value of the parameter. + + Returns: + The complete command line parameter. + """ + prefix = f"{'-' if is_short else '--'}{'no-' if is_no else ''}" + name = name.replace("_", "-") + value = f"{' ' if is_short else '='}{value}" if value else "" + return f"{prefix}{name}{value}" + + def __str__(self) -> str: + """Returns a string of command-line-ready arguments from the class fields.""" + arguments: list[str] = [] + + for field in fields(self): + value = getattr(self, field.name) + modifiers = cast(ParamsModifier, field.metadata) + + if value is None: + continue + + if isinstance(value, Params): + arguments.append(str(value)) + continue + + # take the short modifier, or the long modifier, or infer from field name + switch_name = modifiers.get("Params_short", modifiers.get("Params_long", field.name)) + is_short = "Params_short" in modifiers + + if isinstance(value, bool): + arguments.append(self._make_switch(switch_name, is_short, is_no=(not value))) + continue + + convert = _reduce_functions(modifiers.get("Params_convert_value", [])) + multiple = modifiers.get("Params_multiple", False) + + values = value if multiple else [value] + for value in values: + arguments.append(self._make_switch(switch_name, is_short, value=convert(value))) + + if self._suffix: + arguments.append(self._suffix) + + return " ".join(arguments) |