diff --git a/ofxtools/scripts/ofxget.py b/ofxtools/scripts/ofxget.py index bcb9058a..e17b8cbc 100644 --- a/ofxtools/scripts/ofxget.py +++ b/ofxtools/scripts/ofxget.py @@ -37,6 +37,7 @@ Iterator, ChainMap, ) +from pathlib import Path # 3rd party imports try: @@ -192,6 +193,9 @@ def add_subparser( default=0, help="Give more output (option can be repeated)", ) + parser.add_argument( + "--config", type=Path, default=None, help="Use custom configuration file" + ), # Higher-level configs (e.g. account #s) # imply lower-level configs (e.g. username/passwd) if stmt: @@ -1596,6 +1600,16 @@ def main() -> None: argparser.print_help() sys.exit() + if hasattr(args_, "config") and isinstance(args_.config, Path): + global USERCONFIGPATH + USERCONFIGPATH = args_.config.resolve() + if not USERCONFIGPATH.exists(): + msg = "Can't find custom configuration file" + logger.error(msg) + raise RuntimeError(msg) + USERCFG.read([CONFIGPATH, USERCONFIGPATH]) + logger.debug(f"Using custom configuration: {USERCONFIGPATH}") + args = merge_config(args_, USERCFG) REQUEST_HANDLERS[args["request"]](args) diff --git a/tests/data/custom.cfg b/tests/data/custom.cfg new file mode 100644 index 00000000..5335e3e5 --- /dev/null +++ b/tests/data/custom.cfg @@ -0,0 +1,3 @@ +[server0] +ofxhome: 0 +url: https://ofx.test.com diff --git a/tests/test_ofxget.py b/tests/test_ofxget.py index 61161352..e797bdf4 100644 --- a/tests/test_ofxget.py +++ b/tests/test_ofxget.py @@ -1457,6 +1457,25 @@ def testSavePasswdNoKeyring(self): # set_password.assert_called_once_with("ofxtools", "myserver", "t0ps3kr1t") +class CustomConfigTestCase(unittest.TestCase): + def testCustomConfig(self): + import os + + this_dir = os.path.dirname(os.path.abspath(__file__)) + config_file = os.path.join(this_dir, "data", "custom.cfg") + + def _merge_config(*args): + return {"verbose": 1, "request": "list", "server": "server0"} + + with patch("sys.argv", ["main", "list", "server0", "--config", config_file]): + with patch.multiple( + "ofxtools.scripts.ofxget", + merge_config=_merge_config, + list_fis=DEFAULT, + ) as MOCKS: + ofxget.main() + + class MainTestCase(unittest.TestCase): def testMain(self): args = argparse.Namespace(verbose=1, request="list")