Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): support echo @starwhale.argument help text #3101

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dataclasses
from enum import Enum
from functools import wraps
from collections import defaultdict

import click

Expand All @@ -28,6 +29,46 @@ def get(cls) -> t.List[str]:
return cls._args or []


class ArgumentContext:
_instance = None
_lock = threading.Lock()

def __init__(self) -> None:
self._click_ctx = click.Context(click.Command("Starwhale Argument Decorator"))
self._options: t.Dict[str, list] = defaultdict(list)

@classmethod
def get_current_context(cls) -> ArgumentContext:
with cls._lock:
if cls._instance is None:
cls._instance = ArgumentContext()
return cls._instance

def add_option(self, option: click.Option, group: str) -> None:
with self._lock:
self._options[group].append(option)

def echo_help(self) -> None:
if not self._options:
click.echo("No options")
return

formatter = self._click_ctx.make_formatter()
formatter.write_heading("\nOptions from Starwhale Argument Decorator")

for group, options in self._options.items():
help_records = []
for option in options:
record = option.get_help_record(self._click_ctx)
if record:
help_records.append(record)

with formatter.section(f"** {group}"):
formatter.write_dl(help_records)

click.echo(formatter.getvalue().rstrip("\n"))


def argument(dataclass_types: t.Any, inject_name: str = "argument") -> t.Any:
"""argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune).

Expand Down Expand Up @@ -68,9 +109,7 @@ def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArgumen
is_sequence = False

def _register_wrapper(func: t.Callable) -> t.Any:
# TODO: add `--help` for the arguments
# TODO: dump parser to json file when model building
# TODO: `@handler` decorator function supports @argument decorator
parser = get_parser_from_dataclasses(dataclass_types)

@wraps(func)
Expand Down Expand Up @@ -113,12 +152,14 @@ def init_dataclasses_values(
for k in inputs:
del args_map[k]
ret.append(dtype(**inputs))

if args_map:
console.warn(f"Unused args from command line: {args_map}")
return ret


def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:
argument_ctx = ArgumentContext.get_current_context()
parser = click.OptionParser()
for dtype in dataclass_types:
if not dataclasses.is_dataclass(dtype):
Expand All @@ -129,13 +170,17 @@ def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:
if not field.init:
continue
field.type = type_hints[field.name]
add_field_into_parser(parser, field)
option = convert_field_to_option(field)
option.add_to_parser(parser=parser, ctx=parser.ctx) # type: ignore
argument_ctx.add_option(
option=option, group=f"{dtype.__module__}.{dtype.__qualname__}"
)

parser.ignore_unknown_options = True
return parser


def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None:
def convert_field_to_option(field: dataclasses.Field) -> click.Option:
# TODO: field.name need format for click option?
decls = [f"--{field.name}"]
if "_" in field.name:
Expand Down Expand Up @@ -220,4 +265,4 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field)
else:
kw["required"] = True

click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore
return click.Option(**kw)
30 changes: 30 additions & 0 deletions client/starwhale/core/model/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import sys
import typing as t
from pathlib import Path

Expand Down Expand Up @@ -472,6 +473,13 @@ def _recover(model: str, force: bool) -> None:
multiple=True,
help="module name, the format is python module import path, handlers will be searched in the module. The option supports set multiple times.",
)
@optgroup.option( # type: ignore[no-untyped-call]
"-sa",
"--show-argument",
is_flag=True,
default=False,
help="[ONLY STANDALONE]Show the argument help info by the @starwhale.argument decorator registered arguments. The help info only analysis the imported modules.",
)
@optgroup.option( # type: ignore[no-untyped-call]
"-f",
"--model-yaml",
Expand Down Expand Up @@ -609,6 +617,7 @@ def _run(
forbid_packaged_runtime: bool,
forbid_snapshot: bool,
cleanup_snapshot: bool,
show_argument: bool,
) -> None:
"""Run Model.
Model Package and the model source directory are supported.
Expand Down Expand Up @@ -644,9 +653,15 @@ def _run(
\b
# --> run with finetune validation dataset
swcli model run --workdir . -m mnist.finetune --dataset mnist --val-dataset mnist-val

\b
# --> echo the argument help info by the @starwhale argument decorator
swcli model run --workdir . -m mnist.finetune --show-argument
swcli model run --uri mnist --show-argument
"""
from starwhale.api.argument import ExtraCliArgsRegistry

# TODO: currently, ExtraCliArgsRegistry must be set before the model run. We will find a better way to set it, such as ctx hooking.
ExtraCliArgsRegistry.set(ctx.args)

# TODO: support run model in cluster mode
Expand Down Expand Up @@ -698,6 +713,21 @@ def _run(
forbid_packaged_runtime=forbid_packaged_runtime,
)

if show_argument:
search_modules = model_config.run.modules
if not search_modules:
click.echo(
"no modules specified, please use `--module` option to set search modules"
)
sys.exit(1)

ModelTermView.show_argument(
model_src_dir=model_src_dir,
search_modules=search_modules,
runtime_uri=runtime_uri,
)
return

if in_container:
ModelTermView.run_in_container(
model_src_dir=model_src_dir,
Expand Down
20 changes: 20 additions & 0 deletions client/starwhale/core/model/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,26 @@ def run_in_server(

return ok, version_or_reason

@classmethod
@BaseTermView._only_standalone
def show_argument(
cls,
model_src_dir: Path | str,
search_modules: t.List[str],
runtime_uri: t.Optional[Resource] = None,
) -> None:
if runtime_uri:
RuntimeProcess(uri=runtime_uri).run()
else:
from starwhale.api._impl.argument import ArgumentContext
from starwhale.api._impl.job.handler import Handler

Handler._preload_registering_handlers(
search_modules=search_modules, package_dir=Path(model_src_dir)
)
ctx = ArgumentContext.get_current_context()
ctx.echo_help()

@classmethod
@BaseTermView._only_standalone
def run_in_host(
Expand Down
44 changes: 44 additions & 0 deletions client/tests/sdk/test_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import typing as t
import dataclasses
from enum import Enum
from unittest.mock import patch, MagicMock

import click
from pyfakefs.fake_filesystem_unittest import TestCase

from starwhale.api._impl.argument import argument as argument_decorator
from starwhale.api._impl.argument import (
ArgumentContext,
ExtraCliArgsRegistry,
get_parser_from_dataclasses,
)
Expand Down Expand Up @@ -67,6 +69,7 @@ def setUp(self) -> None:

def tearDown(self) -> None:
ExtraCliArgsRegistry._args = None
ArgumentContext._instance = None

def test_argument_exceptions(self) -> None:
@argument_decorator(ScalarArguments)
Expand Down Expand Up @@ -157,6 +160,14 @@ def test_scalar_parser(self) -> None:
assert scalar_parser._long_opt["--epoch"].obj.type == click.INT
assert scalar_parser._long_opt["--epoch"].obj.default == 1

argument_ctx = ArgumentContext.get_current_context()
assert len(argument_ctx._options) == 1
options = argument_ctx._options["tests.sdk.test_argument.ScalarArguments"]
assert len(options) == 5
assert options[0].name == "batch"
assert options[-1].name == "epoch"
argument_ctx.echo_help()

def test_compose_parser(self) -> None:
compose_parser = get_parser_from_dataclasses([ComposeArguments])

Expand Down Expand Up @@ -199,3 +210,36 @@ def test_compose_parser(self) -> None:
assert not optional_list_obj.required
assert optional_list_obj.multiple
assert optional_list_obj.default is None

argument_ctx = ArgumentContext.get_current_context()
assert len(argument_ctx._options) == 1
options = argument_ctx._options["tests.sdk.test_argument.ComposeArguments"]
assert len(options) == 6
assert options[0].name == "debug"
argument_ctx.echo_help()

@patch("click.echo")
def test_argument_help_output(self, mock_echo: MagicMock):
@argument_decorator((ScalarArguments, ComposeArguments))
def mock_func(starwhale_argument: t.Tuple) -> None:
...

ArgumentContext.get_current_context().echo_help()
help_output = mock_echo.call_args[0][0]
cases = [
"tests.sdk.test_argument.ScalarArguments:",
"--batch INTEGER",
"--overwrite",
"--learning_rate, --learning-rate FLOAT",
"--half_precision_backend, --half-precision-backend TEXT",
"--epoch INTEGER",
"tests.sdk.test_argument.ComposeArguments:",
"--debug DEBUGOPTION",
"--lr_scheduler_kwargs, --lr-scheduler-kwargs DICT",
"--evaluation_strategy, --evaluation-strategy [no|steps|epoch]",
"--per_gpu_train_batch_size, --per-gpu-train-batch-size INTEGER",
"--eval_delay, --eval-delay FLOAT",
"--label_names, --label-names TEXT",
]
for case in cases:
assert case in help_output
Loading