Source code for nlpstack.workflow.base

import argparse
import inspect
import typing
from typing import Callable, ClassVar, Dict, Iterator, Optional, Sequence, Type, Union


[docs]class Workflow: _registry: ClassVar[Dict[str, Type["Workflow"]]] = {}
[docs] @classmethod def register(self, name: str, exist_ok: bool = False) -> Callable[[Type["Workflow"]], Type["Workflow"]]: def wrapper(workflow: Type["Workflow"]) -> Type["Workflow"]: if not exist_ok and name in self._registry: raise ValueError(f"Workflow '{name}' was already registered.") self._registry[name] = workflow return workflow return wrapper
[docs] @classmethod def by_name(cls, name: str) -> Type["Workflow"]: return cls._registry[name]
[docs] @classmethod def available_names(cls) -> Sequence[str]: return list(cls._registry)
@staticmethod def _setup_parser(parser: argparse.ArgumentParser, func: Callable) -> argparse.ArgumentParser: parser.set_defaults(__func=func) parser.description = func.__doc__ signature = inspect.signature(func) for name, param in signature.parameters.items(): if name == "self": continue arg_type = param.annotation if param.annotation != inspect.Parameter.empty else str optional = param.default != inspect.Parameter.empty default = param.default if optional else None origin = typing.get_origin(arg_type) args = typing.get_args(arg_type) if origin == Union and len(args) == 2 and args[1] == type(None): # noqa: E721 arg_type = args[0] optional = True default = None positional = param.kind in ( param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD, ) help_message = f"{arg_type.__name__}" if arg_type else "str" if optional: help_message += f" (default: {default})" elif not positional: help_message += " (required)" argparse_kwargs = { "help": help_message, "type": arg_type, } if optional: argparse_kwargs["default"] = default elif not positional: argparse_kwargs["required"] = True if arg_type == bool and default is not None: argparse_kwargs.pop("type") argparse_kwargs["action"] = "store_false" if default else "store_true" if positional: parser.add_argument(name, **argparse_kwargs) else: name = name.replace("_", "-") parser.add_argument("--" + name, **argparse_kwargs) return parser @classmethod def _collect_methods(cls) -> Iterator[Callable]: for name, func in inspect.getmembers(cls, predicate=inspect.isfunction): if not name.startswith("_") and not inspect.isclass(func): yield func
[docs] @classmethod def build_parser(cls) -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description=cls.__doc__) subparsers = parser.add_subparsers() for func in cls._collect_methods(): subparser = subparsers.add_parser(func.__name__, help=func.__doc__) cls._setup_parser(subparser, func) return parser
[docs] @classmethod def run(cls, args: Optional[Sequence[str]] = None) -> None: args = args or ["--help"] parser = cls.build_parser() namespace = parser.parse_args(args) params = vars(namespace) func = params.pop("__func") kwargs = {k.replace("-", "_"): v for k, v in params.items()} func(cls(), **kwargs)