Skip to content

Commit

Permalink
support echo @starwhale.argument help text
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Jan 2, 2024
1 parent e6a498b commit f7aeb9f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 5 deletions.
55 changes: 50 additions & 5 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dataclasses
from enum import Enum
from functools import wraps
from collections import defaultdict

import click

Expand All @@ -28,6 +29,46 @@ def get(cls) -> t.List[str]:
return cls._args or []


class ArgumentContext:
_instance = None
_lock = threading.Lock()

def __init__(self) -> None:
self._click_ctx = click.Context(click.Command("Starwhale Argument Decorator"))
self._options: t.Dict[str, list] = defaultdict(list)

@classmethod
def get_current_context(cls) -> ArgumentContext:
with cls._lock:
if cls._instance is None:
cls._instance = ArgumentContext()
return cls._instance

def add_option(self, option: click.Option, group: str) -> None:
with self._lock:
self._options[group].append(option)

def echo_help(self) -> None:
if not self._options:
click.echo("No options")
return

formatter = self._click_ctx.make_formatter()
formatter.write_heading("\nOptions from Starwhale Argument Decorator")

for group, options in self._options.items():
help_records = []
for option in options:
record = option.get_help_record(self._click_ctx)
if record:
help_records.append(record)

with formatter.section(f"** {group}"):
formatter.write_dl(help_records)

click.echo(formatter.getvalue().rstrip("\n"))


def argument(dataclass_types: t.Any, inject_name: str = "argument") -> t.Any:
"""argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune).
Expand Down Expand Up @@ -68,9 +109,7 @@ def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArgumen
is_sequence = False

def _register_wrapper(func: t.Callable) -> t.Any:
# TODO: add `--help` for the arguments
# TODO: dump parser to json file when model building
# TODO: `@handler` decorator function supports @argument decorator
parser = get_parser_from_dataclasses(dataclass_types)

@wraps(func)
Expand Down Expand Up @@ -113,12 +152,14 @@ def init_dataclasses_values(
for k in inputs:
del args_map[k]
ret.append(dtype(**inputs))

if args_map:
console.warn(f"Unused args from command line: {args_map}")
return ret


def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:
argument_ctx = ArgumentContext.get_current_context()
parser = click.OptionParser()
for dtype in dataclass_types:
if not dataclasses.is_dataclass(dtype):
Expand All @@ -129,13 +170,17 @@ def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:
if not field.init:
continue
field.type = type_hints[field.name]
add_field_into_parser(parser, field)
option = convert_field_to_option(field)
option.add_to_parser(parser=parser, ctx=parser.ctx) # type: ignore
argument_ctx.add_option(
option=option, group=f"{dtype.__module__}.{dtype.__qualname__}"
)

parser.ignore_unknown_options = True
return parser


def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None:
def convert_field_to_option(field: dataclasses.Field) -> click.Option:
# TODO: field.name need format for click option?
decls = [f"--{field.name}"]
if "_" in field.name:
Expand Down Expand Up @@ -220,4 +265,4 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field)
else:
kw["required"] = True

click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore
return click.Option(**kw)
30 changes: 30 additions & 0 deletions client/starwhale/core/model/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import sys
import typing as t
from pathlib import Path

Expand Down Expand Up @@ -472,6 +473,13 @@ def _recover(model: str, force: bool) -> None:
multiple=True,
help="module name, the format is python module import path, handlers will be searched in the module. The option supports set multiple times.",
)
@optgroup.option( # type: ignore[no-untyped-call]
"-sa",
"--show-argument",
is_flag=True,
default=False,
help="[ONLY STANDALONE]Show the argument help info by the @starwhale.argument decorator registered arguments. The help info only analysis the imported modules.",
)
@optgroup.option( # type: ignore[no-untyped-call]
"-f",
"--model-yaml",
Expand Down Expand Up @@ -609,6 +617,7 @@ def _run(
forbid_packaged_runtime: bool,
forbid_snapshot: bool,
cleanup_snapshot: bool,
show_argument: bool,
) -> None:
"""Run Model.
Model Package and the model source directory are supported.
Expand Down Expand Up @@ -644,9 +653,15 @@ def _run(
\b
# --> run with finetune validation dataset
swcli model run --workdir . -m mnist.finetune --dataset mnist --val-dataset mnist-val
\b
# --> echo the argument help info by the @starwhale argument decorator
swcli model run --workdir . -m mnist.finetune --show-argument
swcli model run --uri mnist --show-argument
"""
from starwhale.api.argument import ExtraCliArgsRegistry

# TODO: currently, ExtraCliArgsRegistry must be set before the model run. We will find a better way to set it, such as ctx hooking.
ExtraCliArgsRegistry.set(ctx.args)

# TODO: support run model in cluster mode
Expand Down Expand Up @@ -698,6 +713,21 @@ def _run(
forbid_packaged_runtime=forbid_packaged_runtime,
)

if show_argument:
search_modules = model_config.run.modules
if not search_modules:
click.echo(
"no modules specified, please use `--module` option to set search modules"
)
sys.exit(1)

ModelTermView.show_argument(
model_src_dir=model_src_dir,
search_modules=search_modules,
runtime_uri=runtime_uri,
)
return

if in_container:
ModelTermView.run_in_container(
model_src_dir=model_src_dir,
Expand Down
20 changes: 20 additions & 0 deletions client/starwhale/core/model/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,26 @@ def run_in_server(

return ok, version_or_reason

@classmethod
@BaseTermView._only_standalone
def show_argument(
cls,
model_src_dir: Path | str,
search_modules: t.List[str],
runtime_uri: t.Optional[Resource] = None,
) -> None:
if runtime_uri:
RuntimeProcess(uri=runtime_uri).run()
else:
from starwhale.api._impl.argument import ArgumentContext
from starwhale.api._impl.job.handler import Handler

Handler._preload_registering_handlers(
search_modules=search_modules, package_dir=Path(model_src_dir)
)
ctx = ArgumentContext.get_current_context()
ctx.echo_help()

@classmethod
@BaseTermView._only_standalone
def run_in_host(
Expand Down

0 comments on commit f7aeb9f

Please sign in to comment.