From 3b33efc7a2989d1a962815a7995f112313750938 Mon Sep 17 00:00:00 2001 From: Lorenzo Gabriele Date: Mon, 28 Feb 2022 22:12:05 +0100 Subject: [PATCH] Partial support for ParserForClass --- ...rserForClassCompanionVersionSpecific.scala | 2 +- ...erForMethodsCompanionVersionSpecific.scala | 2 +- mainargs/src-3/Macros.scala | 100 ++++++---- ...rserForClassCompanionVersionSpecific.scala | 2 +- mainargs/test/src-2/ClassTests.scala | 101 ---------- mainargs/test/src-2/OldVarargsTests.scala | 7 +- mainargs/test/src-2/ParserTests.scala | 53 ------ mainargs/test/src-2/VersionSpecific.scala | 5 + mainargs/test/src-3/VersionSpecific.scala | 5 + mainargs/test/src-jvm-2/AmmoniteTests.scala | 8 +- mainargs/test/src-jvm-2/MillTests.scala | 2 - mainargs/test/src/ClassTests.scala | 173 ++++++++++++++++++ mainargs/test/{src-2 => src}/ManyTests.scala | 0 mainargs/test/src/NewVarargsTests.scala | 10 +- mainargs/test/src/ParserTests.scala | 61 ++++++ mainargs/test/src/TestUtils.scala | 7 + mainargs/test/src/VarargsTests.scala | 134 ++++++++------ 17 files changed, 405 insertions(+), 267 deletions(-) delete mode 100644 mainargs/test/src-2/ClassTests.scala delete mode 100644 mainargs/test/src-2/ParserTests.scala create mode 100644 mainargs/test/src-2/VersionSpecific.scala create mode 100644 mainargs/test/src-3/VersionSpecific.scala create mode 100644 mainargs/test/src/ClassTests.scala rename mainargs/test/{src-2 => src}/ManyTests.scala (100%) create mode 100644 mainargs/test/src/ParserTests.scala create mode 100644 mainargs/test/src/TestUtils.scala diff --git a/mainargs/src-2/ParserForClassCompanionVersionSpecific.scala b/mainargs/src-2/ParserForClassCompanionVersionSpecific.scala index 3649936..7e8feb4 100644 --- a/mainargs/src-2/ParserForClassCompanionVersionSpecific.scala +++ b/mainargs/src-2/ParserForClassCompanionVersionSpecific.scala @@ -4,6 +4,6 @@ import acyclic.skipped import scala.language.experimental.macros -private [mainargs] trait ParserForClassCompanionVersionSpecific { +private[mainargs] trait ParserForClassCompanionVersionSpecific { def apply[T]: ParserForClass[T] = macro Macros.parserForClass[T] } diff --git a/mainargs/src-2/ParserForMethodsCompanionVersionSpecific.scala b/mainargs/src-2/ParserForMethodsCompanionVersionSpecific.scala index cff9d25..4e43163 100644 --- a/mainargs/src-2/ParserForMethodsCompanionVersionSpecific.scala +++ b/mainargs/src-2/ParserForMethodsCompanionVersionSpecific.scala @@ -4,6 +4,6 @@ import acyclic.skipped import scala.language.experimental.macros -private [mainargs] trait ParserForMethodsCompanionVersionSpecific { +private[mainargs] trait ParserForMethodsCompanionVersionSpecific { def apply[B](base: B): ParserForMethods[B] = macro Macros.parserForMethods[B] } diff --git a/mainargs/src-3/Macros.scala b/mainargs/src-3/Macros.scala index 8500a62..4c0e790 100644 --- a/mainargs/src-3/Macros.scala +++ b/mainargs/src-3/Macros.scala @@ -3,48 +3,17 @@ package mainargs import scala.quoted._ object Macros { + private def mainAnnotation(using Quotes) = quotes.reflect.TypeRepr.of[mainargs.main].typeSymbol + private def argAnnotation(using Quotes) = quotes.reflect.TypeRepr.of[mainargs.arg].typeSymbol def parserForMethods[B](base: Expr[B])(using Quotes, Type[B]): Expr[ParserForMethods[B]] = { import quotes.reflect._ val allMethods = TypeRepr.of[B].typeSymbol.memberMethods - val mainAnnotation = TypeRepr.of[mainargs.main].typeSymbol - val argAnnotation = TypeRepr.of[mainargs.arg].typeSymbol val annotatedMethodsWithMainAnnotations = allMethods.flatMap { methodSymbol => methodSymbol.getAnnotation(mainAnnotation).map(methodSymbol -> _) }.sortBy(_._1.pos.map(_.start)) - val mainDatasExprs: Seq[Expr[MainData[Any, B]]] = annotatedMethodsWithMainAnnotations.map { (annotatedMethod, mainAnnotation) => - val params = annotatedMethod.paramSymss.headOption.getOrElse(throw new Exception("Multiple parameter lists not supported")) - val defaultParams = getDefaultParams(annotatedMethod) - val argSigs = Expr.ofList(params.map { param => - val paramTree = param.tree.asInstanceOf[ValDef] - val paramTpe = paramTree.tpt.tpe - val arg = param.getAnnotation(argAnnotation).map(_.asExpr.asInstanceOf[Expr[mainargs.arg]]).getOrElse('{ new mainargs.arg() }) - val paramType = paramTpe.asType - paramType match - case '[t] => - val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match { - case Some(v) => '{ Some(((_: B) => $v).asInstanceOf[B => t]) } - case None => '{ None } - } - val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse{ - report.error( - s"No mainargs.ArgReader of ${paramTpe.typeSymbol.fullName} found for parameter ${param.name}", - param.pos.get - ) - '{ ??? } - } - '{ ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader }).asInstanceOf[mainargs.ArgSig[Any, B]] } - }) - - val invokeRaw: Expr[(B, Seq[Any]) => Any] = { - def callOf(args: Expr[Seq[Any]]) = call(annotatedMethod, '{ Seq( ${ args }) }) - '{ (b: B, params: Seq[Any]) => - ${ callOf('{ params }) } - } - } - - '{ MainData.create[Any, B](${ Expr(annotatedMethod.name) }, ${ mainAnnotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) } - } - val mainDatas = Expr.ofList(mainDatasExprs) + val mainDatas = Expr.ofList(annotatedMethodsWithMainAnnotations.map { (annotatedMethod, mainAnnotationInstance) => + createMainData[Any, B](annotatedMethod, mainAnnotationInstance) + }) '{ new ParserForMethods[B]( @@ -53,6 +22,65 @@ object Macros { } } + def parserForClass[B](using Quotes, Type[B]): Expr[ParserForClass[B]] = { + import quotes.reflect._ + val typeReprOfB = TypeRepr.of[B] + val companionModule = typeReprOfB match { + case TypeRef(a,b) => TermRef(a,b) + } + val typeSymbolOfB = typeReprOfB.typeSymbol + val companionModuleType = typeSymbolOfB.companionModule.tree.asInstanceOf[ValDef].tpt.tpe.asType + val companionModuleExpr = Ident(companionModule).asExpr + val mainAnnotationInstance = typeSymbolOfB.getAnnotation(mainAnnotation).getOrElse { + report.error( + s"cannot find @main annotation on ${companionModule.name}", + typeSymbolOfB.pos.get + ) + ??? + } + val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head + companionModuleType match + case '[bCompanion] => + val mainData = createMainData[B, bCompanion](annotatedMethod, mainAnnotationInstance) + '{ + new ParserForClass[B]( + ClassMains[B](${ mainData }.asInstanceOf[MainData[B, Any]], () => ${ Ident(companionModule).asExpr }) + ) + } + } + + def createMainData[T: Type, B: Type](using Quotes)(method: quotes.reflect.Symbol, annotation: quotes.reflect.Term): Expr[MainData[T, B]] = { + import quotes.reflect.* + val params = method.paramSymss.headOption.getOrElse(throw new Exception("Multiple parameter lists not supported")) + val defaultParams = getDefaultParams(method) + val argSigs = Expr.ofList(params.map { param => + val paramTree = param.tree.asInstanceOf[ValDef] + val paramTpe = paramTree.tpt.tpe + val arg = param.getAnnotation(argAnnotation).map(_.asExpr.asInstanceOf[Expr[mainargs.arg]]).getOrElse('{ new mainargs.arg() }) + val paramType = paramTpe.asType + paramType match + case '[t] => + val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match { + case Some(v) => '{ Some(((_: B) => $v).asInstanceOf[B => t]) } + case None => '{ None } + } + val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse{ + report.error( + s"No mainargs.ArgReader of ###companionModule### found for parameter ${param.name}", + param.pos.get + ) + '{ ??? } + } + '{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader })).asInstanceOf[ArgSig[Any, B]] } + }) + + val invokeRaw: Expr[(B, Seq[Any]) => T] = { + def callOf(args: Expr[Seq[Any]]) = call(method, '{ Seq( ${ args }) }) + '{ ((b: B, params: Seq[Any]) => ${ callOf('{ params }) }).asInstanceOf[(B, Seq[Any]) => T] } + } + '{ MainData.create[T, B](${ Expr(method.name) }, ${ annotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) } + } + /** Call a method given by its symbol. * * E.g. diff --git a/mainargs/src-3/ParserForClassCompanionVersionSpecific.scala b/mainargs/src-3/ParserForClassCompanionVersionSpecific.scala index 01af670..ae1ac2d 100644 --- a/mainargs/src-3/ParserForClassCompanionVersionSpecific.scala +++ b/mainargs/src-3/ParserForClassCompanionVersionSpecific.scala @@ -3,5 +3,5 @@ package mainargs import scala.language.experimental.macros private [mainargs] trait ParserForClassCompanionVersionSpecific { - inline def apply[T]: ParserForClass[T] = ??? + inline def apply[T]: ParserForClass[T] = ${ Macros.parserForClass[T] } } diff --git a/mainargs/test/src-2/ClassTests.scala b/mainargs/test/src-2/ClassTests.scala deleted file mode 100644 index ef2a655..0000000 --- a/mainargs/test/src-2/ClassTests.scala +++ /dev/null @@ -1,101 +0,0 @@ -package mainargs -import utest._ - - -object ClassTests extends TestSuite{ - - @main - case class Foo(x: Int, y: Int) - - @main - case class Bar(w: Flag = Flag(), f: Foo, @arg(short = 'z') zzzz: String) - - @main - case class Qux(moo: String, b: Bar) - - implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo] - implicit val barParser: ParserForClass[Bar] = ParserForClass[Bar] - implicit val quxParser: ParserForClass[Qux] = ParserForClass[Qux] - - object Main{ - @main - def run(bar: Bar, - bool: Boolean = false) = { - s"${bar.w.value} ${bar.f.x} ${bar.f.y} ${bar.zzzz} $bool" - } - } - - val tests = Tests { - test("simple") { - test("success"){ - fooParser.constructOrThrow(Seq("-x", "1", "-y", "2")) ==> Foo(1, 2) - } - test("missing") { - fooParser.constructRaw(Seq("-x", "1")) ==> - Result.Failure.MismatchedArguments( - Seq(ArgSig.Simple(None, Some('y'),None,None,mainargs.TokensReader.IntRead, false)), - List(), - List(), - None - ) - - } - } - - test("nested") { - test("success"){ - barParser.constructOrThrow(Seq("-w", "-x", "1", "-y", "2", "--zzzz", "xxx")) ==> - Bar(Flag(true), Foo(1, 2), "xxx") - } - test("missingInner"){ - barParser.constructRaw(Seq("-w", "-x", "1", "-z", "xxx")) ==> - Result.Failure.MismatchedArguments( - Seq(ArgSig.Simple(None,Some('y'),None,None,mainargs.TokensReader.IntRead, false)), - List(), - List(), - None - ) - } - test("missingOuter"){ - barParser.constructRaw(Seq("-w", "-x", "1", "-y", "2")) ==> - Result.Failure.MismatchedArguments( - Seq(ArgSig.Simple(Some("zzzz"),Some('z'),None,None,mainargs.TokensReader.StringRead, false)), - List(), - List(), - None - ) - } - - test("missingInnerOuter"){ - barParser.constructRaw(Seq("-w", "-x", "1")) ==> - Result.Failure.MismatchedArguments( - Seq( - ArgSig.Simple(None,Some('y'),None,None,mainargs.TokensReader.IntRead, false), - ArgSig.Simple(Some("zzzz"),Some('z'),None,None,mainargs.TokensReader.StringRead, false) - ), - List(), - List(), - None - ) - } - test("failedInnerOuter") { - assertMatch(barParser.constructRaw(Seq("-w","-x", "xxx", "-y", "hohoho", "-z", "xxx"))) { - case Result.Failure.InvalidArguments( - Seq( - Result.ParamError.Failed(ArgSig.Simple(None, Some('x'), None, None, _, false), Seq("xxx"), _), - Result.ParamError.Failed(ArgSig.Simple(None, Some('y'), None, None, _, false), Seq("hohoho"), _) - ) - ) => - } - } - } - - test("doubleNested"){ - quxParser.constructOrThrow(Seq("-w", "-x", "1", "-y", "2", "-z", "xxx", "--moo", "cow")) ==> - Qux("cow", Bar(Flag(true), Foo(1, 2), "xxx")) - } - test("success"){ - ParserForMethods(Main).runOrThrow(Seq("-x", "1", "-y", "2", "-z", "hello")) ==> "false 1 2 hello false" - } - } -} diff --git a/mainargs/test/src-2/OldVarargsTests.scala b/mainargs/test/src-2/OldVarargsTests.scala index 1534ec2..6e7bdef 100644 --- a/mainargs/test/src-2/OldVarargsTests.scala +++ b/mainargs/test/src-2/OldVarargsTests.scala @@ -1,14 +1,15 @@ package mainargs import utest._ -object OldVarargsTests extends VarargsTests{ - object Base{ +object OldVarargsTests extends VarargsTests { + object Base { @main def pureVariadic(nums: Int*) = nums.sum @main - def mixedVariadic(@arg(short = 'f') first: Int, args: String*) = first + args.mkString + def mixedVariadic(@arg(short = 'f') first: Int, args: String*) = + first + args.mkString } val check = new Checker(ParserForMethods(Base), allowPositional = true) diff --git a/mainargs/test/src-2/ParserTests.scala b/mainargs/test/src-2/ParserTests.scala deleted file mode 100644 index ebf26ce..0000000 --- a/mainargs/test/src-2/ParserTests.scala +++ /dev/null @@ -1,53 +0,0 @@ -package mainargs -import utest._ - - -object ParserTests extends TestSuite{ - - object SingleBase{ - @main(doc = "Qux is a function that does stuff") - def run(i: Int, - @arg(doc = "Pass in a custom `s` to override it") - s: String = "lols") = s * i - } - - object MultiBase{ - @main - def foo() = 1 - - @main - def bar(i: Int) = i - } - - @main - case class ClassBase(code: Option[String] = None, other: String = "hello") - - val multiMethodParser = ParserForMethods(MultiBase) - val singleMethodParser = ParserForMethods(SingleBase) - val classParser = ParserForClass[ClassBase] - val tests = Tests { - test("runEitherMulti") { - - test { - multiMethodParser.runEither(Array("foo")) ==> Right(1) - } - test { - multiMethodParser.runEither(Array("bar", "-i", "123")) ==> Right(123) - } - test { - assert( - multiMethodParser.runEither(Array("f")) - .left - .exists(_.contains("Unable to find subcommand: f")) - ) - } - } - test("runEitherSingle"){ - singleMethodParser.runEither(Array("5", "x"), allowPositional = true) ==> Right("xxxxx") - } - test("constructEither"){ - classParser.constructEither(Array("--code", "println(1)")) ==> - Right(ClassBase(code = Some("println(1)"), other = "hello")) - } - } -} diff --git a/mainargs/test/src-2/VersionSpecific.scala b/mainargs/test/src-2/VersionSpecific.scala new file mode 100644 index 0000000..807759b --- /dev/null +++ b/mainargs/test/src-2/VersionSpecific.scala @@ -0,0 +1,5 @@ +package mainargs + +object VersionSpecific { + val isScala3 = false +} diff --git a/mainargs/test/src-3/VersionSpecific.scala b/mainargs/test/src-3/VersionSpecific.scala new file mode 100644 index 0000000..aed72d3 --- /dev/null +++ b/mainargs/test/src-3/VersionSpecific.scala @@ -0,0 +1,5 @@ +package mainargs + +object VersionSpecific { + val isScala3 = true +} diff --git a/mainargs/test/src-jvm-2/AmmoniteTests.scala b/mainargs/test/src-jvm-2/AmmoniteTests.scala index f191e9f..aaa9705 100644 --- a/mainargs/test/src-jvm-2/AmmoniteTests.scala +++ b/mainargs/test/src-jvm-2/AmmoniteTests.scala @@ -69,7 +69,7 @@ object AmmoniteConfig{ @arg(doc = "Print this message") help: Flag ) - implicit val coreParser: ParserForClass[Core] = ParserForClass[Core] + implicit val coreParser = ParserForClass[Core] @main case class Predef( @@ -86,7 +86,7 @@ object AmmoniteConfig{ "choose an additional predef to use using `--predef") noHomePredef: Flag ) - implicit val predefParser: ParserForClass[Predef] = ParserForClass[Predef] + implicit val predefParser = ParserForClass[Predef] @main case class Repl( @@ -105,12 +105,12 @@ object AmmoniteConfig{ "friendliness.") classBased: Flag ) - implicit val replParser: ParserForClass[Repl] = ParserForClass[Repl] + implicit val replParser = ParserForClass[Repl] } object AmmoniteTests extends TestSuite{ - val parser: ParserForClass[AmmoniteConfig] = ParserForClass[AmmoniteConfig] + val parser = ParserForClass[AmmoniteConfig] val tests = Tests { diff --git a/mainargs/test/src-jvm-2/MillTests.scala b/mainargs/test/src-jvm-2/MillTests.scala index 03d6810..ffc78a1 100644 --- a/mainargs/test/src-jvm-2/MillTests.scala +++ b/mainargs/test/src-jvm-2/MillTests.scala @@ -1,7 +1,6 @@ // package mainargs // import utest._ - // object MillTests extends TestSuite{ // implicit object PathRead extends TokensReader[os.Path]("path", strs => Right(os.Path(strs.head, os.pwd))) @@ -107,4 +106,3 @@ // } // } // } - diff --git a/mainargs/test/src/ClassTests.scala b/mainargs/test/src/ClassTests.scala new file mode 100644 index 0000000..f12fa1e --- /dev/null +++ b/mainargs/test/src/ClassTests.scala @@ -0,0 +1,173 @@ +package mainargs +import utest._ + +object ClassTests extends TestSuite { + + @main + case class Foo(x: Int, y: Int) + + @main + case class Bar(w: Flag = Flag(), f: Foo, @arg(short = 'z') zzzz: String) + + @main + case class Qux(moo: String, b: Bar) + + implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo] + implicit val barParser: ParserForClass[Bar] = ParserForClass[Bar] + implicit val quxParser: ParserForClass[Qux] = ParserForClass[Qux] + + object Main { + @main + def run(bar: Bar, bool: Boolean = false) = { + s"${bar.w.value} ${bar.f.x} ${bar.f.y} ${bar.zzzz} $bool" + } + } + + val tests = Tests { + test("simple") { + test("success") { + fooParser.constructOrThrow(Seq("-x", "1", "-y", "2")) ==> Foo(1, 2) + } + test("missing") { + fooParser.constructRaw(Seq("-x", "1")) ==> + Result.Failure.MismatchedArguments( + Seq( + ArgSig.Simple( + None, + Some('y'), + None, + None, + mainargs.TokensReader.IntRead, + false + ) + ), + List(), + List(), + None + ) + + } + } + + test("nested") { + test("success") { + barParser.constructOrThrow( + Seq("-w", "-x", "1", "-y", "2", "--zzzz", "xxx") + ) ==> + Bar(Flag(true), Foo(1, 2), "xxx") + } + test("missingInner") { + // Blocked by https://github.com/lampepfl/dotty/issues/12492 + TestUtils.scala2Only { + barParser.constructRaw(Seq("-w", "-x", "1", "-z", "xxx")) ==> + Result.Failure.MismatchedArguments( + Seq( + ArgSig.Simple( + None, + Some('y'), + None, + None, + mainargs.TokensReader.IntRead, + false + ) + ), + List(), + List(), + None + ) + } + } + test("missingOuter") { + // Blocked by https://github.com/lampepfl/dotty/issues/12492 + TestUtils.scala2Only { + barParser.constructRaw(Seq("-w", "-x", "1", "-y", "2")) ==> + Result.Failure.MismatchedArguments( + Seq( + ArgSig.Simple( + Some("zzzz"), + Some('z'), + None, + None, + mainargs.TokensReader.StringRead, + false + ) + ), + List(), + List(), + None + ) + } + } + + test("missingInnerOuter") { + // Blocked by https://github.com/lampepfl/dotty/issues/12492 + TestUtils.scala2Only { + barParser.constructRaw(Seq("-w", "-x", "1")) ==> + Result.Failure.MismatchedArguments( + Seq( + ArgSig.Simple( + None, + Some('y'), + None, + None, + mainargs.TokensReader.IntRead, + false + ), + ArgSig.Simple( + Some("zzzz"), + Some('z'), + None, + None, + mainargs.TokensReader.StringRead, + false + ) + ), + List(), + List(), + None + ) + } + } + test("failedInnerOuter") { + TestUtils.scala2Only { + assertMatch( + barParser.constructRaw( + Seq("-w", "-x", "xxx", "-y", "hohoho", "-z", "xxx") + ) + ) { + case Result.Failure.InvalidArguments( + Seq( + Result.ParamError.Failed( + ArgSig.Simple(None, Some('x'), None, None, _, false), + Seq("xxx"), + _ + ), + Result.ParamError.Failed( + ArgSig.Simple(None, Some('y'), None, None, _, false), + Seq("hohoho"), + _ + ) + ) + ) => + } + } + } + } + + test("doubleNested") { + TestUtils.scala2Only { + quxParser.constructOrThrow( + Seq("-w", "-x", "1", "-y", "2", "-z", "xxx", "--moo", "cow") + ) ==> + Qux("cow", Bar(Flag(true), Foo(1, 2), "xxx")) + } + } + test("success") { + TestUtils.scala2Only { + ParserForMethods(Main).runOrThrow( + Seq("-x", "1", "-y", "2", "-z", "hello") + ) ==> "false 1 2 hello false" + } + } + } +} diff --git a/mainargs/test/src-2/ManyTests.scala b/mainargs/test/src/ManyTests.scala similarity index 100% rename from mainargs/test/src-2/ManyTests.scala rename to mainargs/test/src/ManyTests.scala diff --git a/mainargs/test/src/NewVarargsTests.scala b/mainargs/test/src/NewVarargsTests.scala index 21c44bb..c55065b 100644 --- a/mainargs/test/src/NewVarargsTests.scala +++ b/mainargs/test/src/NewVarargsTests.scala @@ -1,7 +1,7 @@ package mainargs import utest._ -object NewVarargsTests extends VarargsTests{ - object Base{ +object NewVarargsTests extends VarargsTests { + object Base { @main def pureVariadic(nums: Leftover[Int]) = nums.value.sum @@ -10,8 +10,10 @@ object NewVarargsTests extends VarargsTests{ first + args.value.mkString } @main - def mixedVariadicWithDefault(@arg(short = 'f') first: Int = 1337, - args: Leftover[String]) = { + def mixedVariadicWithDefault( + @arg(short = 'f') first: Int = 1337, + args: Leftover[String] + ) = { first + args.value.mkString } } diff --git a/mainargs/test/src/ParserTests.scala b/mainargs/test/src/ParserTests.scala new file mode 100644 index 0000000..ad7e4c1 --- /dev/null +++ b/mainargs/test/src/ParserTests.scala @@ -0,0 +1,61 @@ +package mainargs +import utest._ + +object ParserTests extends TestSuite { + + object SingleBase { + @main(doc = "Qux is a function that does stuff") + def run( + i: Int, + @arg(doc = "Pass in a custom `s` to override it") + s: String = "lols" + ) = s * i + } + + object MultiBase { + @main + def foo() = 1 + + @main + def bar(i: Int) = i + } + + @main + case class ClassBase(code: Option[String] = None, other: String = "hello") + + val multiMethodParser = ParserForMethods(MultiBase) + val singleMethodParser = ParserForMethods(SingleBase) + val classParser = ParserForClass[ClassBase] + val tests = Tests { + test("runEitherMulti") { + + test { + multiMethodParser.runEither(Array("foo")) ==> Right(1) + } + test { + multiMethodParser.runEither(Array("bar", "-i", "123")) ==> Right(123) + } + test { + assert( + multiMethodParser + .runEither(Array("f")) + .left + .exists(_.contains("Unable to find subcommand: f")) + ) + } + } + test("runEitherSingle") { + singleMethodParser.runEither( + Array("5", "x"), + allowPositional = true + ) ==> Right("xxxxx") + } + test("constructEither") { + TestUtils.scala2Only { + // default values in classes not working on Scala 3 + classParser.constructEither(Array("--code", "println(1)")) ==> + Right(ClassBase(code = Some("println(1)"), other = "hello")) + } + } + } +} diff --git a/mainargs/test/src/TestUtils.scala b/mainargs/test/src/TestUtils.scala new file mode 100644 index 0000000..d697bad --- /dev/null +++ b/mainargs/test/src/TestUtils.scala @@ -0,0 +1,7 @@ +package mainargs + +object TestUtils { + def scala2Only(f: => Unit): Unit = { + if (VersionSpecific.isScala3) {} else f + } +} diff --git a/mainargs/test/src/VarargsTests.scala b/mainargs/test/src/VarargsTests.scala index cfeaa0b..cea2a97 100644 --- a/mainargs/test/src/VarargsTests.scala +++ b/mainargs/test/src/VarargsTests.scala @@ -1,61 +1,69 @@ package mainargs import utest._ -trait VarargsTests extends TestSuite{ +trait VarargsTests extends TestSuite { def check: Checker[_] def isNewVarargsTests: Boolean val tests = Tests { - test("happyPathPasses"){ + test("happyPathPasses") { test - check( - List("pureVariadic", "1", "2", "3"), Result.Success(6) + List("pureVariadic", "1", "2", "3"), + Result.Success(6) ) test - check( List("mixedVariadic", "1", "2", "3", "4", "5"), Result.Success("12345") ) test - { - if (isNewVarargsTests) check( - List("mixedVariadicWithDefault"), - Result.Success("1337") - ) + if (isNewVarargsTests) + check( + List("mixedVariadicWithDefault"), + Result.Success("1337") + ) } } - test("emptyVarargsPasses"){ + test("emptyVarargsPasses") { test - check(List("pureVariadic"), Result.Success(0)) test - check( - List("mixedVariadic", "-f", "1"), Result.Success("1") + List("mixedVariadic", "-f", "1"), + Result.Success("1") ) test - check( - List("mixedVariadic", "1"), Result.Success("1") + List("mixedVariadic", "1"), + Result.Success("1") ) } - test("varargsAreAlwaysPositional"){ + test("varargsAreAlwaysPositional") { val invoked = check.parseInvoke( List("pureVariadic", "--nums", "31337") ) - test - assertMatch(invoked){ - case Result.Failure.InvalidArguments(List( - Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("--nums"), - """java.lang.NumberFormatException: For input string: "--nums"""" | - """java.lang.NumberFormatException: --nums""" - ) - ))=> + test - assertMatch(invoked) { + case Result.Failure.InvalidArguments( + List( + Result.ParamError.Failed( + ArgSig.Leftover("nums", _, _), + Seq("--nums"), + """java.lang.NumberFormatException: For input string: "--nums"""" | + """java.lang.NumberFormatException: --nums""" + ) + ) + ) => } test - assertMatch( check.parseInvoke(List("pureVariadic", "1", "2", "3", "--nums", "4")) - ){ - case Result.Failure.InvalidArguments(List( - Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("--nums"), - "java.lang.NumberFormatException: For input string: \"--nums\"" | - "java.lang.NumberFormatException: --nums" - ) - ))=> + ) { + case Result.Failure.InvalidArguments( + List( + Result.ParamError.Failed( + ArgSig.Leftover("nums", _, _), + Seq("--nums"), + "java.lang.NumberFormatException: For input string: \"--nums\"" | + "java.lang.NumberFormatException: --nums" + ) + ) + ) => } test - check( List("mixedVariadic", "1", "--args", "foo"), @@ -64,47 +72,51 @@ trait VarargsTests extends TestSuite{ } - test("notEnoughNormalArgsStillFails"){ - assertMatch(check.parseInvoke(List("mixedVariadic"))){ + test("notEnoughNormalArgsStillFails") { + assertMatch(check.parseInvoke(List("mixedVariadic"))) { case Result.Failure.MismatchedArguments( - Seq(ArgSig.Simple(Some("first"), _, _, _, _, _)), - Nil, - Nil, - None - ) => + Seq(ArgSig.Simple(Some("first"), _, _, _, _, _)), + Nil, + Nil, + None + ) => } } - test("multipleVarargParseFailures"){ + test("multipleVarargParseFailures") { test - assertMatch( check.parseInvoke(List("pureVariadic", "aa", "bb", "3")) - ){ - case Result.Failure.InvalidArguments(List( - Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("aa"), - "java.lang.NumberFormatException: For input string: \"aa\"" | - "java.lang.NumberFormatException: aa" - ), - Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("bb"), - "java.lang.NumberFormatException: For input string: \"bb\"" | - "java.lang.NumberFormatException: bb" - ) - ))=> + ) { + case Result.Failure.InvalidArguments( + List( + Result.ParamError.Failed( + ArgSig.Leftover("nums", _, _), + Seq("aa"), + "java.lang.NumberFormatException: For input string: \"aa\"" | + "java.lang.NumberFormatException: aa" + ), + Result.ParamError.Failed( + ArgSig.Leftover("nums", _, _), + Seq("bb"), + "java.lang.NumberFormatException: For input string: \"bb\"" | + "java.lang.NumberFormatException: bb" + ) + ) + ) => } test - assertMatch( check.parseInvoke(List("mixedVariadic", "aa", "bb", "3")) - ){ - case Result.Failure.InvalidArguments(List( - Result.ParamError.Failed( - ArgSig.Simple(Some("first"), _, _, _, _, _), - Seq("aa"), - "java.lang.NumberFormatException: For input string: \"aa\"" | - "java.lang.NumberFormatException: aa" - ) - ))=> + ) { + case Result.Failure.InvalidArguments( + List( + Result.ParamError.Failed( + ArgSig.Simple(Some("first"), _, _, _, _, _), + Seq("aa"), + "java.lang.NumberFormatException: For input string: \"aa\"" | + "java.lang.NumberFormatException: aa" + ) + ) + ) => } } }