diff --git a/qlib/finco/task.py b/qlib/finco/task.py index f0e38c6824..dc6434bf61 100644 --- a/qlib/finco/task.py +++ b/qlib/finco/task.py @@ -1,8 +1,10 @@ import os from pathlib import Path -from typing import Any, List +import io +from typing import Any, List, Union from jinja2 import Template +import ruamel.yaml as yaml import abc import re import logging @@ -207,7 +209,11 @@ def execute(self): new_task = [] # 1) create a workspace # TODO: we have to make choice between `sl` and `sl-cfg` - new_task.append(CMDTask(cmd_intention=f"Copy folder from {get_tpl_path() / 'sl'} to {self._context_manager.get_context('workspace')}")) + new_task.append( + CMDTask( + cmd_intention=f"Copy folder from {get_tpl_path() / 'sl'} to {self._context_manager.get_context('workspace')}" + ) + ) # 2) CURD on the workspace for name, regex in regex_dict.items(): @@ -249,6 +255,7 @@ class CMDTask(ActionTask): """ This CMD task is responsible for ensuring compatibility across different operating systems. """ + __DEFAULT_WORKFLOW_SYSTEM_PROMPT = """ You are an expert system administrator. Your task is to convert the user's intention into a specific runnable command for a particular system. @@ -271,8 +278,9 @@ def __init__(self, cmd_intention: str, cwd=None): self._output = None def execute(self): - prompt = Template(self.__DEFAULT_WORKFLOW_USER_PROMPT).render(cmd_intention=self.cmd_intention, - user_os=platform.system()) + prompt = Template(self.__DEFAULT_WORKFLOW_USER_PROMPT).render( + cmd_intention=self.cmd_intention, user_os=platform.system() + ) response = APIBackend().build_messages_and_create_chat_completion(prompt, self.__DEFAULT_WORKFLOW_SYSTEM_PROMPT) self._output = subprocess.check_output(response, shell=True, cwd=self.cwd) return [] @@ -535,6 +543,43 @@ def execute(self): return [] +class YamlEditTask(ActionTask): + """This yaml edit task will replace a specific component directly""" + + def __init__(self, file: Union[str, Path], module_path: str, updated_content: str): + """ + + Parameters + ---------- + file + a target file that needs to be modified + module_path + the path to the section that needs to be replaced with `updated_content` + updated_content + The content to replace the original content in `module_path` + """ + self.p = Path(file) + self.module_path = module_path + self.updated_content = updated_content + + def execute(self): + # 1) read original and new content + with self.p.open("r") as f: + config = yaml.safe_load(f) + update_config = yaml.safe_load(io.StringIO(self.updated_content)) + + # 2) locate the module + focus = config + module_list = self.module_path.split(".") + for k in module_list[:-1]: + focus = focus[k] + + # 3) replace the module and save + focus[module_list[-1]] = update_config + with self.p.open("w") as f: + yaml.dump(config, f) + + class SummarizeTask(Task): __DEFAULT_WORKSPACE = "./" diff --git a/qlib/finco/tpl/sl-cfg/workflow_config_ds.yaml b/qlib/finco/tpl/sl-cfg/workflow_config.yaml similarity index 100% rename from qlib/finco/tpl/sl-cfg/workflow_config_ds.yaml rename to qlib/finco/tpl/sl-cfg/workflow_config.yaml diff --git a/qlib/finco/workflow.py b/qlib/finco/workflow.py index 557456ec8e..f380b23da5 100644 --- a/qlib/finco/workflow.py +++ b/qlib/finco/workflow.py @@ -55,7 +55,6 @@ def __init__(self, workspace=None) -> None: self._context.set_context("workspace", self._workspace) self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China a stock market. I want to construct a new dataset covers longer history" - def _confirm_and_rm(self): # if workspace exists, please confirm and remove it. Otherwise exit. if self._workspace.exists(): diff --git a/tests/finco/test_cfg.py b/tests/finco/test_cfg.py index 29b5c40f14..97f83f0c27 100644 --- a/tests/finco/test_cfg.py +++ b/tests/finco/test_cfg.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import unittest +import shutil +import difflib from qlib.finco.tpl import get_tpl_path import ruamel.yaml as yaml @@ -8,9 +10,14 @@ from qlib.utils import init_instance_by_config from qlib.tests import TestAutoData +from pathlib import Path +from qlib.finco.tpl import get_tpl_path +from qlib.finco.task import YamlEditTask + +DIRNAME = Path(__file__).absolute().resolve().parent -class FincoTpl(TestAutoData): +class FincoTpl(TestAutoData): def test_tpl_consistence(self): """Motivation: make sure the configuable template is consistent with the default config""" tpl_p = get_tpl_path() @@ -18,18 +25,47 @@ def test_tpl_consistence(self): config = yaml.safe_load(fp) # init_data_handler hd: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"]) - # NOTE: The config in workflow_config_ds.yaml is generated by the following code: + # NOTE: The config in workflow_config.yaml is generated by the following code: # dump in yaml format to file without auto linebreak # print(yaml.dump(hd.data_loader.fields, width=10000, stream=open("_tmp", "w"))) - with (tpl_p / "sl-cfg" / "workflow_config_ds.yaml").open("rb") as fp: + with (tpl_p / "sl-cfg" / "workflow_config.yaml").open("rb") as fp: config = yaml.safe_load(fp) hd_ds: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"]) self.assertEqual(hd_ds.data_loader.fields, hd.data_loader.fields) - check = hd_ds.fetch().fillna(0.) == hd.fetch().fillna(0.) + check = hd_ds.fetch().fillna(0.0) == hd.fetch().fillna(0.0) self.assertTrue(check.all().all()) + def test_update_yaml(self): + p = get_tpl_path() / "sl" / "workflow_config.yaml" + p_new = DIRNAME / "_test_config.yaml" + shutil.copy(p, p_new) + updated_content = """ +class: LGBModelTest +module_path: qlib.contrib.model.gbdt +kwargs: + loss: mse + colsample_bytree: 1.8879 + learning_rate: 0.3 + subsample: 0.8790 + lambda_l1: 205.7000 + lambda_l2: 580.9769 + max_depth: 9 + num_leaves: 211 + num_threads: 21 +""" + t = YamlEditTask(p_new, "task.model", updated_content) + t.execute() + # NOTE: the formmat is changed by ruamel.yaml, so it can't be compared by text directly.. + # print the diff between p and p_new with difflib + # with p.open("r") as fp: + # content = fp.read() + # with p_new.open("r") as fp: + # content_new = fp.read() + # for line in difflib.unified_diff(content, content_new, fromfile="original", tofile="new", lineterm=""): + # print(line) + if __name__ == "__main__": unittest.main()