From 4fdc3cafa2bf6de23b9e84dd1a5eca96af557bd4 Mon Sep 17 00:00:00 2001 From: Pavlo Pohrebnyi Date: Sun, 15 Oct 2023 16:59:48 +0300 Subject: [PATCH] Add named parameters to the model --- build.sbt | 2 + src/main/scala/scanet/core/Mat.scala | 181 +++++++++++---- src/main/scala/scanet/core/Params.scala | 72 ++++++ src/main/scala/scanet/core/Session.scala | 27 ++- src/main/scala/scanet/core/TF.scala | 91 +++++++- .../scala/scanet/estimators/package.scala | 14 +- src/main/scala/scanet/math/grad/package.scala | 28 ++- .../scala/scanet/models/Aggregation.scala | 19 ++ src/main/scala/scanet/models/Math.scala | 18 +- src/main/scala/scanet/models/Model.scala | 157 ++++++------- src/main/scala/scanet/models/ParamDef.scala | 12 + .../scala/scanet/models/layer/Activate.scala | 7 +- src/main/scala/scanet/models/layer/Bias.scala | 22 +- .../scala/scanet/models/layer/Composed.scala | 53 ++--- .../scala/scanet/models/layer/Conv2D.scala | 31 ++- .../scala/scanet/models/layer/Dense.scala | 28 +-- .../scala/scanet/models/layer/Flatten.scala | 6 +- .../scala/scanet/models/layer/Layer.scala | 36 +-- .../scala/scanet/models/layer/LayerInfo.scala | 10 +- .../scala/scanet/models/layer/Pool2D.scala | 6 +- src/main/scala/scanet/models/layer/RNN.scala | 199 +++++++++-------- src/main/scala/scanet/optimizers/Effect.scala | 2 +- .../scanet/optimizers/KryoSerializers.scala | 4 +- .../scala/scanet/optimizers/Optimizer.scala | 206 +++++++++++------- src/main/scala/scanet/research/Neuron.scala | 16 +- src/test/scala/scanet/core/KernelsSpec.scala | 46 +++- src/test/scala/scanet/core/SessionSpec.scala | 8 +- src/test/scala/scanet/core/TFSpec.scala | 20 ++ src/test/scala/scanet/models/CNNSpec.scala | 33 ++- src/test/scala/scanet/models/RNNSpec.scala | 2 +- .../scala/scanet/models/RegressionSpec.scala | 65 ++++-- .../models/layer/ActivateLayerSpec.scala | 6 +- .../scanet/models/layer/BiasLayerSpec.scala | 7 +- .../models/layer/ComposedLayerSpec.scala | 38 +++- .../scanet/models/layer/Conv2DLayerSpec.scala | 7 +- .../scanet/models/layer/DenseLayerSpec.scala | 18 +- .../models/layer/FlattenLayerSpec.scala | 6 +- .../scanet/models/layer/Pool2DLayerSpec.scala | 6 +- .../scanet/models/layer/RNNLayerSpec.scala | 35 ++- .../scala/scanet/optimizers/SGDSpec.scala | 5 +- 40 files changed, 1033 insertions(+), 516 deletions(-) create mode 100644 src/main/scala/scanet/core/Params.scala create mode 100644 src/main/scala/scanet/models/Aggregation.scala create mode 100644 src/main/scala/scanet/models/ParamDef.scala diff --git a/build.sbt b/build.sbt index a3d4a66..ed05d07 100644 --- a/build.sbt +++ b/build.sbt @@ -17,6 +17,8 @@ libraryDependencies ++= Seq( "com.typesafe.scala-logging" %% "scala-logging" % "3.9.5", "ch.qos.logback" % "logback-classic" % "1.4.5", "org.scala-lang.modules" %% "scala-collection-compat" % "2.9.0", + "com.github.ben-manes.caffeine" % "caffeine" % "2.8.5", + "com.softwaremill.magnolia1_2" %% "magnolia" % "1.1.3", "org.scalacheck" %% "scalacheck" % "1.17.0" % Test, "org.scalatest" %% "scalatest" % "3.2.14" % Test) diff --git a/src/main/scala/scanet/core/Mat.scala b/src/main/scala/scanet/core/Mat.scala index e569fc0..0df8625 100644 --- a/src/main/scala/scanet/core/Mat.scala +++ b/src/main/scala/scanet/core/Mat.scala @@ -11,10 +11,14 @@ import scala.collection.immutable.Seq // case type information gets lost there trait Mat[In] { type Out + // given MyType[Expr] deconstruct as Layout + Seq[Expr] so we could feed a session def deconstructIn(in: In): (Layout, Seq[Expr[_]]) - def constructIn(shape: Layout, expr: Seq[Expr[_]]): In + // given placeholders and layout, construct In representation so we could to pass into the session + def constructIn(layout: Layout, expr: Seq[Expr[_]]): In + // given MyType[Tensor] deconstruct as Layout + Seq[Tensor] so we could create placeholders def deconstructOut(out: Out): (Layout, Seq[Tensor[_]]) - def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out + // given raw tensors and layout, construct MyType[Tensor], that is done to get results after session run + def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out } sealed trait Layout { @@ -22,46 +26,145 @@ sealed trait Layout { } object Layout { - case class Leaf(size: Int) extends Layout + case object Value extends Layout { + override def size: Int = 1 + } + case class SeqLayout(items: Seq[Layout]) extends Layout { + override def size: Int = items.map(_.size).sum + } + case class MapLayout(items: Seq[(Any, Layout)]) extends Layout { + override def size: Int = items.map(_._2.size).sum + } case class Struct(children: Seq[Layout]) extends Layout { def apply(index: Int): Layout = children(index) def size: Int = children.map(_.size).sum } } -class MatExpr[A: TensorType] extends Mat[Expr[A]] { +class ValueMat[A: TensorType] extends Mat[Expr[A]] { override type Out = Tensor[A] - override def deconstructIn(in: Expr[A]): (Layout, Seq[Expr[_]]) = (Leaf(1), Seq(in)) - override def constructIn(shape: Layout, expr: Seq[Expr[_]]): Expr[A] = - expr.head.asInstanceOf[Expr[A]] - override def deconstructOut(out: Tensor[A]): (Layout, Seq[Tensor[_]]) = (Leaf(1), Seq(out)) - override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out = - Tensor.wrap[A](tensors.head) -} - -class MatSeqOfExpr[A: TensorType] extends Mat[Seq[Expr[A]]] { - override type Out = Seq[Tensor[A]] - override def deconstructIn(in: Seq[Expr[A]]): (Layout, Seq[Expr[_]]) = (Leaf(in.size), in) - override def constructIn(shape: Layout, expr: Seq[Expr[_]]): Seq[Expr[A]] = - expr.asInstanceOf[Seq[Expr[A]]] - override def deconstructOut(out: Seq[Tensor[A]]): (Layout, Seq[Tensor[_]]) = - (Leaf(out.size), out) - override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out = - tensors.map(raw => Tensor.wrap[A](raw)) + override def deconstructIn(in: Expr[A]): (Layout, Seq[Expr[_]]) = (Value, Seq(in)) + override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Expr[A] = layout match { + case Value => expr.head.asInstanceOf[Expr[A]] + case other => error(s"Unsupported layout $other") + } + override def deconstructOut(out: Tensor[A]): (Layout, Seq[Tensor[_]]) = (Value, Seq(out)) + override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = + layout match { + case Value => Tensor.wrap[A](tensors.head) + case other => error(s"Unsupported layout $other") + } } object Mat { trait AllSyntax { - implicit def matExpr[A: TensorType]: MatExpr[A] = new MatExpr[A] + implicit def valueMat[A: TensorType]: ValueMat[A] = new ValueMat[A] + // Note: we have to define anonymous class so dependant types would work + // that is probably caused by the fact that we have to preserve an exact structural type - implicit def matSeqOfExpr[A: TensorType]: MatSeqOfExpr[A] = new MatSeqOfExpr[A] + implicit def seqMat[A](implicit m: Mat[A]) = new Mat[Seq[A]] { + override type Out = Seq[m.Out] + override def deconstructIn(in: Seq[A]): (Layout, Seq[Expr[_]]) = { + val (layouts, allExpr) = in.map(m.deconstructIn).unzip + (SeqLayout(layouts), allExpr.flatten) + } + override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Seq[A] = { + layout match { + case SeqLayout(items) => + items.foldLeft((Seq.empty[A], expr)) { + case ((result, exprAll), next) => + val (consumed, rest) = exprAll.splitAt(next.size) + val constructed = m.constructIn(next, consumed) + (constructed +: result, rest) + }._1.reverse + case other => error(s"Unsupported layout $other") + } + } + override def deconstructOut(out: Seq[m.Out]): (Layout, Seq[Tensor[_]]) = { + val (layouts, allTensors) = out.map(m.deconstructOut).unzip + (SeqLayout(layouts), allTensors.flatten) + } + override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Seq[m.Out] = { + layout match { + case SeqLayout(items) => + items.foldLeft((Seq.empty[m.Out], tensors)) { + case ((result, tensorsAll), next) => + val (consumed, rest) = tensorsAll.splitAt(next.size) + val constructed = m.constructOutRaw(next, consumed) + (constructed +: result, rest) + }._1.reverse + case other => error(s"Unsupported layout $other") + } + } + } - // Note: we have to define anonymous class so dependant types would work - // that is probably caused by the fact that we have to preserve an original - // (m1.Out, m2.Out) which come from implicit scope - implicit def matTuple2Expr[A1, A2](implicit m1: Mat[A1], m2: Mat[A2]) = + implicit def mapMat[K, V](implicit m: Mat[V]) = new Mat[Map[K, V]] { + override type Out = Map[K, m.Out] + override def deconstructIn(in: Map[K, V]): (Layout, Seq[Expr[_]]) = { + val (layouts, allExpr) = in + .map { + case (key, value) => + val (layout, expr) = m.deconstructIn(value) + ((key, layout), expr) + } + .toList + .unzip + (MapLayout(layouts), allExpr.flatten) + } + override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Map[K, V] = { + layout match { + case MapLayout(items) => + items.foldLeft((Map.empty[K, V], expr)) { + case ((result, exprAll), (key, next)) => + val (consumed, rest) = exprAll.splitAt(next.size) + val constructed = m.constructIn(next, consumed) + (result + (key.asInstanceOf[K] -> constructed), rest) + }._1 + case other => error(s"Unsupported layout $other") + } + } + override def deconstructOut(out: Map[K, m.Out]): (Layout, Seq[Tensor[_]]) = { + val (layouts, allTensors) = out + .map { + case (key, value) => + val (layout, tensors) = m.deconstructOut(value) + ((key, layout), tensors) + } + .toList + .unzip + (MapLayout(layouts), allTensors.flatten) + } + override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = { + layout match { + case MapLayout(items) => + items.foldLeft((Map.empty[K, m.Out], tensors)) { + case ((result, tensorsAll), (key, next)) => + val (consumed, rest) = tensorsAll.splitAt(next.size) + val constructed = m.constructOutRaw(next, consumed) + (result + (key.asInstanceOf[K] -> constructed), rest) + }._1 + case other => error(s"Unsupported layout $other") + } + } + } + + // check to see if we can implement xmap with dependant types + implicit def paramsMat[A](implicit m: Mat[A]) = new Mat[Params[A]] { + override type Out = Params[m.Out] + private val map = mapMat[Path, A] + override def deconstructIn(in: Params[A]): (Layout, Seq[Expr[_]]) = + map.deconstructIn(in.params) + override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Params[A] = + Params(map.constructIn(layout, expr)) + override def deconstructOut(out: Params[m.Out]): (Layout, Seq[Tensor[_]]) = + map.deconstructOut(out.params) + override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = + Params(map.constructOutRaw(layout, tensors)) + } + + implicit def tuple2Mat[A1, A2](implicit m1: Mat[A1], m2: Mat[A2]) = new Mat[(A1, A2)] { override type Out = (m1.Out, m2.Out) override def deconstructIn(in: (A1, A2)): (Layout, Seq[Expr[_]]) = { @@ -69,14 +172,14 @@ object Mat { val (shape2, expr2) = m2.deconstructIn(in._2) (Struct(Seq(shape1, shape2)), expr1 ++ expr2) } - override def constructIn(shape: Layout, expr: Seq[Expr[_]]): (A1, A2) = { - shape match { + override def constructIn(layout: Layout, expr: Seq[Expr[_]]): (A1, A2) = { + layout match { case Struct(t1 :: t2 :: Nil) => val (s1, s2) = (t1.size, t2.size) ( m1.constructIn(t1, expr.slice(0, s1)), m2.constructIn(t2, expr.slice(s1, s1 + s2))) - case _ => error("StructShape of size 2 is required") + case _ => error("Struct layout of size 2 is required") } } override def deconstructOut(out: (m1.Out, m2.Out)): (Layout, Seq[Tensor[_]]) = { @@ -84,8 +187,8 @@ object Mat { val (shape2, tensor2) = m2.deconstructOut(out._2) (Struct(Seq(shape1, shape2)), tensor1 ++ tensor2) } - override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out = { - shape match { + override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = { + layout match { case Struct(t1 :: t2 :: Nil) => val (s1, s2) = (t1.size, t2.size) ( @@ -96,7 +199,7 @@ object Mat { } } - implicit def matTuple3Expr[A1, A2, A3](implicit m1: Mat[A1], m2: Mat[A2], m3: Mat[A3]) = + implicit def tuple3Mat[A1, A2, A3](implicit m1: Mat[A1], m2: Mat[A2], m3: Mat[A3]) = new Mat[(A1, A2, A3)] { override type Out = (m1.Out, m2.Out, m3.Out) override def deconstructIn(in: (A1, A2, A3)): (Layout, Seq[Expr[_]]) = { @@ -105,15 +208,15 @@ object Mat { val (t3, expr3) = m3.deconstructIn(in._3) (Struct(Seq(t1, t2, t3)), expr1 ++ expr2 ++ expr3) } - override def constructIn(shape: Layout, expr: Seq[Expr[_]]): (A1, A2, A3) = { - shape match { + override def constructIn(layout: Layout, expr: Seq[Expr[_]]): (A1, A2, A3) = { + layout match { case Struct(t1 :: t2 :: t3 :: Nil) => val (s1, s2, s3) = (t1.size, t2.size, t3.size) ( m1.constructIn(t1, expr.slice(0, s1)), m2.constructIn(t2, expr.slice(s1, s1 + s2)), m3.constructIn(t3, expr.slice(s1 + s2, s1 + s2 + s3))) - case _ => error("StructShape of size 3 is required") + case _ => error("Struct layout of size 3 is required") } } override def deconstructOut(out: (m1.Out, m2.Out, m3.Out)): (Layout, Seq[Tensor[_]]) = { @@ -122,15 +225,15 @@ object Mat { val (t3, tensor3) = m3.deconstructOut(out._3) (Struct(Seq(t1, t2, t3)), tensor1 ++ tensor2 ++ tensor3) } - override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out = { - shape match { + override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = { + layout match { case Struct(t1 :: t2 :: t3 :: Nil) => val (s1, s2, s3) = (t1.size, t2.size, t3.size) ( m1.constructOutRaw(t1, tensors.slice(0, s1)), m2.constructOutRaw(t2, tensors.slice(s1, s1 + s2)), m3.constructOutRaw(t3, tensors.slice(s1 + s2, s1 + s2 + s3))) - case _ => error("StructShape of size 3 is required") + case _ => error("Struct layout of size 3 is required") } } } diff --git a/src/main/scala/scanet/core/Params.scala b/src/main/scala/scanet/core/Params.scala new file mode 100644 index 0000000..936360b --- /dev/null +++ b/src/main/scala/scanet/core/Params.scala @@ -0,0 +1,72 @@ +package scanet.core + +case class Path(segments: Seq[String]) { + def /(child: Path): Path = Path(segments ++ child.segments) + def startsWith(parent: Path): Boolean = segments.startsWith(parent.segments) + def endsWith(parent: Path): Boolean = segments.endsWith(parent.segments) + def relativeTo(parent: Path): Option[Path] = + if (startsWith(parent)) Some(Path(segments.drop(parent.segments.size))) else None + override def toString: String = segments.mkString("/") +} + +object Path { + def apply(s1: String, sn: String*): Path = new Path(s1 :: sn.toList) + def parse(path: String): Path = new Path(path.split("/").toList) + implicit def stringIsPath(string: String): Path = Path.parse(string) +} + +// convenient wrapper over Map[Path, A] to get DSL like API +case class Params[A](params: Map[Path, A]) { + def apply(path: Path): A = + params.getOrElse(path, error(s"missing $path param")) + def paths: Set[Path] = params.keySet + def children(parent: Path): Params[A] = { + val childrenParams = params.flatMap { + case (path, value) => + path.relativeTo(parent).map(child => child -> value) + } + Params(childrenParams) + } + def +(other: (Path, A)): Params[A] = Params(params + other) + def -(other: Path): Params[A] = Params(params - other) + def ++(other: Params[A]): Params[A] = + Params(params ++ other.params) + def map[B](f: (Path, A) => (Path, B)): Params[B] = + Params(params.map(f.tupled)) + def mapValues[B](f: A => B): Params[B] = + Params(params.map { case (k, v) => (k, f(v)) }) + def filter(f: (Path, A) => Boolean): Params[A] = + Params(params.filter(f.tupled)) + def filterPaths(f: Path => Boolean): Params[A] = + Params(params.filter { case (k, _) => f(k) }) + def filterValues(f: A => Boolean): Params[A] = + Params(params.filter { case (_, v) => f(v) }) + def partition(by: (Path, A) => Boolean): (Params[A], Params[A]) = { + val (left, right) = params.partition(by.tupled) + (Params(left), Params(right)) + } + def partitionPaths(by: Path => Boolean): (Params[A], Params[A]) = + partition { case (k, _) => by(k) } + def partitionValues(by: A => Boolean): (Params[A], Params[A]) = + partition { case (k, v) => by(v) } + def values: Iterable[A] = params.values + def join[B](other: Params[B]): Params[(A, B)] = { + val allPaths = paths ++ other.paths + val joinedItems = allPaths.map(path => (path, (this(path), other(path)))) + Params(joinedItems.toMap) + } + def size: Int = params.size + def isEmpty: Boolean = params.isEmpty + def prependPath(path: Path): Params[A] = + Params(params.map { case (k, v) => path / k -> v }) + def weights: A = apply(Params.Weights) +} + +object Params { + def apply[A](elems: (Path, A)*): Params[A] = + new Params[A](Map(elems: _*)) + def empty[A]: Params[A] = + new Params[A](Map.empty) + val Weights: Path = "weights" + val State: Path = "state" +} diff --git a/src/main/scala/scanet/core/Session.scala b/src/main/scala/scanet/core/Session.scala index c02fd9a..fb340ea 100644 --- a/src/main/scala/scanet/core/Session.scala +++ b/src/main/scala/scanet/core/Session.scala @@ -17,7 +17,8 @@ case class Runner(session: Session, feed: Map[Expr[_], Tensor[_]] = Map()) { def feed(elems: (Expr[_], Tensor[_])*): Runner = copy(feed = feed ++ Map(elems: _*)) def evalUnsafe(outs: Seq[Expr[_]]): Seq[RawTensor] = { - session.eval(outs, feed) + if (outs.isEmpty) Seq.empty + else session.eval(outs, feed) } def eval[A](value: A)(implicit m: Mat[A]): m.Out = { @@ -53,13 +54,13 @@ case class SessionState(scope: NativeScope, cache: Map[String, LabeledOperationO */ class Session(verbose: Boolean = false) extends AutoCloseable { - val nGraph = new Graph() - val config: ConfigProto = ConfigProto + private val nGraph = new Graph() + private val config = ConfigProto .newBuilder(ConfigProto.getDefaultInstance) .setLogDevicePlacement(verbose) .build() - val nSession = new NativeSession(nGraph, config) - var state = SessionState(new OpScope(nGraph), Map.empty) + private val nSession = new NativeSession(nGraph, config) + @volatile private var state = SessionState(new OpScope(nGraph), Map.empty) def runner: Runner = Runner(this) @@ -94,7 +95,7 @@ class Session(verbose: Boolean = false) extends AutoCloseable { outputs.uncompress(outTensors) } - def nativeHandle: TF_Session = { + private[core] def nativeHandle: TF_Session = { val nativeHandleField = classOf[NativeSession].getDeclaredField("nativeHandle") nativeHandleField.setAccessible(true) nativeHandleField.get(nSession).asInstanceOf[TF_Session] @@ -143,27 +144,25 @@ object Session { } } -class LazySession { - lazy val get = new Session() -} - -class SessionPool(val maxSize: Int) { +class SessionPool(val maxSize: Int) extends AutoCloseable { - private val pool: BlockingDeque[LazySession] = new LinkedBlockingDeque[LazySession](maxSize) + private val pool: BlockingDeque[Session] = new LinkedBlockingDeque[Session](maxSize) private val inUse: AtomicInteger = new AtomicInteger(0) - pool.addAll(List.fill(maxSize)(new LazySession()).asJavaCollection) + pool.addAll(List.fill(maxSize)(new Session()).asJavaCollection) def used: Int = inUse.get def within[R](f: Session => R): R = { val session = pool.takeFirst() inUse.incrementAndGet() - val result = Try { f(session.get) } + val result = Try { f(session) } pool.addFirst(session) inUse.decrementAndGet() result.get } + + override def close(): Unit = pool.forEach(_.close()) } object SessionPool { diff --git a/src/main/scala/scanet/core/TF.scala b/src/main/scala/scanet/core/TF.scala index 9f6d213..d0087b7 100644 --- a/src/main/scala/scanet/core/TF.scala +++ b/src/main/scala/scanet/core/TF.scala @@ -27,6 +27,10 @@ object TF { def tf: TF5[A1, A2, A3, A4, A5, R] = TF5(f) } + class TF6Ops[A1, A2, A3, A4, A5, A6, R](f: (A1, A2, A3, A4, A5, A6) => R) { + def tf: TF6[A1, A2, A3, A4, A5, A6, R] = TF6(f) + } + trait AllSyntax { // implicit conversions to TF when, useful when we want to compile regular function @@ -36,6 +40,8 @@ object TF { implicit def toTF4[A1, A2, A3, A4, R](f: (A1, A2, A3, A4) => R): TF4[A1, A2, A3, A4, R] = TF4(f) implicit def toTF5[A1, A2, A3, A4, A5, R](f: (A1, A2, A3, A4, A5) => R) : TF5[A1, A2, A3, A4, A5, R] = TF5(f) + implicit def toTF6[A1, A2, A3, A4, A5, A6, R](f: (A1, A2, A3, A4, A5, A6) => R) + : TF6[A1, A2, A3, A4, A5, A6, R] = TF6(f) // ops to by calling f.tf we can convert a function to FT on demand implicit def toTF1Ops[A1, R](f: A1 => R): TF1Ops[A1, R] = new TF1Ops(f) @@ -46,7 +52,8 @@ object TF { new TF4Ops(f) implicit def toTF5Ops[A1, A2, A3, A4, A5, R](f: (A1, A2, A3, A4, A5) => R) : TF5Ops[A1, A2, A3, A4, A5, R] = new TF5Ops(f) - + implicit def toTF6Ops[A1, A2, A3, A4, A5, A6, R](f: (A1, A2, A3, A4, A5, A6) => R) + : TF6Ops[A1, A2, A3, A4, A5, A6, R] = new TF6Ops(f) } object syntax extends AllSyntax @@ -349,3 +356,85 @@ object TF5 { } } } + +trait TF6[A1, A2, A3, A4, A5, A6, R] { + def compileWith(session: Session)( + implicit a1Mat: Mat[A1], + a2Mat: Mat[A2], + a3Mat: Mat[A3], + a4Mat: Mat[A4], + a5Mat: Mat[A5], + a6Mat: Mat[A6], + rMat: Mat[R]): (a1Mat.Out, a2Mat.Out, a3Mat.Out, a4Mat.Out, a5Mat.Out, a6Mat.Out) => rMat.Out + def compile( + implicit a1Mat: Mat[A1], + a2Mat: Mat[A2], + a3Mat: Mat[A3], + a4Mat: Mat[A4], + a5Mat: Mat[A5], + a6Mat: Mat[A6], + rMat: Mat[R]) + : (a1Mat.Out, a2Mat.Out, a3Mat.Out, a4Mat.Out, a5Mat.Out, a6Mat.Out) => rMat.Out = + compileWith(new Session())(a1Mat, a2Mat, a3Mat, a4Mat, a5Mat, a6Mat, rMat) +} + +object TF6 { + + def apply[A1, A2, A3, A4, A5, A6, R](func: (A1, A2, A3, A4, A5, A6) => R): TF6[A1, A2, A3, A4, A5, A6, R] = + new TF6Cached[A1, A2, A3, A4, A5, A6, R](func) + + class TF6Cached[A1, A2, A3, A4, A5, A6, R](func: (A1, A2, A3, A4, A5, A6) => R) + extends TF6[A1, A2, A3, A4, A5, A6, R] { + + private val cache = TF.Cache[R]() + + override def compileWith(session: Session)( + implicit a1Mat: Mat[A1], + a2Mat: Mat[A2], + a3Mat: Mat[A3], + a4Mat: Mat[A4], + a5Mat: Mat[A5], + a6Mat: Mat[A6], + rMat: Mat[R]): (a1Mat.Out, a2Mat.Out, a3Mat.Out, a4Mat.Out, a5Mat.Out, a6Mat.Out) => rMat.Out = { + (a1Out: a1Mat.Out, a2Out: a2Mat.Out, a3Out: a3Mat.Out, a4Out: a4Mat.Out, a5Out: a5Mat.Out, a6Out: a6Mat.Out) => + { + val (a1Layout, a1T) = a1Mat.deconstructOut(a1Out) + val (a2Layout, a2T) = a2Mat.deconstructOut(a2Out) + val (a3Layout, a3T) = a3Mat.deconstructOut(a3Out) + val (a4Layout, a4T) = a4Mat.deconstructOut(a4Out) + val (a5Layout, a5T) = a5Mat.deconstructOut(a5Out) + val (a6Layout, a6T) = a6Mat.deconstructOut(a6Out) + val aTAll = a1T ++ a2T ++ a3T ++ a4T ++ a5T ++ a6T + val a1Type = a1T.map(tensor => (tensor.`type`, tensor.shape)) + val a2Type = a2T.map(tensor => (tensor.`type`, tensor.shape)) + val a3Type = a3T.map(tensor => (tensor.`type`, tensor.shape)) + val a4Type = a4T.map(tensor => (tensor.`type`, tensor.shape)) + val a5Type = a5T.map(tensor => (tensor.`type`, tensor.shape)) + val a6Type = a6T.map(tensor => (tensor.`type`, tensor.shape)) + val aShapes = a1Type.map(_._2) ++ a2Type.map(_._2) ++ a3Type.map(_._2) ++ + a4Type.map(_._2) ++ a5Type.map(_._2) ++ a6Type.map(_._2) + + val (pAll, r) = cache.getOrCompute(aShapes) { + val p1 = a1Type.map { case (t, s) => placeholderRaw(t, s) } + val p2 = a2Type.map { case (t, s) => placeholderRaw(t, s) } + val p3 = a3Type.map { case (t, s) => placeholderRaw(t, s) } + val p4 = a4Type.map { case (t, s) => placeholderRaw(t, s) } + val p5 = a5Type.map { case (t, s) => placeholderRaw(t, s) } + val p6 = a6Type.map { case (t, s) => placeholderRaw(t, s) } + val a1 = a1Mat.constructIn(a1Layout, p1) + val a2 = a2Mat.constructIn(a2Layout, p2) + val a3 = a3Mat.constructIn(a3Layout, p3) + val a4 = a4Mat.constructIn(a4Layout, p4) + val a5 = a5Mat.constructIn(a5Layout, p5) + val a6 = a6Mat.constructIn(a6Layout, p6) + val r = func(a1, a2, a3, a4, a5, a6) + (p1 ++ p2 ++ p3 ++ p4 ++ p5 ++ p6, r) + } + + val (rLayout, rIn) = rMat.deconstructIn(r) + val rRaw = session.runner.feed(pAll zip aTAll: _*).evalUnsafe(rIn) + rMat.constructOutRaw(rLayout, rRaw) + } + } + } +} diff --git a/src/main/scala/scanet/estimators/package.scala b/src/main/scala/scanet/estimators/package.scala index 596f4d2..e624d63 100644 --- a/src/main/scala/scanet/estimators/package.scala +++ b/src/main/scala/scanet/estimators/package.scala @@ -5,7 +5,7 @@ import scanet.core.{Expr, Numeric, Session, Shape} import scala.{math => m} import scanet.math.syntax._ -import scanet.models.TrainedModel +import scanet.models.{TrainedModel_} import scanet.optimizers.Iterators.Partial import scanet.optimizers.syntax._ import scanet.optimizers.Record @@ -16,7 +16,7 @@ import scala.collection.immutable.Seq package object estimators { def accuracy[A: Numeric]( - model: TrainedModel[A], + model: TrainedModel_[A], ds: Dataset[Record[A]], batch: Int = 1000): Float = { import ds.sparkSession.implicits._ @@ -47,13 +47,13 @@ package object estimators { } def RMSE[A: Numeric]( - model: TrainedModel[A], + model: TrainedModel_[A], ds: Dataset[Record[A]], batch: Int = 1000): Float = m.sqrt(MSE(model, ds, batch)).toFloat def MSE[A: Numeric]( - model: TrainedModel[A], + model: TrainedModel_[A], ds: Dataset[Record[A]], batch: Int = 1000): Float = meanError(model, ds, batch) { @@ -61,7 +61,7 @@ package object estimators { } def MAE[A: Numeric]( - model: TrainedModel[A], + model: TrainedModel_[A], ds: Dataset[Record[A]], batch: Int = 1000): Float = meanError(model, ds, batch) { @@ -69,7 +69,7 @@ package object estimators { } private def meanError[A: Numeric]( - model: TrainedModel[A], + model: TrainedModel_[A], ds: Dataset[Record[A]], batch: Int)( error: (Expr[A], Expr[A]) => Expr[A]): Float = { @@ -100,7 +100,7 @@ package object estimators { } def R2Score[A: Numeric]( - model: TrainedModel[A], + model: TrainedModel_[A], ds: Dataset[Record[A]], batch: Int = 1000): Float = { require(ds.labelsShape == Shape(1), "labels should have shape (1)") diff --git a/src/main/scala/scanet/math/grad/package.scala b/src/main/scala/scanet/math/grad/package.scala index 5bb5592..1154651 100644 --- a/src/main/scala/scanet/math/grad/package.scala +++ b/src/main/scala/scanet/math/grad/package.scala @@ -1,6 +1,6 @@ package scanet.math.grad -import scanet.core.{Expr, Floating, Node, Numeric, Shape} +import scanet.core.{Expr, Floating, Node, Numeric, Params, Shape} import scanet.math.alg.kernels.syntax._ import scala.collection.immutable.Seq @@ -41,7 +41,14 @@ class GradCalc[A: Numeric, R: Floating](out: Expr[A]) { } } -class GradCalcNOps[A: Numeric, B: Numeric]( +class GradCalcOps[A: Numeric, B: Numeric]( + out: Expr[A], + withRespectTo: Expr[B]) { + def returns[R: Floating]: Expr[R] = + new GradCalc[A, R](out).calc[B](withRespectTo) +} + +class GradCalcSeqOps[A: Numeric, B: Numeric]( out: Expr[A], withRespectTo: Seq[Expr[B]]) { def returns[R: Floating]: Seq[Expr[R]] = { @@ -50,11 +57,13 @@ class GradCalcNOps[A: Numeric, B: Numeric]( } } -class GradCalcOps[A: Numeric, B: Numeric]( +class GradCalcParamsOps[A: Numeric, B: Numeric]( out: Expr[A], - withRespectTo: Expr[B]) { - def returns[R: Floating]: Expr[R] = - new GradCalc[A, R](out).calc[B](withRespectTo) + withRespectTo: Params[Expr[B]]) { + def returns[R: Floating]: Params[Expr[R]] = { + val calc = new GradCalc[A, R](out) + withRespectTo.mapValues(calc.calc[B]) + } } class GradOps[A: Numeric](expr: Expr[A]) { @@ -62,8 +71,11 @@ class GradOps[A: Numeric](expr: Expr[A]) { def grad[B: Numeric](withRespectTo: Expr[B]): GradCalcOps[A, B] = new GradCalcOps[A, B](expr, withRespectTo) - def grad[B: Numeric](withRespectTo: Seq[Expr[B]]): GradCalcNOps[A, B] = - new GradCalcNOps[A, B](expr, withRespectTo) + def grad[B: Numeric](withRespectTo: Seq[Expr[B]]): GradCalcSeqOps[A, B] = + new GradCalcSeqOps[A, B](expr, withRespectTo) + + def grad[B: Numeric](withRespectTo: Params[Expr[B]]): GradCalcParamsOps[A, B] = + new GradCalcParamsOps[A, B](expr, withRespectTo) } trait GradSyntax { diff --git a/src/main/scala/scanet/models/Aggregation.scala b/src/main/scala/scanet/models/Aggregation.scala new file mode 100644 index 0000000..d6037dc --- /dev/null +++ b/src/main/scala/scanet/models/Aggregation.scala @@ -0,0 +1,19 @@ +package scanet.models + +import scanet.core.{Expr, Floating} +import scanet.math.syntax._ + +trait Aggregation { + def build[E: Floating](inputs: Seq[Expr[E]]): Expr[E] +} + +object Aggregation { + case object Sum extends Aggregation { + override def build[E: Floating](inputs: Seq[Expr[E]]): Expr[E] = + plus(inputs) + } + case object Avg extends Aggregation { + override def build[E: Floating](inputs: Seq[Expr[E]]): Expr[E] = + plus(inputs) / inputs.size.const.cast[E] + } +} diff --git a/src/main/scala/scanet/models/Math.scala b/src/main/scala/scanet/models/Math.scala index c1077aa..99cc672 100644 --- a/src/main/scala/scanet/models/Math.scala +++ b/src/main/scala/scanet/models/Math.scala @@ -1,7 +1,9 @@ package scanet.models -import scanet.core.{Expr, Floating, Shape} +import scanet.core.Params.Weights +import scanet.core.{Expr, Floating, Params, Shape} import scanet.math.syntax._ +import scanet.models.Aggregation.Avg import scanet.models.layer.StatelessLayer import scala.collection.immutable.Seq @@ -10,17 +12,15 @@ object Math { case object `x^2` extends StatelessLayer { - override def build[A: Floating]( - input: Expr[A], - weights: Seq[Expr[A]]): Expr[A] = - weights.head * weights.head + override def params_(input: Shape): Params[ParamDef] = + Params(Weights -> ParamDef(Shape(), Initializer.Zeros, Some(Avg), trainable = true)) - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]) = zeros[E](Shape()) + override def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] = + pow(params(Weights), 2) - override def weightsShapes(input: Shape): Seq[Shape] = Seq(Shape()) + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = + zeros[E](Shape()) override def outputShape(input: Shape): Shape = input - - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = Seq(zeros[E](Shape())) } } diff --git a/src/main/scala/scanet/models/Model.scala b/src/main/scala/scanet/models/Model.scala index cea7e4b..a109af2 100644 --- a/src/main/scala/scanet/models/Model.scala +++ b/src/main/scala/scanet/models/Model.scala @@ -11,49 +11,43 @@ abstract class Model extends Serializable { def name: String = getClass.getSimpleName.replace("$", "") - def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]): Expr[E] = - buildStateful(input, weights, stateShapes(input.shape).map(s => zeros[E](s)))._1 + /** Model params, both trainable and non-trainable (model state) + * @param input input shape + * @return param definitions + */ + def params_(input: Shape): Params[ParamDef] /** Build a model * - * @param input training set, where first dimension equals to number of samples (batch size) - * @param weights model weights - * @param state model state - * @return model + * @param input training set, where first dimension equals to number of samples (batch size) + * @param params initialized or calculated model params + * @return tuple where the first element is model output and second is changed params */ - def buildStateful[E: Floating]( - input: Expr[E], - weights: Seq[Expr[E]], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) + def build_[E: Floating](input: Expr[E], params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) /** Additional model penalty to be added to the loss * - * @param weights model weights + * @param params initialized or calculated model params * @return penalty */ - def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]): Expr[E] + def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] - def result[E: Floating]: (Expr[E], Seq[Expr[E]]) => Expr[E] = - (input, weights) => - buildStateful(input, weights, stateShapes(input.shape).map(s => zeros[E](s)))._1 + def result_[E: Floating]: (Expr[E], Params[Expr[E]]) => Expr[E] = + (input, params) => build_(input, params)._1 - def resultStateful[E: Floating] - : (Expr[E], Seq[Expr[E]], Seq[Expr[E]]) => (Expr[E], Seq[Expr[E]]) = - (input, weights, state) => buildStateful(input, weights, state) + def resultStateful_[E: Floating]: (Expr[E], Params[Expr[E]]) => (Expr[E], Params[Expr[E]]) = + (input, params) => build_(input, params) + // do we really need that??? def outputShape(input: Shape): Shape - def weightsShapes(input: Shape): Seq[Shape] - def initWeights[E: Floating](input: Shape): Seq[Expr[E]] - - def stateShapes(input: Shape): Seq[Shape] - def withLoss(loss: Loss): LossModel = LossModel(this, loss) - private def makeGraph[E: Floating](input: Shape) = - build( - placeholder[E](input), - weightsShapes(input).map(s => placeholder[E](s))) + private def makeGraph[E: Floating](input: Shape): Expr[E] = + build_( + input = placeholder[E](input), + params = params_(input).mapValues(paramDef => placeholder[E](paramDef.shape))) + ._1 def displayResult[E: Floating](input: Shape, dir: String = ""): Unit = makeGraph[E](input).as("result").display(dir) @@ -61,101 +55,118 @@ abstract class Model extends Serializable { def printResult[E: Floating](input: Shape): Unit = println(makeGraph[E](input).as("result").toString) - def info(input: Shape): Seq[LayerInfo] = - Seq(LayerInfo(toString, weightsShapes(input), outputShape(input))) + def info(input: Shape): Seq[LayerInfo] = { + val (weights, state) = params_(input).partitionValues(_.trainable) + Seq(LayerInfo( + toString, + weights.values.map(_.shape).toList, + state.values.map(_.shape).toList, + outputShape(input))) + } def describe[E: Floating](input: Shape): String = { val layersInfo = info(input) - val layers = (LayerInfo("Input", Seq.empty, input) +: layersInfo).map(_.toRow) - val table = Tabulator.format(Seq("name", "weights", "params", "output") +: layers) - val total = layersInfo.flatMap(info => info.weights.map(_.power)).sum - val size = Bytes.formatSize(TensorType[E].codec.sizeOf(total)) - s"$table\nTotal params: $total ($size)" + val layers = (LayerInfo("Input", Seq.empty, Seq.empty, input) +: layersInfo).map(_.toRow) + val table = + Tabulator.format(Seq("name", "weights", "weights params", "state params", "output") +: layers) + val weightTotal = layersInfo.map(info => info.weightsTotal).sum + val weightSize = Bytes.formatSize(TensorType[E].codec.sizeOf(weightTotal)) + val stateTotal = layersInfo.map(info => info.stateTotal).sum + val stateSize = Bytes.formatSize(TensorType[E].codec.sizeOf(stateTotal)) + s"$table\nTotal weight params: $weightTotal ($weightSize), state params: $stateTotal ($stateSize)" } } case class LossModel(model: Model, lossF: Loss) extends Serializable { - def build[E: Floating](input: Expr[E], output: Expr[E], weights: Seq[Expr[E]]): Expr[E] = - lossF.build(model.build(input, weights), output) plus model.penalty(input.shape, weights) + def build_[E: Floating]( + input: Expr[E], + output: Expr[E], + params: Params[Expr[E]]): Expr[E] = + buildStateful_(input, output, params)._1 - def buildStateful[E: Floating]( + def buildStateful_[E: Floating]( input: Expr[E], output: Expr[E], - weights: Seq[Expr[E]], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = { - val (result, nextState) = model.buildStateful(input, weights, state) - val loss = lossF.build(result, output) plus model.penalty(input.shape, weights) - (loss, nextState) + params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = { + val (result, nextParams) = model.build_(input, params) + val loss = lossF.build(result, output) plus model.penalty_(input.shape, params) + (loss, nextParams) } - def loss[E: Floating]: (Expr[E], Expr[E], Seq[Expr[E]]) => Expr[E] = - (input, output, weights) => build(input, output, weights) + def loss_[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => Expr[E] = + (input, output, params) => buildStateful_(input, output, params)._1 - def lossStateful[E: Floating] - : (Expr[E], Expr[E], Seq[Expr[E]], Seq[Expr[E]]) => (Expr[E], Seq[Expr[E]]) = - (input, output, weights, state) => buildStateful(input, output, weights, state) + def lossStateful_[E: Floating] + : (Expr[E], Expr[E], Params[Expr[E]]) => (Expr[E], Params[Expr[E]]) = + (input, output, params) => buildStateful_(input, output, params) - def grad[E: Floating]: (Expr[E], Expr[E], Seq[Expr[E]]) => Seq[Expr[E]] = - (input, output, weights) => build(input, output, weights).grad(weights).returns[E] + def grad_[E: Floating]: (Expr[E], Expr[E], Params[Expr[E]]) => Params[Expr[E]] = + (input, output, weights) => { + val loss = build_(input, output, weights) + loss.grad(weights).returns[E] + } - def gradStateful[E: Floating] - : (Expr[E], Expr[E], Seq[Expr[E]], Seq[Expr[E]]) => (Seq[Expr[E]], Seq[Expr[E]]) = + def gradStateful_[E: Floating] + : (Expr[E], Expr[E], Params[Expr[E]], Params[Expr[E]]) => (Params[Expr[E]], Params[Expr[E]]) = (input, output, weights, state) => { - val (loss, nextState) = buildStateful(input, output, weights, state) + val (loss, nextState) = buildStateful_(input, output, weights ++ state) val grad = loss.grad(weights).returns[E] (grad, nextState) } - def trained[E: Floating](weights: Seq[Tensor[E]]) = new TrainedModel(this, weights) + def trained_[E: Floating](params: Params[Tensor[E]]) = new TrainedModel_(this, params) def displayLoss[E: Floating](input: Shape, dir: String = ""): Unit = { - build( - placeholder[E](input), - placeholder[E](model.outputShape(input)), - model.weightsShapes(input).map(s => placeholder[E](s))) + val params = model.params_(input) + build_( + input = placeholder[E](input), + output = placeholder[E](model.outputShape(input)), + params = params.mapValues(paramDef => placeholder[E](paramDef.shape))) .as("loss") .display(dir) } def displayGrad[E: Floating](input: Shape, dir: String = ""): Unit = { - grad[E].apply( + val (weights, state) = model.params_(input).partitionValues(_.trainable) + val (grad, _) = gradStateful_[E].apply( placeholder[E](input), placeholder[E](model.outputShape(input)), - model.weightsShapes(input).map(s => placeholder[E](s))) - .zipWithIndex - .map { case (w, i) => w.as(s"loss_grad_${i}_layer") } + weights.mapValues(paramDef => placeholder[E](paramDef.shape)), + state.mapValues(paramDef => placeholder[E](paramDef.shape))) + grad.params + .map { case (path, w) => (path, w.as(s"loss_grad_${path}_layer")) } .display(dir) } override def toString: String = s"$lossF($model)" } -class TrainedModel[E: Floating](val lossModel: LossModel, val weights: Seq[Tensor[E]]) { +class TrainedModel_[E: Floating](val lossModel: LossModel, val params: Params[Tensor[E]]) { - def buildResult(input: Expr[E]): Expr[E] = lossModel.model.build(input, weights.map(_.const)) + def buildResult(input: Expr[E]): Expr[E] = + buildResultStateful(input)._1 - def buildResultStateful(input: Expr[E], state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = - lossModel.model.buildStateful(input, weights.map(_.const), state) + def buildResultStateful(input: Expr[E]): (Expr[E], Params[Expr[E]]) = + lossModel.model.build_(input, params.mapValues(_.const)) def result: Expr[E] => Expr[E] = (input: Expr[E]) => buildResult(input) - def resultStateful: (Expr[E], Seq[Expr[E]]) => (Expr[E], Seq[Expr[E]]) = - (input: Expr[E], state: Seq[Expr[E]]) => buildResultStateful(input, state) + def resultStateful: Expr[E] => (Expr[E], Params[Expr[E]]) = + (input: Expr[E]) => buildResultStateful(input) def buildLoss(input: Expr[E], output: Expr[E]): Expr[E] = - lossModel.build(input, output, weights.map(_.const)) + buildLossStateful(input, output)._1 def buildLossStateful( input: Expr[E], - output: Expr[E], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = - lossModel.buildStateful(input, output, weights.map(_.const), state) + output: Expr[E]): (Expr[E], Params[Expr[E]]) = + lossModel.buildStateful_(input, output, params.mapValues(_.const)) def loss: (Expr[E], Expr[E]) => Expr[E] = (input, output) => buildLoss(input, output) - def lossStateful: (Expr[E], Expr[E], Seq[Expr[E]]) => (Expr[E], Seq[Expr[E]]) = - (input, output, state) => buildLossStateful(input, output, state) + def lossStateful: (Expr[E], Expr[E]) => (Expr[E], Params[Expr[E]]) = + (input, output) => buildLossStateful(input, output) def outputShape(input: Shape): Shape = lossModel.model.outputShape(input) } diff --git a/src/main/scala/scanet/models/ParamDef.scala b/src/main/scala/scanet/models/ParamDef.scala new file mode 100644 index 0000000..89c9670 --- /dev/null +++ b/src/main/scala/scanet/models/ParamDef.scala @@ -0,0 +1,12 @@ +package scanet.models + +import scanet.core.{Floating, Shape} + +case class ParamDef( + shape: Shape, + initializer: Initializer = Initializer.Zeros, + aggregation: Option[Aggregation] = None, + trainable: Boolean = false) { + def nonTrainable: Boolean = !trainable + def initialize[E: Floating] = initializer.build[E](shape) +} diff --git a/src/main/scala/scanet/models/layer/Activate.scala b/src/main/scala/scanet/models/layer/Activate.scala index 7a1c6ef..aadde1c 100644 --- a/src/main/scala/scanet/models/layer/Activate.scala +++ b/src/main/scala/scanet/models/layer/Activate.scala @@ -2,7 +2,6 @@ package scanet.models.layer import scanet.core.{Expr, Floating, Shape} import scanet.models.Activation -import scala.collection.immutable.Seq /** A layer which applies activation function to the input. * @@ -10,14 +9,12 @@ import scala.collection.immutable.Seq * * @param activation activation function */ -case class Activate(activation: Activation) extends WeightlessLayer { +case class Activate(activation: Activation) extends NotTrainableLayer { override def name: String = activation.toString - override def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]): Expr[E] = { - require(weights.isEmpty, "Activate layer does not require weights") + override def build_[E: Floating](input: Expr[E]): Expr[E] = activation.build(input) - } override def outputShape(input: Shape): Shape = input diff --git a/src/main/scala/scanet/models/layer/Bias.scala b/src/main/scala/scanet/models/layer/Bias.scala index e330a14..6d7f745 100644 --- a/src/main/scala/scanet/models/layer/Bias.scala +++ b/src/main/scala/scanet/models/layer/Bias.scala @@ -1,9 +1,11 @@ package scanet.models.layer -import scanet.core.{Expr, Floating, Shape} +import scanet.core.Params._ +import scanet.core.{Expr, Floating, Params, Shape} +import scanet.models.Aggregation.Avg import scanet.models.Initializer.Zeros -import scanet.models.{Initializer, Regularization} import scanet.models.Regularization.Zero +import scanet.models.{Initializer, ParamDef, Regularization} import scanet.syntax._ import scala.collection.immutable.Seq @@ -21,17 +23,15 @@ import scala.collection.immutable.Seq case class Bias(features: Int, reg: Regularization = Zero, initializer: Initializer = Zeros) extends StatelessLayer { - override def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]): Expr[E] = { - require(weights.size == 1, "Bias layer can have only one set of weights") - input + weights.head - } + override def params_(input: Shape): Params[ParamDef] = + Params(Weights -> ParamDef(Shape(features), initializer, Some(Avg), trainable = true)) - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]): Expr[E] = reg.build(weights.head) + override def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] = + input + params.weights - override def weightsShapes(input: Shape): Seq[Shape] = Seq(Shape(features)) - - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = - Seq(initializer.build[E](weightsShapes(input).head)) + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = + reg.build(params.weights) override def outputShape(input: Shape): Shape = input + } diff --git a/src/main/scala/scanet/models/layer/Composed.scala b/src/main/scala/scanet/models/layer/Composed.scala index 994a6f5..eb54cbc 100644 --- a/src/main/scala/scanet/models/layer/Composed.scala +++ b/src/main/scala/scanet/models/layer/Composed.scala @@ -1,7 +1,8 @@ package scanet.models.layer -import scanet.core.{Expr, Floating, Shape} +import scanet.core.{Expr, Floating, Params, Shape} import scanet.math.syntax._ +import scanet.models.ParamDef import scala.collection.immutable.Seq @@ -12,41 +13,34 @@ import scala.collection.immutable.Seq */ case class Composed(left: Layer, right: Layer) extends Layer { - override def buildStateful[E: Floating]( - input: Expr[E], - weights: Seq[Expr[E]], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = { - val (leftWeights, rightWeights) = weights.splitAt(left.weightsShapes(input.shape).size) - val (leftState, rightState) = state.splitAt(left.stateShapes(input.shape).size) - val (leftOutput, leftNewState) = left.buildStateful(input, leftWeights, leftState) - val (rightOutput, rightNewState) = right.buildStateful(leftOutput, rightWeights, rightState) - (rightOutput, leftNewState ++ rightNewState) + override def params_(input: Shape): Params[ParamDef] = { + // todo: flatten + val leftParams = left.params_(input).prependPath("l") + val rightParams = right.params_(left.outputShape(input)).prependPath("r") + leftParams ++ rightParams } - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]) = { - val (leftWeights, rightWeights) = weights.splitAt(left.weightsShapes(input).size) - left.penalty(input, leftWeights) plus right.penalty(left.outputShape(input), rightWeights) + override def build_[E: Floating]( + input: Expr[E], + params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = { + val leftParams = params.children("l") + val rightParams = params.children("r") + val (leftOutput, leftState) = left.build_(input, leftParams) + val (rightOutput, rightState) = right.build_(leftOutput, rightParams) + (rightOutput, leftState.prependPath("l") ++ rightState.prependPath("r")) } - override def outputShape(input: Shape): Shape = right.outputShape(left.outputShape(input)) - - override def weightsShapes(input: Shape): Seq[Shape] = { - val leftShapes = left.weightsShapes(input) - val rightShapes = right.weightsShapes(left.outputShape(input)) - leftShapes ++ rightShapes + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = { + val leftParams = params.children("l") + val rightParams = params.children("r") + left.penalty_(input, leftParams) plus right.penalty_(left.outputShape(input), rightParams) } - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = { - val leftShapes = left.initWeights[E](input) - val rightShapes = right.initWeights[E](left.outputShape(input)) - leftShapes ++ rightShapes - } + override def outputShape(input: Shape): Shape = + right.outputShape(left.outputShape(input)) - override def stateShapes(input: Shape): Seq[Shape] = { - val leftShapes = left.stateShapes(input) - val rightShapes = right.stateShapes(left.outputShape(input)) - leftShapes ++ rightShapes - } + override def stateful: Boolean = + left.stateful || right.stateful override def info(input: Shape): Seq[LayerInfo] = { val rightInput = left.outputShape(input) @@ -54,5 +48,4 @@ case class Composed(left: Layer, right: Layer) extends Layer { } override def toString: String = s"$left >> $right" - } diff --git a/src/main/scala/scanet/models/layer/Conv2D.scala b/src/main/scala/scanet/models/layer/Conv2D.scala index ba28c25..45d30a3 100644 --- a/src/main/scala/scanet/models/layer/Conv2D.scala +++ b/src/main/scala/scanet/models/layer/Conv2D.scala @@ -1,12 +1,14 @@ package scanet.models.layer -import scanet.core.{Expr, Floating, Shape} +import scanet.core.Params.Weights +import scanet.core.{Expr, Floating, Params, Shape} import scanet.math.nn.ConvFormat._ import scanet.math.nn.Padding._ import scanet.math.nn.{ConvFormat, Padding} import scanet.math.syntax.zeros -import scanet.models.{Activation, Initializer, Regularization} +import scanet.models.{Activation, Initializer, ParamDef, Regularization} import scanet.models.Activation.Identity +import scanet.models.Aggregation.Avg import scanet.models.Initializer.{GlorotUniform, Zeros} import scanet.models.Regularization.Zero import scanet.syntax._ @@ -78,33 +80,30 @@ case class Conv2D private ( def filterHeight: Int = kernel._1 def filterWidth: Int = kernel._2 - override def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]): Expr[E] = { - require(weights.size == 1, "Conv2D layer can have only one set of weights") + override def params_(input: Shape): Params[ParamDef] = { + require( + input.rank == 4, + s"Conv2D input should have a shape (NHWC) or (NCHW) but was $input") + val shape = Shape(filterHeight, filterWidth, input(format.cAxis), filters) + Params(Weights -> ParamDef(shape, initializer, Some(Avg), trainable = true)) + } + + override def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] = { // Conv2D example: // input = (batch_shape, in_height, in_width, in_channels) = (1, 5, 5, 1) // filters = (filter_height, filter_width, in_channels, out_channels) = (2, 2, 1, 1) // output = (batch_shape, out_height, out_width, out_channels) = (1, 5, 5, 1) conv2D[E]( input = input, - filters = weights.head, + filters = params.weights, strides = Seq(strides._1, strides._2), padding = padding, format = format) } - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]): Expr[E] = + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = zeros[E](Shape()) - override def weightsShapes(input: Shape): Seq[Shape] = { - require( - input.rank == 4, - s"Conv2D input should have a shape (NHWC) or (NCHW) but was $input") - Seq(Shape(filterHeight, filterWidth, input(format.cAxis), filters)) - } - - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = - Seq(initializer.build[E](weightsShapes(input).head)) - override def outputShape(input: Shape): Shape = { require( input.rank == 4, diff --git a/src/main/scala/scanet/models/layer/Dense.scala b/src/main/scala/scanet/models/layer/Dense.scala index 0feb32f..14603c7 100644 --- a/src/main/scala/scanet/models/layer/Dense.scala +++ b/src/main/scala/scanet/models/layer/Dense.scala @@ -1,9 +1,11 @@ package scanet.models.layer -import scanet.core.{Expr, Floating, Shape} +import scanet.core.Params.Weights +import scanet.core.{Expr, Floating, Params, Shape} +import scanet.models.Aggregation.Avg import scanet.models.Initializer.{GlorotUniform, Zeros} import scanet.models.Regularization.Zero -import scanet.models.{Activation, Initializer, Regularization} +import scanet.models.{Activation, Initializer, ParamDef, Regularization} import scanet.syntax._ import scala.collection.immutable.Seq @@ -42,25 +44,17 @@ object Dense { case class Dense private (outputs: Int, reg: Regularization, initializer: Initializer) extends StatelessLayer { - override def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]) = { - require(weights.size == 1, "Dense layer can have only one set of weights") + override def params_(input: Shape): Params[ParamDef] = + Params(Weights -> ParamDef(Shape(input(1), outputs), initializer, Some(Avg), trainable = true)) + + override def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] = // x:(samples, features) // w:(features, outputs) // x * w -> (samples, features) * (features, outputs) -> (samples, outputs) - input matmul weights.head - } - - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]) = - reg.build(weights.head) + input matmul params.weights - override def weightsShapes(input: Shape): Seq[Shape] = { - require(input.rank == 2, "features should have a shape (batch, features)") - Seq(Shape(input(1), outputs)) - } - - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = - Seq(initializer.build[E](weightsShapes(input).head)) + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = + reg.build(params.weights) override def outputShape(input: Shape): Shape = Shape(input.head, outputs) - } diff --git a/src/main/scala/scanet/models/layer/Flatten.scala b/src/main/scala/scanet/models/layer/Flatten.scala index 0392864..9961f53 100644 --- a/src/main/scala/scanet/models/layer/Flatten.scala +++ b/src/main/scala/scanet/models/layer/Flatten.scala @@ -2,16 +2,14 @@ package scanet.models.layer import scanet.core.syntax._ import scanet.core.{Expr, Floating, Shape} -import scala.collection.immutable.Seq /** A layer which flattens the input tensor of any shape to 2 dims matrix. * * Given an input tensor `Shape(N, H, W, C)`, after flattening we will get `Shape(N, H*W*C)` */ -case object Flatten extends WeightlessLayer { +case object Flatten extends NotTrainableLayer { - override def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]): Expr[E] = { - require(weights.isEmpty, "Flatten layer does not require weights") + override def build_[E: Floating](input: Expr[E]): Expr[E] = { val shape = input.shape require(shape.rank >= 2, s"rank should be >= 2, but was ${shape.rank}") val batch = shape(0) diff --git a/src/main/scala/scanet/models/layer/Layer.scala b/src/main/scala/scanet/models/layer/Layer.scala index 00570cc..9e32573 100644 --- a/src/main/scala/scanet/models/layer/Layer.scala +++ b/src/main/scala/scanet/models/layer/Layer.scala @@ -1,16 +1,16 @@ package scanet.models.layer -import scanet.core.{Expr, Floating, Shape} +import scanet.core.{Expr, Floating, Params, Shape} import scanet.math.syntax.zeros -import scanet.models.Model +import scanet.models.{Model, ParamDef} import scala.collection.immutable.Seq - import scala.annotation.nowarn trait Layer extends Model { def trainable: Boolean = true + def stateful: Boolean /** Compose `right` layer with `this` (`left`) layer. * @@ -31,21 +31,29 @@ trait Layer extends Model { trait StatelessLayer extends Layer { - def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]): Expr[E] + override def stateful: Boolean = false - override def buildStateful[E: Floating]( - input: Expr[E], - weights: Seq[Expr[E]], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = - (build(input, weights), Seq.empty) + def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] - override def stateShapes(input: Shape): Seq[Shape] = Seq.empty + override def build_[E: Floating]( + input: Expr[E], + params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = { + (buildStateless_(input, params), Params.empty) + } } -trait WeightlessLayer extends StatelessLayer { +trait NotTrainableLayer extends StatelessLayer { + override def trainable: Boolean = false - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]): Expr[E] = + + override def params_(input: Shape): Params[ParamDef] = Params.empty + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = zeros[E](Shape()) - override def weightsShapes(input: Shape): Seq[Shape] = Seq.empty - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = Seq.empty + + def build_[E: Floating](input: Expr[E]): Expr[E] + + override def buildStateless_[E: Floating](input: Expr[E], params: Params[Expr[E]]): Expr[E] = { + require(params.isEmpty, s"$this layer does not require params") + build_(input) + } } diff --git a/src/main/scala/scanet/models/layer/LayerInfo.scala b/src/main/scala/scanet/models/layer/LayerInfo.scala index 9c89d85..bdb4564 100644 --- a/src/main/scala/scanet/models/layer/LayerInfo.scala +++ b/src/main/scala/scanet/models/layer/LayerInfo.scala @@ -5,7 +5,7 @@ import scanet.core.Shape import scala.annotation.tailrec import scala.collection.immutable.Seq -case class LayerInfo(name: String, weights: Seq[Shape], output: Shape) { +case class LayerInfo(name: String, weights: Seq[Shape], state: Seq[Shape], output: Shape) { private def group[A](input: Seq[A]): (Seq[A], Int) = { def repeated(size: Int): Boolean = { @@ -36,10 +36,14 @@ case class LayerInfo(name: String, weights: Seq[Shape], output: Shape) { if (size > 1) s"[$value]x$size" else value } + def weightsTotal: Int = weights.map(_.power).sum + def stateTotal: Int = state.map(_.power).sum + def toRow: Seq[String] = { val weightsStr = groupConcat(weights) - val params = groupConcat(weights.map(_.power)) + val weightParams = groupConcat(weights.map(_.power)) + val stateParams = groupConcat(state.map(_.power)) val outputStr = ("_" +: output.tail.dims.map(_.toString)).mkString("(", ",", ")") - Seq(name, weightsStr, params, outputStr) + Seq(name, weightsStr, weightParams, stateParams, outputStr) } } diff --git a/src/main/scala/scanet/models/layer/Pool2D.scala b/src/main/scala/scanet/models/layer/Pool2D.scala index 2a077f3..c47efa6 100644 --- a/src/main/scala/scanet/models/layer/Pool2D.scala +++ b/src/main/scala/scanet/models/layer/Pool2D.scala @@ -31,17 +31,15 @@ case class Pool2D( padding: Padding = Valid, format: ConvFormat = NHWC, reduce: Reduce = Reduce.Max) - extends WeightlessLayer { + extends NotTrainableLayer { - override def build[E: Floating](input: Expr[E], weights: Seq[Expr[E]]): Expr[E] = { - require(weights.isEmpty, "Pool2D layer does not require weights") + override def build_[E: Floating](input: Expr[E]): Expr[E] = pool2D[E]( input = input, window = Seq(window._1, window._2), strides = Seq(strides._1, strides._2), padding = padding, format = format) - } override def outputShape(input: Shape): Shape = { require( diff --git a/src/main/scala/scanet/models/layer/RNN.scala b/src/main/scala/scanet/models/layer/RNN.scala index 2bf218e..c5e643a 100644 --- a/src/main/scala/scanet/models/layer/RNN.scala +++ b/src/main/scala/scanet/models/layer/RNN.scala @@ -1,9 +1,12 @@ package scanet.models.layer -import scanet.core.{Expr, Floating, Shape} + +import scanet.core.{Expr, Floating, Params, Path, Shape} +import scanet.math.syntax.zeros import scanet.models.Activation.{Sigmoid, Tanh} +import scanet.models.Aggregation.Avg import scanet.models.Initializer.{GlorotUniform, Ones, Orthogonal, Zeros} import scanet.models.Regularization.Zero -import scanet.models.{Activation, Initializer, Regularization} +import scanet.models.{Activation, Initializer, ParamDef, Regularization} import scanet.syntax._ import scala.annotation.tailrec @@ -21,39 +24,57 @@ import scala.collection.immutable.Seq * @param returnSequence Whether to return the last output in the output sequence, or the full sequence. * To stack multiple layers set `returnSequence=true` */ -case class RNN(cell: Layer, returnSequence: Boolean = false) extends Layer { +case class RNN(cell: Layer, returnSequence: Boolean = false, stateful: Boolean = false) + extends Layer { + + // todo: better params management + + override def params_(input: Shape): Params[ParamDef] = { + val (weights, state) = paramsPartitioned(input) + if (stateful) state ++ weights else weights + } + + private def paramsPartitioned(input: Shape): (Params[ParamDef], Params[ParamDef]) = + cell.params_(dropTime(input)).partitionValues(_.trainable) - override def buildStateful[E: Floating]( + override def build_[E: Floating]( input: Expr[E], - weights: Seq[Expr[E]], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = { + params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = { require(input.rank == 3, "RNN requires input to have Shape(batch, time, features)") // input: (batch, time, features) -> stepsInput: (time, batch, features) val stepsInput = input.transpose(Seq(1, 0) ++ (2 until input.rank)) val timeSteps = stepsInput.shape(0) + val (weightParamsDef, stateParamsDef) = paramsPartitioned(input.shape) + val weightParamsNames = weightParamsDef.params.keySet + val (weightParams, stateParams) = params.params + .partition { case (path, _) => weightParamsNames(path) } + @tailrec def stackCells( step: Int, outputs: Seq[Expr[E]], - state: Seq[Expr[E]]): (Seq[Expr[E]], Seq[Expr[E]]) = { - val (output, outputState) = cell.buildStateful(stepsInput.slice(step), weights, state) + state: Params[Expr[E]]): (Seq[Expr[E]], Params[Expr[E]]) = { + val (output, outputState) = cell.build_(stepsInput.slice(step), state ++ Params(weightParams)) if (step < timeSteps - 1) { stackCells(step + 1, outputs :+ output, outputState) } else { (outputs :+ output, outputState) } } - val (outputs, lastOutputState) = stackCells(0, Seq.empty, state) + + val inputState = + if (stateful) Params(stateParams) else stateParamsDef.mapValues(d => zeros[E](d.shape)) + val (outputs, lastOutputState) = stackCells(0, Seq.empty, inputState) val output = if (returnSequence) joinAlong(outputs, 1).reshape(outputShape(input.shape)) else outputs.last (output, lastOutputState) } - private def dropTime(input: Shape): Shape = input.remove(1) + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = + cell.penalty_(dropTime(input), params) - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]): Expr[E] = - cell.penalty(dropTime(input), weights) + private def dropTime(input: Shape): Shape = input.remove(1) override def outputShape(input: Shape): Shape = { val cellOutput = cell.outputShape(dropTime(input)) @@ -62,13 +83,6 @@ case class RNN(cell: Layer, returnSequence: Boolean = false) extends Layer { else cell.outputShape(cellOutput) } - - override def weightsShapes(input: Shape): Seq[Shape] = cell.weightsShapes(dropTime(input)) - - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = - cell.initWeights(dropTime(input)) - - override def stateShapes(input: Shape): Seq[Shape] = cell.stateShapes(dropTime(input)) } object SimpleRNN { @@ -98,7 +112,8 @@ object SimpleRNN { kernelReg: Regularization = Zero, recurrentReg: Regularization = Zero, biasReg: Regularization = Zero, - returnSequence: Boolean = false): RNN = + returnSequence: Boolean = false, + stateful: Boolean = false): RNN = RNN( SimpleRNNCell( units, @@ -110,10 +125,12 @@ object SimpleRNN { kernelReg, recurrentReg, biasReg), - returnSequence) + returnSequence, + stateful) } object SimpleRNNCell { + def apply( units: Int, activation: Activation = Tanh, @@ -129,6 +146,10 @@ object SimpleRNNCell { new SimpleRNNCell(units, kernelInitializer, recurrentInitializer, kernelReg, recurrentReg) cell ?>> (bias, Bias(units, biasReg, biasInitializer)) ?>> (activation.ni, activation.layer) } + + val Kernel: Path = "kernel_weights" + val Recurrent: Path = "recurrent_weights" + val State: Path = "state" } case class SimpleRNNCell( @@ -138,37 +159,30 @@ case class SimpleRNNCell( kernelReg: Regularization, recurrentReg: Regularization) extends Layer { + import SimpleRNNCell._ + + override def stateful: Boolean = true - override def buildStateful[E: Floating]( + override def params_(input: Shape): Params[ParamDef] = Params( + // (features, units) + Kernel -> ParamDef(Shape(input(1), units), kernelInitializer, Some(Avg), trainable = true), + // (units, units) + Recurrent -> ParamDef(Shape(units, units), recurrentInitializer, Some(Avg), trainable = true), + // state, keeps previous output + State -> ParamDef(outputShape(input), Initializer.Zeros)) + + override def build_[E: Floating]( input: Expr[E], - weights: Seq[Expr[E]], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = { + params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = { require(input.rank >= 2, "SimpleRNNCell requires input Seq(batch, features)") - require(weights.size == 2, "SimpleRNNCell requires weights Seq(kernel, recurrent)") - require(state.size == 1, "SimpleRNNCell requires single state") - val kernel = weights.head - val recurrent = weights(1) - val result = (input matmul kernel) + (state.head matmul recurrent) - (result, Seq(result)) + val result = (input matmul params(Kernel)) + (params(State) matmul params(Recurrent)) + (result, Params(State -> result)) } - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]): Expr[E] = - kernelReg.build(weights.head) + recurrentReg.build(weights(1)) + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = + kernelReg.build(params(Kernel)) + recurrentReg.build(params(Recurrent)) override def outputShape(input: Shape): Shape = Shape(input.head, units) - - override def weightsShapes(input: Shape): Seq[Shape] = { - val wx = Shape(input(1), units) // (features, units) - val wh = Shape(units, units) // (units, units) - Seq(wx, wh) - } - - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = { - val Seq(wx, wh) = weightsShapes(input) - Seq(kernelInitializer.build[E](wx), recurrentInitializer.build[E](wh)) - } - - override def stateShapes(input: Shape): Seq[Shape] = Seq(outputShape(input)) } object LSTM { @@ -199,7 +213,8 @@ object LSTM { kernelReg: Regularization = Zero, recurrentReg: Regularization = Zero, biasReg: Regularization = Zero, - returnSequence: Boolean = false): RNN = + returnSequence: Boolean = false, + stateful: Boolean = false): RNN = RNN( LSTMCell( units, @@ -213,7 +228,10 @@ object LSTM { kernelReg, recurrentReg, biasReg), - returnSequence) + returnSequence, + stateful) + val CellState: Path = "cell_state" + val HiddenState: Path = "hidden_state" } case class LSTMCell( @@ -229,9 +247,10 @@ case class LSTMCell( recurrentReg: Regularization = Zero, biasReg: Regularization = Zero) extends Layer { + import LSTM._ - private def cell(activation: Activation, useBias: Initializer) = - SimpleRNNCell( + private def cell(path: Path, activation: Activation, useBias: Initializer) = + path -> SimpleRNNCell( units, activation, bias, @@ -242,11 +261,28 @@ case class LSTMCell( recurrentReg, biasReg) - private val fCell = cell(recurrentActivation, useBias = biasForgetInitializer) // forget - private val iCell = cell(recurrentActivation, useBias = biasInitializer) // input - private val gCell = cell(activation, useBias = biasInitializer) // gate - private val oCell = cell(recurrentActivation, useBias = biasInitializer) // output - private val cells = Seq(fCell, iCell, gCell, oCell) + private val fCell = cell("forget", recurrentActivation, useBias = biasForgetInitializer) + private val iCell = cell("input", recurrentActivation, useBias = biasInitializer) + private val gCell = cell("gate", activation, useBias = biasInitializer) + private val oCell = cell("output", recurrentActivation, useBias = biasInitializer) + private val cells = Seq(fCell, iCell, gCell, oCell).map(_._2) + private val cells_ = Map(fCell, iCell, gCell, oCell) + + override def stateful: Boolean = true + + override def params_(input: Shape): Params[ParamDef] = { + val weights = cells_ + .map { + case (name, cell) => + val params = cell.params_(input) + val onlyWeights = params.filterPaths(path => !path.endsWith(SimpleRNNCell.State)) + onlyWeights.prependPath(name) + } + .reduce(_ ++ _) + val states = Params(Seq(CellState, HiddenState) + .map(path => path -> ParamDef(outputShape(input), Zeros)): _*) + weights ++ states + } /** Shapes: * - input c t-1: (batch, features) @@ -260,48 +296,35 @@ case class LSTMCell( * - output h: (batch, units) * - output y: (batch, units) */ - override def buildStateful[E: Floating]( + override def build_[E: Floating]( input: Expr[E], - weights: Seq[Expr[E]], - state: Seq[Expr[E]]): (Expr[E], Seq[Expr[E]]) = { + params: Params[Expr[E]]): (Expr[E], Params[Expr[E]]) = { require(input.rank >= 2, "LSTMCell requires input Seq(batch, features)") - require( - weights.size == 8 + (if (bias) 4 else 0), - "LSTMCell requires weights (Seq(kernel, recurrent, bias?) * 4") - require(state.size == 2, "LSTMCell requires state Seq(c_prev, h_prev)") - val Seq(cPrev, hPrev) = state - val Seq(f, i, g, o) = cells.zip(unpackWeights(input.shape, weights)).map { - case (cell, weights) => cell.buildStateful(input, weights, Seq(hPrev))._1 + val cPrev = params(CellState) + val hPrev = params(HiddenState) + val List(f, i, g, o) = cells_.toList.map { + case (name, cell) => + val cellState = cell.params_(input.shape) + .filter { + case (path, param) => + path.endsWith(Params.State) && param.nonTrainable && param.shape == hPrev.shape + } + .mapValues(_ => hPrev) + val cellParams = params.children(name) ++ cellState + cell.build_(input, cellParams)._1 } val c = cPrev * f + i * g val h = o * activation.build(c) - (h, Seq(c, h)) - } - - private def unpackWeights[E](input: Shape, weights: Seq[Expr[E]]): Seq[Seq[Expr[E]]] = { - val (_, unpacked) = cells.foldLeft((weights, Seq.empty[Seq[Expr[E]]])) { - case ((weights, acc), cell) => - val size = cell.weightsShapes(input).size - val (cellWeights, remainWeights) = weights.splitAt(size) - (remainWeights, acc :+ cellWeights) - } - unpacked + (h, Params(CellState -> c, HiddenState -> h)) } - override def penalty[E: Floating](input: Shape, weights: Seq[Expr[E]]): Expr[E] = - cells.zip(unpackWeights(input, weights)).foldLeft(Floating[E].zero.const) { - case (sum, (cell, weights)) => sum + cell.penalty(input, weights) + override def penalty_[E: Floating](input: Shape, params: Params[Expr[E]]): Expr[E] = + cells_.foldLeft(Floating[E].zero.const) { + case (sum, (name, cell)) => + val cellParams = params.children(name) + sum + cell.penalty_(input, cellParams) } override def outputShape(input: Shape): Shape = - oCell.outputShape(input) - - override def weightsShapes(input: Shape): Seq[Shape] = - cells.flatMap(_.weightsShapes(input)) - - override def initWeights[E: Floating](input: Shape): Seq[Expr[E]] = - cells.flatMap(_.initWeights[E](input)) - - override def stateShapes(input: Shape): Seq[Shape] = - Seq.fill(2)(outputShape(input)) + oCell._2.outputShape(input) } diff --git a/src/main/scala/scanet/optimizers/Effect.scala b/src/main/scala/scanet/optimizers/Effect.scala index 8a5dbb8..6f0f325 100644 --- a/src/main/scala/scanet/optimizers/Effect.scala +++ b/src/main/scala/scanet/optimizers/Effect.scala @@ -55,7 +55,7 @@ object Effect { tensorboard: Boolean = false) extends Effect[E] { override def apply(state: State, next: StepContext[E]): State = { - val trained = next.lossModel.trained(next.result.weights) + val trained = next.lossModel.trained_(next.result.params) val a = accuracy(trained, ds, next.step.batch) if (tensorboard) state.board.addScalar("accuracy", a, next.step.iter) diff --git a/src/main/scala/scanet/optimizers/KryoSerializers.scala b/src/main/scala/scanet/optimizers/KryoSerializers.scala index e35cf99..850cdec 100644 --- a/src/main/scala/scanet/optimizers/KryoSerializers.scala +++ b/src/main/scala/scanet/optimizers/KryoSerializers.scala @@ -4,7 +4,7 @@ import com.esotericsoftware.kryo.Kryo import com.twitter.chill._ import org.apache.spark.serializer.KryoRegistrator import scanet.core.{Shape, Tensor, TensorType} -import scanet.models.TrainedModel +import scanet.models.TrainedModel_ import org.tensorflow.proto.framework.DataType.DT_STRING import scanet.core.syntax._ @@ -12,7 +12,7 @@ class KryoSerializers extends KryoRegistrator { override def registerClasses(kryo: Kryo): Unit = { kryo.forClass[Tensor[_]](new TensorSerializer()) - kryo.register(classOf[TrainedModel[_]]) + kryo.register(classOf[TrainedModel_[_]]) } } diff --git a/src/main/scala/scanet/optimizers/Optimizer.scala b/src/main/scala/scanet/optimizers/Optimizer.scala index 8da68d9..11eeec8 100644 --- a/src/main/scala/scanet/optimizers/Optimizer.scala +++ b/src/main/scala/scanet/optimizers/Optimizer.scala @@ -1,14 +1,17 @@ package scanet.optimizers +import com.github.benmanes.caffeine.cache.Caffeine import org.apache.spark.sql.Dataset import scanet.core.{Tensor, _} import scanet.math.syntax._ -import scanet.models.{Loss, LossModel, Model, TrainedModel} -import scanet.optimizers.syntax._ +import scanet.models._ import scanet.optimizers.Condition.always import scanet.optimizers.Optimizer.BuilderState._ -import scanet.optimizers.Optimizer.{sessionsPool, tfCache} +import scanet.optimizers.syntax._ +import java.time.Duration +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import scala.annotation.{nowarn, tailrec} import scala.collection._ import scala.collection.immutable.Seq @@ -22,9 +25,13 @@ case class Step[A: Numeric](batch: Int, epoch: Int = 0, iter: Int = 0) { case class StepResult[A: Numeric]( iterations: Int, - weights: Seq[Tensor[A]], - meta: Seq[Tensor[A]], - loss: A) + paramsDef: Params[ParamDef], + params: Params[Tensor[A]], + meta: Params[Tensor[A]], + loss: A) { + def paramsWithDef: Params[(ParamDef, Tensor[A])] = + paramsDef.join(params) +} case class StepContext[A: Numeric]( step: Step[A], @@ -38,7 +45,7 @@ case class Optimizer[A: Floating]( alg: Algorithm, model: Model, loss: Loss, - initWeights: () => Option[Seq[Tensor[A]]], + initParams: () => Option[Params[Tensor[A]]], dataset: Dataset[Record[A]], batchSize: Int, minimizing: Boolean, @@ -48,7 +55,8 @@ case class Optimizer[A: Floating]( private val lossModel = model.withLoss(loss) - def run(): TrainedModel[A] = { + def run(): TrainedModel_[A] = { + val jobId = UUID.randomUUID().toString val ds: Dataset[Record[A]] = dataset.cache() val sc = ds.sparkSession.sparkContext val board = TensorBoard(boardDir) @@ -58,21 +66,17 @@ case class Optimizer[A: Floating]( @tailrec def optimize( prevStep: Step[A], - weights: Seq[Tensor[A]], - meta: Seq[Tensor[A]]): Seq[Tensor[A]] = { - val weightsBr = sc.broadcast(weights) + params: Params[Tensor[A]], + meta: Params[Tensor[A]]): Params[Tensor[A]] = { + val weightsBr = sc.broadcast(params) val metaBr = sc.broadcast(meta) val start = System.currentTimeMillis() val result = ds.rdd .mapPartitions { it => - Iterator(optimizeOnPartition( - it, - shapes, - prevStep.iter, - weightsBr.value, - metaBr.value)) + Iterator( + optimizeOnPartition(jobId, it, shapes, prevStep.iter, weightsBr.value, metaBr.value)) } - .treeReduce(averageMetaAndWeights) + .treeReduce(aggParamsAndMeta(jobId)) val finish = System.currentTimeMillis() val step: Step[A] = prevStep.nextEpoch.incIter(result.iterations) val stepCtx = StepContext(step, result, lossModel, finish - start) @@ -82,112 +86,154 @@ case class Optimizer[A: Floating]( .run() if (stop(stepCtx)) { - result.weights + result.params } else { - optimize(step, result.weights, result.meta) + optimize(step, result.params, result.meta) } } - val weights = optimize(Step(batchSize), Seq(), Seq()) - lossModel.trained(weights) + val params = optimize(Step(batchSize), Params.empty, Params.empty) + + lossModel.trained_(params) } private def optimizeOnPartition( + jobId: String, it: scala.Iterator[Record[A]], shapes: (Shape, Shape), globalIter: Int, - weights: Seq[Tensor[A]], - meta: Seq[Tensor[A]]): StepResult[A] = { - val result = sessionsPool.within(session => { + params: Params[Tensor[A]], + meta: Params[Tensor[A]]): StepResult[A] = { + val resource = Optimizer.resource(jobId) + val result = resource.sessionPool.within(session => { val batches = TensorIterator(it, shapes, batchSize) - val batchedShapes = (batchSize +: shapes._1, batchSize +: shapes._2) - val (weightsInitialized, metaInitialized) = + val batchedInputShape = batchSize +: shapes._1 + val paramsDef = model.params_(batchedInputShape) + val weightNames = paramsDef.filterValues(_.trainable).paths + val (paramsInitialized, metaInitialized) = if (globalIter == 0) { - val weights = initWeights().getOrElse(model.initWeights[A](batchedShapes._1).eval) - val weightShapes = model.weightsShapes(batchedShapes._1) - val meta = weightShapes.map(alg.initMeta[A](_)) - (weights, meta) + val params = initParams() + .getOrElse(paramsDef.mapValues(param => param.initialize[A].eval)) + val meta = paramsDef.filterValues(_.trainable) + .mapValues(param => alg.initMeta[A](param.shape)) + (params, meta) } else { - (weights, meta) + (params, meta) } - val loss = compileLoss(session) - val calc = compileCalc(session) + val backprop = compileBackprop(resource.tfCache, session) + val loss: (Tensor[A], Tensor[A], Params[Tensor[A]]) => (Tensor[A], Params[Tensor[A]]) = + compileLoss(resource.tfCache, session) @tailrec - def optimize(iter: Int, weights: Seq[Tensor[A]], meta: Seq[Tensor[A]]): StepResult[A] = { + def optimize( + iter: Int, + params: Params[Tensor[A]], + meta: Params[Tensor[A]]): StepResult[A] = { val (x, y) = batches.next() - /*_*/ - val (nextWeights, nextMeta) = - calc(x, y, weights, meta, Tensor.scalar(globalIter + iter + 1)) - /*_*/ + val (weights, state) = params.partitionPaths(weightNames.contains) + val (nextWeights, nextMeta, nextState) = + backprop(x, y, weights, meta, state, Tensor.scalar(globalIter + iter + 1)) + val nextParams = nextWeights ++ nextState if (batches.hasNext) { - optimize(iter + 1, nextWeights, nextMeta) + optimize(iter + 1, nextParams, nextMeta) } else { - StepResult(iter + 1, nextWeights, nextMeta, loss(x, y, nextWeights).toScalar) + val (lossResult, _) = loss(x, y, nextParams) + StepResult(iter + 1, paramsDef, nextWeights, nextMeta, lossResult.toScalar) } } - optimize(0, weightsInitialized, metaInitialized) + optimize(0, paramsInitialized, metaInitialized) }) result } - private def compileLoss(session: Session) = { - tfCache.getOrCompute( + private def compileLoss(cache: Optimizer.Cache, session: Session) = { + cache.getOrCompute( s"$lossModel:loss[${TensorType[A].classTag}]", - lossModel.loss[A].tf) compileWith session + lossModel.lossStateful_[A].tf) compileWith session } - private def compileCalc(session: Session) = { - def newOutputSeq = Seq[Expr[A]]() - tfCache.getOrCompute( + private def compileBackprop(cache: Optimizer.Cache, session: Session) = { + def newOutputSeq = Params[Expr[A]]() + val tf = cache.getOrCompute( s"$lossModel:$alg:calc[${TensorType[A].classTag}]]", { - val func = - (x: Expr[A], y: Expr[A], ws: Seq[Expr[A]], metas: Seq[Expr[A]], iter: Expr[Int]) => { - val gs = lossModel.grad[A].apply(x, y, ws) - ws.zip(gs).zip(metas).foldLeft((newOutputSeq, newOutputSeq))((acc, next) => { + val func = ( + x: Expr[A], + y: Expr[A], + weights: Params[Expr[A]], + metas: Params[Expr[A]], + state: Params[Expr[A]], + iter: Expr[Int]) => { + val (grads, nextState) = lossModel.gradStateful_[A].apply(x, y, weights, state) + val (nextAcc, nextMeta) = weights.join(grads).join(metas).params + .foldLeft((newOutputSeq, newOutputSeq)) { (acc, next) => val (gAcc, metaAcc) = acc - val ((w, g), meta) = next + val (path, ((w, g), meta)) = next val Delta(del, metaNext) = alg.delta[A](g, meta, iter) val d = del.cast[A] val gNext = if (minimizing) w - d else w + d - (gAcc :+ gNext, metaAcc :+ metaNext) - }) - } + (gAcc + (path -> gNext), metaAcc + (path -> metaNext)) + } + (nextAcc, nextMeta, nextState) + } func.tf - }) compileWith session + }) + tf compileWith session } - private def averageMetaAndWeights(left: StepResult[A], right: StepResult[A]): StepResult[A] = { - sessionsPool.within { session => - val weightsAvg = tfCache.getOrCompute("weightsAvg", avg[A].tf) compileWith session - val metaAvg = tfCache.getOrCompute("metaAvg", avg[A].tf) compileWith session + private def aggParamsAndMeta(jobId: String)( + left: StepResult[A], + right: StepResult[A]): StepResult[A] = { + val resource = Optimizer.resource(jobId) + resource.sessionPool.within { session => + def buildParamsAgg( + leftParams: Params[Expr[A]], + rightParams: Params[Expr[A]]): Params[Expr[A]] = { + left.paramsDef.join(leftParams.join(rightParams)).mapValues { + case (paramDef, (l, r)) => + paramDef.aggregation match { + case Some(agg) => agg.build(Seq(l, r)) + case None => paramDef.initializer.build[A](l.shape) + } + } + } + def buildMetaAgg(leftMeta: Params[Expr[A]], rightMeta: Params[Expr[A]]) = { + (leftMeta join rightMeta).mapValues { case (l, r) => (l + r) / 2.0f.const.cast[A] } + } + val weightsAvg = + resource.tfCache.getOrCompute("paramsAgg", (buildParamsAgg _).tf) compileWith session + val metaAvg = + resource.tfCache.getOrCompute("metaAvg", (buildMetaAgg _).tf) compileWith session val lossAvg = (left.loss plus right.loss) / c.convert(2) StepResult( left.iterations + right.iterations, - weightsAvg(left.weights, right.weights), + left.paramsDef, + weightsAvg(left.params, right.params), metaAvg(left.meta, right.meta), lossAvg) } } - private def avg[X: Numeric] = - (arg1: Seq[Expr[X]], arg2: Seq[Expr[X]]) => { - (arg1 zip arg2).map { case (l, r) => (l + r) / 2.0f.const.cast[X] } - } } object Optimizer { + case class Resource(sessionPool: SessionPool, tfCache: Cache) extends AutoCloseable { + override def close(): Unit = sessionPool.close() + } + class Cache { - private val map = concurrent.TrieMap[String, Any]() - def getOrCompute[A](key: String, op: => A): A = { - map.get(key) match { - case Some(v) => v.asInstanceOf[A] - case None => val d = op; map(key) = d; d - } - } + private val map = new ConcurrentHashMap[String, Any]() + def getOrCompute[A](key: String, op: => A): A = + map.computeIfAbsent(key, _ => op).asInstanceOf[A] } - val sessionsPool = new SessionPool(64) - val tfCache = new Cache + private def CPUs: Int = Runtime.getRuntime.availableProcessors() + + private val resources = Caffeine.newBuilder() + .maximumSize(5) + .expireAfterAccess(Duration.ofSeconds(300)) + .removalListener((_: String, resource: Resource, _) => resource.close()) + .build((_: String) => Resource(new SessionPool(CPUs), new Cache)) + + def resource(id: String): Resource = resources.get(id) sealed trait BuilderState @@ -209,8 +255,8 @@ object Optimizer { def using(alg: Algorithm): Builder[A, State with WithAlg] = copy(optimizer = optimizer.copy(alg = alg)) - def initWeights(args: => Seq[Tensor[A]]): Builder[A, State] = - copy(optimizer = optimizer.copy(initWeights = () => Some(args))) + def initParams(args: => Params[Tensor[A]]): Builder[A, State] = + copy(optimizer = optimizer.copy(initParams = () => Some(args))) def on(dataset: Dataset[Record[A]]): Builder[A, State with WithDataset] = copy(optimizer = optimizer.copy(dataset = dataset)) @@ -242,7 +288,7 @@ object Optimizer { @nowarn def build(implicit ev: State =:= Complete): Optimizer[A] = optimizer - def run()(implicit ev: State =:= Complete): TrainedModel[A] = build.run() + def run()(implicit ev: State =:= Complete): TrainedModel_[A] = build.run() } def minimize[R: Floating](model: Model)( @@ -252,7 +298,7 @@ object Optimizer { alg = null, model = model, loss = null, - initWeights = () => None, + initParams = () => None, dataset = null, batchSize = 10000, minimizing = true, @@ -266,7 +312,7 @@ object Optimizer { alg = null, model = model, loss = null, - initWeights = () => None, + initParams = () => None, dataset = null, batchSize = 10000, minimizing = false, diff --git a/src/main/scala/scanet/research/Neuron.scala b/src/main/scala/scanet/research/Neuron.scala index bb537bb..44c03a9 100644 --- a/src/main/scala/scanet/research/Neuron.scala +++ b/src/main/scala/scanet/research/Neuron.scala @@ -1,6 +1,7 @@ package scanet.research -import scanet.core.{Shape, Tensor} +import scanet.core.Params.Weights +import scanet.core.{Params, Shape, Tensor} import scanet.models.Activation.Identity import scanet.models.Loss.MeanSquaredError import scanet.models.layer.Dense @@ -11,7 +12,7 @@ import scala.collection.immutable.Seq object Neuron { def main(args: Array[String]): Unit = { - val neuron = Dense(1, Identity) + val neuron = Dense(1, Identity, bias = false) // x: dataset, where each record consists of set of features // for example, lets take a person and try to predict how much money // he has on a bank account given the time he works each week and age @@ -23,16 +24,17 @@ object Neuron { println(neuron.outputShape(Shape(2))) // (1, 3) -> 3 = 2 + 1 -> 1 is bias val w = Tensor.matrix(Array(0.6f, 0.35f, 0.9f)) + val params = Params(Weights -> w) // to make a prediction we need to run a forward pass - val result = neuron.result[Float].compile + val result = neuron.result_[Float].compile // (0.7 * 50 + 0.5 * 25 + 1 * 1) - println(result(x, Seq(w))) + println(result(x, params)) + val paramsExpr = params.mapValues(_.const) // let's calculate prediction error (loss) - val ws = Seq(w.const) - val loss = MeanSquaredError.build(neuron.build(x.const, ws), y.const) + val loss = MeanSquaredError.build(neuron.build_(x.const, paramsExpr)._1, y.const) println(loss.eval) // let's calculate a gradient - val grads = loss.grad(ws).returns[Float] + val grads = loss.grad(paramsExpr).returns[Float] println(grads.eval) // now we can subtract a fraction of a gradient from weights // and next time loss should b smaller which means more accurate prediction diff --git a/src/test/scala/scanet/core/KernelsSpec.scala b/src/test/scala/scanet/core/KernelsSpec.scala index c5455ea..f8983d2 100644 --- a/src/test/scala/scanet/core/KernelsSpec.scala +++ b/src/test/scala/scanet/core/KernelsSpec.scala @@ -27,7 +27,6 @@ class KernelsSpec extends AnyFlatSpec with Matchers { } "composite output" should "have toString which includes operators chain" in { - println(5.0f.const.reshape(1).toString) 5.0f.const.reshape(1).toString should be("Reshape(Const(5.0)[Float]:(),new_shape:Const(1)[Int]:(1))[Float]:(1)") } @@ -36,15 +35,56 @@ class KernelsSpec extends AnyFlatSpec with Matchers { 5.0f.const.eval should be(Tensor.scalar(5.0f)) } - "product of 2 ops" should "be evaluated" in { + "product of 2 expr" should "be evaluated" in { (1.const, 2.const).eval should be((Tensor.scalar(1), Tensor.scalar(2))) } - "product of 3 ops" should "be evaluated" in { + "product of 3 expr" should "be evaluated" in { (1.const, 2.const, 3.const).eval should be( (Tensor.scalar(1), Tensor.scalar(2), Tensor.scalar(3))) } + "seq of expr" should "be evaluated" in { + Seq(1.const, 2.const).eval should be(Seq(Tensor.scalar(1), Tensor.scalar(2))) + } + + "seq of tuple expr" should "be evaluated" in { + Seq((1.const, 2.const), (3.const, 4.const)).eval should be( + Seq((Tensor.scalar(1), Tensor.scalar(2)), (Tensor.scalar(3), Tensor.scalar(4)))) + } + + "seq of seq expr" should "be evaluated" in { + Seq(Seq(1.const, 2.const), Seq(3.const, 4.const, 5.const)).eval should be( + Seq( + Seq(Tensor.scalar(1), Tensor.scalar(2)), + Seq(Tensor.scalar(3), Tensor.scalar(4), Tensor.scalar(5)))) + } + + "map of expr" should "be evaluated" in { + Map("a" -> 1.const, "b" -> 2.const).eval should be( + Map("a" -> Tensor.scalar(1), "b" -> Tensor.scalar(2))) + } + + "map of tuple expr" should "be evaluated" in { + Map("a" -> (1.const, 2.const), "b" -> (3.const, 4.const)).eval should be( + Map("a" -> (Tensor.scalar(1), Tensor.scalar(2)), "b" -> (Tensor.scalar(3), Tensor.scalar(4)))) + } + + "map of map expr" should "be evaluated" in { + val in = Map( + 1 -> Map("a" -> 1.const, "b" -> 2.const), + 2 -> Map("a" -> 3.const, "b" -> 4.const, "c" -> 5.const)) + val out = Map( + 1 -> Map("a" -> Tensor.scalar(1), "b" -> Tensor.scalar(2)), + 2 -> Map("a" -> Tensor.scalar(3), "b" -> Tensor.scalar(4), "c" -> Tensor.scalar(5))) + in.eval should be(out) + } + + "params of expr" should "be evaluated" in { + Params(Path("a") -> 1.const, Path("b") -> 2.const).eval should be( + Params(Path("a") -> Tensor.scalar(1), Path("b") -> Tensor.scalar(2))) + } + "reshape" should "transform vector into matrix" in { Tensor.vector(1, 2, 3, 4).const.reshape(2, 2).eval should be( Tensor.matrix(Array(1, 2), Array(3, 4))) diff --git a/src/test/scala/scanet/core/SessionSpec.scala b/src/test/scala/scanet/core/SessionSpec.scala index cdff36e..a1d33e8 100644 --- a/src/test/scala/scanet/core/SessionSpec.scala +++ b/src/test/scala/scanet/core/SessionSpec.scala @@ -24,6 +24,13 @@ class SessionSpec extends AnyFlatSpec with CustomMatchers { } } + it should "eval an empty sequence of outputs" in { + withing { session => + session.runner.eval[Seq[Expr[Int]]](Seq.empty) should + be(Seq.empty[Tensor[Int]]) + } + } + it should "eval a tuple2 of outputs" in { withing { session => session.runner.eval((5.const, 10.const)) should @@ -33,7 +40,6 @@ class SessionSpec extends AnyFlatSpec with CustomMatchers { it should "eval a tuple2 with one output and sequence of outputs" in { withing { session => - println(session.devices) session.runner.eval[(Expr[Int], Seq[Expr[Int]])]((1.const, Seq(5.const, 10.const))) should be((scalar(1), Seq(scalar(5), scalar(10)))) } diff --git a/src/test/scala/scanet/core/TFSpec.scala b/src/test/scala/scanet/core/TFSpec.scala index 87b82d4..644a698 100644 --- a/src/test/scala/scanet/core/TFSpec.scala +++ b/src/test/scala/scanet/core/TFSpec.scala @@ -35,6 +35,18 @@ class TFSpec extends AnyFlatSpec with CustomMatchers { func(scalar(5)) should be(Seq(scalar(5), scalar(6))) } + it should "work with map of outputs as arg" in { + val sum = (arg: Map[String, Expr[Int]]) => arg("a") plus arg("b") + val func = sum.compile + func(Map("a" -> scalar(1), "b" -> scalar(2))) should be(scalar(3)) + } + + it should "return map of outputs" in { + val double = (arg: Expr[Int]) => Map("a" -> (arg + 0.const), "b" -> (arg + 1.const)) + val func = double.compile + func(scalar(5)) should be(Map("a" -> scalar(5), "b" -> scalar(6))) + } + it should "return complex tuple" in { val complex = (arg: Expr[Int]) => (arg + 0.const, Seq(arg + 0.const, arg + 2.const)) val func = complex.compile @@ -65,4 +77,12 @@ class TFSpec extends AnyFlatSpec with CustomMatchers { val func = plus.compile func(scalar(1), scalar(2), scalar(3), scalar(4), scalar(5)) should be(scalar(15)) } + + "tensor function of 6 args" should "work" in { + val plus = + (a1: Expr[Int], a2: Expr[Int], a3: Expr[Int], a4: Expr[Int], a5: Expr[Int], a6: Expr[Int]) => + a1 + a2 + a3 + a4 + a5 + a6 + val func = plus.compile + func(scalar(1), scalar(2), scalar(3), scalar(4), scalar(5), scalar(6)) should be(scalar(21)) + } } diff --git a/src/test/scala/scanet/models/CNNSpec.scala b/src/test/scala/scanet/models/CNNSpec.scala index 05804a1..7e07e5d 100644 --- a/src/test/scala/scanet/models/CNNSpec.scala +++ b/src/test/scala/scanet/models/CNNSpec.scala @@ -42,24 +42,23 @@ class CNNSpec extends AnyWordSpec with CustomMatchers with SharedSpark with Data Conv2D(64, activation = ReLU()) >> Pool2D() >> Flatten >> Dense(10, Softmax) val expected = - """#+-----------------------------------------------+--------------+------+------------+ - #|name |weights |params|output | - #+-----------------------------------------------+--------------+------+------------+ - #|Input | | |(_,24,24,1) | - #|Conv2D(32,(3,3),(1,1),Valid,NHWC,GlorotUniform)|(3, 3, 1, 32) |288 |(_,22,22,32)| - #|ReLU(0.0) | | |(_,22,22,32)| - #|Pool2D((2,2),(1,1),Valid,NHWC,Max) | | |(_,21,21,32)| - #|Conv2D(64,(3,3),(1,1),Valid,NHWC,GlorotUniform)|(3, 3, 32, 64)|18432 |(_,19,19,64)| - #|ReLU(0.0) | | |(_,19,19,64)| - #|Pool2D((2,2),(1,1),Valid,NHWC,Max) | | |(_,18,18,64)| - #|Flatten | | |(_,20736) | - #|Dense(10,Zero,GlorotUniform) |(20736, 10) |207360|(_,10) | - #|Bias(10,Zero,Zeros) |(10) |10 |(_,10) | - #|Softmax | | |(_,10) | - #+-----------------------------------------------+--------------+------+------------+ - #Total params: 226090 (883.2 KB)""" + """#+-----------------------------------------------+--------------+--------------+------------+------------+ + #|name |weights |weights params|state params|output | + #+-----------------------------------------------+--------------+--------------+------------+------------+ + #|Input | | | |(_,24,24,1) | + #|Conv2D(32,(3,3),(1,1),Valid,NHWC,GlorotUniform)|(3, 3, 1, 32) |288 | |(_,22,22,32)| + #|ReLU(0.0) | | | |(_,22,22,32)| + #|Pool2D((2,2),(1,1),Valid,NHWC,Max) | | | |(_,21,21,32)| + #|Conv2D(64,(3,3),(1,1),Valid,NHWC,GlorotUniform)|(3, 3, 32, 64)|18432 | |(_,19,19,64)| + #|ReLU(0.0) | | | |(_,19,19,64)| + #|Pool2D((2,2),(1,1),Valid,NHWC,Max) | | | |(_,18,18,64)| + #|Flatten | | | |(_,20736) | + #|Dense(10,Zero,GlorotUniform) |(20736, 10) |207360 | |(_,10) | + #|Bias(10,Zero,Zeros) |(10) |10 | |(_,10) | + #|Softmax | | | |(_,10) | + #+-----------------------------------------------+--------------+--------------+------------+------------+ + #Total weight params: 226090 (883.2 KB), state params: 0 (0 B)""" .stripMargin('#') - println(model.describe[Float](Shape(1, 24, 24, 1))) model.describe[Float](Shape(1, 24, 24, 1)) shouldBe expected } } diff --git a/src/test/scala/scanet/models/RNNSpec.scala b/src/test/scala/scanet/models/RNNSpec.scala index d2b35d8..757c82f 100644 --- a/src/test/scala/scanet/models/RNNSpec.scala +++ b/src/test/scala/scanet/models/RNNSpec.scala @@ -50,7 +50,7 @@ class RNNSpec extends AnyWordSpec with CustomMatchers with SharedSpark with Data // metrics.R2Score } - "train as forecast predictor using LSTM Cell" in { + "train as forecast predictor using LSTM Cell in stateless mode" in { val Array(train, test) = monthlySunspots(12).randomSplit(Array(0.8, 0.2), 1) val model = LSTM(2) >> Dense(1, Tanh) val trained = train diff --git a/src/test/scala/scanet/models/RegressionSpec.scala b/src/test/scala/scanet/models/RegressionSpec.scala index 06f0275..19d0461 100644 --- a/src/test/scala/scanet/models/RegressionSpec.scala +++ b/src/test/scala/scanet/models/RegressionSpec.scala @@ -1,50 +1,63 @@ package scanet.models import org.scalatest.flatspec.AnyFlatSpec -import scanet.core.Tensor +import scanet.core.Params.Weights +import scanet.core.Path.stringIsPath +import scanet.core.{Params, Tensor} import scanet.math.syntax._ import scanet.models.Loss._ import scanet.test.CustomMatchers -import scala.collection.immutable.Seq class RegressionSpec extends AnyFlatSpec with CustomMatchers { "linear regression" should "calculate loss with Float precision" in { - val loss = LinearRegression().withLoss(MeanSquaredError).loss[Float].compile + val loss = LinearRegression().withLoss(MeanSquaredError).loss_[Float].compile val x = Tensor.matrix(Array(1.0f, 2.0f), Array(2.0f, 4.0f)) val y = Tensor.matrix(Array(6.0f), Array(12.0f)) val weights = Tensor.matrix(Array(2.0f), Array(3.0f)) val bias = Tensor.vector(1.0f) - loss(x, y, Seq(weights, bias)) should be(Tensor.scalar(17f)) + val params = Params( + "l" / Weights -> weights, + "r" / Weights -> bias) + loss(x, y, params) should be(Tensor.scalar(17f)) } it should "calculate loss with Double precision" in { - val loss = LinearRegression().withLoss(MeanSquaredError).loss[Double].compile + val loss = LinearRegression().withLoss(MeanSquaredError).loss_[Double].compile val x = Tensor.matrix(Array(1.0, 2.0), Array(2.0, 4.0)) val y = Tensor.matrix(Array(6.0), Array(12.0)) val weights = Tensor.matrix(Array(2.0), Array(3.0)) val bias = Tensor.vector(1.0) - loss(x, y, Seq(weights, bias)) should be(Tensor.scalar(17)) + val params = Params( + "l" / Weights -> weights, + "r" / Weights -> bias) + loss(x, y, params) should be(Tensor.scalar(17)) } it should "calculate result" in { - val result = LinearRegression().result[Float].compile + val result = LinearRegression().result_[Float].compile val x = Tensor.matrix(Array(1.0f, 2.0f), Array(2.0f, 4.0f)) val y = Tensor.matrix(Array(9.0f), Array(17.0f)) val weights = Tensor.matrix(Array(2.0f), Array(3.0f)) val bias = Tensor.vector(1.0f) - result(x, Seq(weights, bias)) should be(y) + val params = Params( + "l" / Weights -> weights, + "r" / Weights -> bias) + result(x, params) should be(y) } it should "calculate gradient" in { - val grad = LinearRegression().withLoss(MeanSquaredError).grad[Float].compile + val grad = LinearRegression().withLoss(MeanSquaredError).grad_[Float].compile val x = Tensor.matrix(Array(1.0f, 2.0f), Array(2.0f, 4.0f)) val y = Tensor.matrix(Array(6.0f), Array(12.0f)) val weights = Tensor.matrix(Array(0.0f), Array(0.0f)) val bias = Tensor.vector(0.0f) - grad(x, y, Seq(weights, bias)) should be(Seq( - Tensor.matrix(Array(-30.0f), Array(-60.0f)), - Tensor.vector(-18.0f))) + val params = Params( + "l" / Weights -> weights, + "r" / Weights -> bias) + grad(x, y, params) shouldBe Params( + "l" / Weights -> Tensor.matrix(Array(-30.0f), Array(-60.0f)), + "r" / Weights -> Tensor.vector(-18.0f)) } it should "produce unique toString to be used as a cache key" in { @@ -52,35 +65,43 @@ class RegressionSpec extends AnyFlatSpec with CustomMatchers { } "logistic regression" should "calculate loss" in { - val regression = LogisticRegression().withLoss(BinaryCrossentropy).loss[Float].compile + val regression = LogisticRegression().withLoss(BinaryCrossentropy).loss_[Float].compile val x = Tensor.matrix(Array(0.34f, 0.78f), Array(0.6f, 0.86f)) val y = Tensor.matrix(Array(0.402f), Array(0.47800002f)) val weights = Tensor.matrix(Array(0.2f), Array(0.3f)) val bias = Tensor.vector(0.1f) - regression(x, y, Seq(weights, bias)).toScalar should be(0.7422824f +- 1e-6f) + val params = Params( + "l" / "l" / Weights -> weights, + "l" / "r" / Weights -> bias) + regression(x, y, params).toScalar should be(0.7422824f +- 1e-6f) } it should "calculate result" in { - val result = LogisticRegression().result[Float].compile + val result = LogisticRegression().result_[Float].compile val x = Tensor.matrix(Array(0.34f, 0.78f), Array(0.6f, 0.86f)) val y = Tensor.matrix(Array(0.599168f), Array(0.617276f)) val weights = Tensor.matrix(Array(0.2f), Array(0.3f)) val bias = Tensor.vector(0.1f) - val predicted = result(x, Seq(weights, bias)) + val params = Params( + "l" / "l" / Weights -> weights, + "l" / "r" / Weights -> bias) + val predicted = result(x, params) predicted.const.roundAt(6).eval should be(y) } it should "calculate gradient " in { - val grad = LogisticRegression().withLoss(BinaryCrossentropy).grad[Float].compile + val grad = LogisticRegression().withLoss(BinaryCrossentropy).grad_[Float].compile val x = Tensor.matrix(Array(0.34f, 0.78f), Array(0.6f, 0.86f)) val y = Tensor.matrix(Array(0.402f), Array(0.47800002f)) val weights = Tensor.matrix(Array(0.2f), Array(0.3f)) val bias = Tensor.vector(0.1f) - val result = grad(x, y, Seq(weights, bias)) - .map(_.const.roundAt(6).eval) - result should be(Seq( - Tensor.matrix(Array(0.075301f), Array(0.136784f)), - Tensor.vector(0.168222f))) + val params = Params( + "l" / "l" / Weights -> weights, + "l" / "r" / Weights -> bias) + val result = grad(x, y, params).mapValues(_.const.roundAt(6).eval) + result shouldBe Params( + "l" / "l" / Weights -> Tensor.matrix(Array(0.075301f), Array(0.136784f)), + "l" / "r" / Weights -> Tensor.vector(0.168222f)) } it should "produce unique toString to be used as a cache key" in { diff --git a/src/test/scala/scanet/models/layer/ActivateLayerSpec.scala b/src/test/scala/scanet/models/layer/ActivateLayerSpec.scala index 6d4ddbd..d2c4446 100644 --- a/src/test/scala/scanet/models/layer/ActivateLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/ActivateLayerSpec.scala @@ -1,7 +1,7 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.Tensor +import scanet.core.{Params, Tensor} import scanet.models.Activation.Sigmoid import scanet.syntax._ import scanet.test.CustomMatchers @@ -19,13 +19,13 @@ class ActivateLayerSpec extends AnyWordSpec with CustomMatchers { Array(1f, 0f, 1f), Array(1f, 1f, 1f)) val model = Sigmoid.layer - val result = model.result[Float].compile + val result = model.result_[Float].compile val y = Tensor.matrix( Array(0.5f, 0.5f, 0.7310586f), Array(0.5f, 0.7310586f, 0.7310586f), Array(0.7310586f, 0.5f, 0.7310586f), Array(0.7310586f, 0.7310586f, 0.7310586f)) - result(x, Seq.empty) should be(y) + result(x, Params.empty) should be(y) } "have string repr" in { diff --git a/src/test/scala/scanet/models/layer/BiasLayerSpec.scala b/src/test/scala/scanet/models/layer/BiasLayerSpec.scala index 5d6fdbf..cb7f95d 100644 --- a/src/test/scala/scanet/models/layer/BiasLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/BiasLayerSpec.scala @@ -1,7 +1,8 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.Tensor +import scanet.core.Params.Weights +import scanet.core.{Params, Tensor} import scanet.syntax._ import scanet.test.CustomMatchers @@ -19,13 +20,13 @@ class BiasLayerSpec extends AnyWordSpec with CustomMatchers { Array(1f, 1f, 1f)) val b = Tensor.vector(1f, 2f, 3f) val model = Bias(3) - val result = model.result[Float].compile + val result = model.result_[Float].compile val y = Tensor.matrix( Array(1f, 2f, 4f), Array(1f, 3f, 4f), Array(2f, 2f, 4f), Array(2f, 3f, 4f)) - result(x, Seq(b)) should be(y) + result(x, Params(Weights -> b)) should be(y) } "have string repr" in { diff --git a/src/test/scala/scanet/models/layer/ComposedLayerSpec.scala b/src/test/scala/scanet/models/layer/ComposedLayerSpec.scala index edabd70..e00fa1a 100644 --- a/src/test/scala/scanet/models/layer/ComposedLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/ComposedLayerSpec.scala @@ -1,7 +1,9 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.{Shape, Tensor} +import scanet.core.Params.Weights +import scanet.core.Path._ +import scanet.core.{Params, Shape, Tensor} import scanet.models.Activation._ import scanet.models.Loss.MeanSquaredError import scanet.models.Regularization.L2 @@ -32,13 +34,18 @@ class ComposedLayerSpec extends AnyWordSpec with CustomMatchers with SharedSpark Array(1f), Array(0f)) val b2 = Tensor.vector(0f) - val forward = model.result[Float].compile + val forward = model.result_[Float].compile val expected = Tensor.matrix( Array(0.705357f), Array(0.770136f), Array(0.762753f), Array(0.801886f)) - val result = forward(x, Seq(w1, b1, w2, b2)).const.roundAt(6).eval + val params = Params( + "l" / "l" / "l" / Weights -> w1, + "l" / "l" / "r" / Weights -> b1, + "r" / "l" / "l" / Weights -> w2, + "r" / "l" / "r" / Weights -> b2) + val result = forward(x, params).const.roundAt(6).eval result should be(expected) } } @@ -54,7 +61,12 @@ class ComposedLayerSpec extends AnyWordSpec with CustomMatchers with SharedSpark val w2 = Tensor.matrix( Array(0.1f, 0.5f, 1f, 0f)) val b2 = Tensor.vector(0f) - model.penalty(Shape(1, 4), Seq(w1.const, b1.const, w2.const, b2.const)).eval should be( + val params = Params( + "l" / "l" / "l" / Weights -> w1, + "l" / "l" / "r" / Weights -> b1, + "r" / "l" / "l" / Weights -> w2, + "r" / "l" / "r" / Weights -> b2) + model.penalty_(Shape(1, 4), params.mapValues(_.const)).eval should be( Tensor.scalar(3.83f)) } @@ -74,8 +86,13 @@ class ComposedLayerSpec extends AnyWordSpec with CustomMatchers with SharedSpark Array(0.5f), Array(1f)) val b2 = Tensor.vector(0f) - val loss = model.withLoss(MeanSquaredError).loss[Float].compile - val result = loss(x, y, Seq(w1, b1, w2, b2)).const.roundAt(6).eval + val params = Params( + "l" / "l" / "l" / Weights -> w1, + "l" / "l" / "r" / Weights -> b1, + "r" / "l" / "l" / Weights -> w2, + "r" / "l" / "r" / Weights -> b2) + val loss = model.withLoss(MeanSquaredError).loss_[Float].compile + val result = loss(x, y, params).const.roundAt(6).eval result should be(Tensor.scalar(0.339962f)) } @@ -95,7 +112,12 @@ class ComposedLayerSpec extends AnyWordSpec with CustomMatchers with SharedSpark Array(0.5f), Array(1f)) val b2 = Tensor.vector(0f) - val loss = model.withLoss(MeanSquaredError).loss[Float].compile - loss(x, y, Seq(w1, b1, w2, b2)) should be(Tensor.scalar(3.6199622f)) + val params = Params( + "l" / "l" / "l" / Weights -> w1, + "l" / "l" / "r" / Weights -> b1, + "r" / "l" / "l" / Weights -> w2, + "r" / "l" / "r" / Weights -> b2) + val loss = model.withLoss(MeanSquaredError).loss_[Float].compile + loss(x, y, params) should be(Tensor.scalar(3.6199622f)) } } diff --git a/src/test/scala/scanet/models/layer/Conv2DLayerSpec.scala b/src/test/scala/scanet/models/layer/Conv2DLayerSpec.scala index fc86727..825d84b 100644 --- a/src/test/scala/scanet/models/layer/Conv2DLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/Conv2DLayerSpec.scala @@ -1,7 +1,8 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.Tensor +import scanet.core.Params.Weights +import scanet.core.{Params, Tensor} import scanet.syntax._ import scanet.test.CustomMatchers @@ -32,8 +33,8 @@ class Conv2DLayerSpec extends AnyWordSpec with CustomMatchers { Array(7.0, 11.0, 16.0, 7.0), Array(10.0, 7.0, 4.0, 7.0)) .reshape(1, 4, 4, 1) - val result = model.result[Double].compile - result(input, Seq(filters)).const.roundAt(2).eval shouldBe output + val result = model.result_[Double].compile + result(input, Params(Weights -> filters)).const.roundAt(2).eval shouldBe output } "have string repr" in { diff --git a/src/test/scala/scanet/models/layer/DenseLayerSpec.scala b/src/test/scala/scanet/models/layer/DenseLayerSpec.scala index 3dc921c..6d40bbb 100644 --- a/src/test/scala/scanet/models/layer/DenseLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/DenseLayerSpec.scala @@ -1,7 +1,9 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.{Shape, Tensor} +import scanet.core.Params.Weights +import scanet.core.{Params, Shape, Tensor} +import scanet.core.Path._ import scanet.models.Activation._ import scanet.models.Loss._ import scanet.models.Regularization.L2 @@ -34,8 +36,9 @@ class DenseLayerSpec extends AnyWordSpec with CustomMatchers { Array(0.880797f, 0.622459f, 0.768525f, 0.598688f), Array(0.890903f, 0.817574f, 0.900250f, 0.802184f)) val model = Dense(4, Sigmoid) - val forward = model.result[Float].compile - val y = forward(x, Seq(w, b)).const.roundAt(6).eval + val forward = model.result_[Float].compile + val params = Params("l" / "l" / Weights -> w, "l" / "r" / Weights -> b) + val y = forward(x, params).const.roundAt(6).eval y should be(yExpected) } @@ -47,12 +50,13 @@ class DenseLayerSpec extends AnyWordSpec with CustomMatchers { Array(1f, 0f)) val b = Tensor.vector(0f, 0f) val model = Dense(4, Sigmoid, reg = L2(lambda = 1)) - model.penalty(Shape(1, 4), Seq(w.const, b.const)).eval should be(Tensor.scalar(1.63f)) + val params = Params("l" / "l" / Weights -> w, "l" / "r" / Weights -> b) + model.penalty_(Shape(1, 4), params.mapValues(_.const)).eval should be(Tensor.scalar(1.63f)) } "produce gradient when combined with loss function" in { val loss = Dense(4, Sigmoid).withLoss(BinaryCrossentropy) - val grad = loss.grad[Float].compile + val grad = loss.grad_[Float].compile val x = Tensor.matrix( Array(0f, 0f, 1f), Array(0f, 1f, 1f), @@ -70,7 +74,9 @@ class DenseLayerSpec extends AnyWordSpec with CustomMatchers { Array(-0.040072710f, -0.034289565f, -0.041798398f, -0.036751173f), Array(-0.078313690f, -0.041943270f, -0.061695820f, -0.047571808f)) val biasGrad = Tensor.vector(-0.078313690f, -0.041943270f, -0.061695820f, -0.047571808f) - grad(x, y, Seq(weights, bias)) should be(Seq(weightsGrad, biasGrad)) + val before = Params("l" / "l" / Weights -> weights, "l" / "r" / Weights -> bias) + val after = Params("l" / "l" / Weights -> weightsGrad, "l" / "r" / Weights -> biasGrad) + grad(x, y, before) should be(after) } } } diff --git a/src/test/scala/scanet/models/layer/FlattenLayerSpec.scala b/src/test/scala/scanet/models/layer/FlattenLayerSpec.scala index 4b250bb..38e5980 100644 --- a/src/test/scala/scanet/models/layer/FlattenLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/FlattenLayerSpec.scala @@ -1,7 +1,7 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.Tensor +import scanet.core.{Params, Tensor} import scanet.syntax._ import scanet.test.CustomMatchers @@ -13,8 +13,8 @@ class FlattenLayerSpec extends AnyWordSpec with CustomMatchers { "flatten the output from (batch, features_1, ... features_n) into (batch, features) tensor" in { val x = Tensor.ones[Float](10, 5, 5) val model = Flatten - val result = model.result[Float].compile - result(x, Seq.empty) shouldBe x.reshape(10, 25) + val result = model.result_[Float].compile + result(x, Params.empty) shouldBe x.reshape(10, 25) } } } diff --git a/src/test/scala/scanet/models/layer/Pool2DLayerSpec.scala b/src/test/scala/scanet/models/layer/Pool2DLayerSpec.scala index 15f24a7..013e6dc 100644 --- a/src/test/scala/scanet/models/layer/Pool2DLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/Pool2DLayerSpec.scala @@ -1,7 +1,7 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.Tensor +import scanet.core.{Params, Tensor} import scanet.syntax._ import scanet.test.CustomMatchers @@ -27,8 +27,8 @@ class Pool2DLayerSpec extends AnyWordSpec with CustomMatchers { Array(2.0, 3.0, 3.0, 3.0), Array(2.0, 3.0, 3.0, 2.0)) .reshape(1, 4, 4, 1) - val result = model.result[Double].compile - result(input, Seq.empty).const.roundAt(2).eval shouldBe output + val result = model.result_[Double].compile + result(input, Params.empty).const.roundAt(2).eval shouldBe output } "have string repr" in { diff --git a/src/test/scala/scanet/models/layer/RNNLayerSpec.scala b/src/test/scala/scanet/models/layer/RNNLayerSpec.scala index 904cada..e5dca1b 100644 --- a/src/test/scala/scanet/models/layer/RNNLayerSpec.scala +++ b/src/test/scala/scanet/models/layer/RNNLayerSpec.scala @@ -1,7 +1,9 @@ package scanet.models.layer import org.scalatest.wordspec.AnyWordSpec -import scanet.core.{Shape, Tensor} +import scanet.core.Params.Weights +import scanet.core.Path._ +import scanet.core.{Params, Shape, Tensor} import scanet.models.Activation.Identity import scanet.syntax._ import scanet.test.CustomMatchers @@ -22,14 +24,18 @@ class RNNLayerSpec extends AnyWordSpec with CustomMatchers { Array(-0.6313778f, 0.7754754f)) val b = Tensor.vector(0f, 0f) val expected = Tensor.matrix(Array(0.222901f, -6.019066f)) - val result = layer.result[Float].compile - val prediction = result(input, Seq(wx, wh, b)).const.roundAt(6).eval + val result = layer.result_[Float].compile + val params = Params( + "l" / "kernel_weights" -> wx, + "l" / "recurrent_weights" -> wh, + "r" / Weights -> b) + val prediction = result(input, params).const.roundAt(6).eval prediction shouldBe expected } "have string repr" in { val model = RNN(SimpleRNNCell(units = 2)) - model.toString shouldBe "RNN(SimpleRNNCell(2,GlorotUniform,Orthogonal,Zero,Zero) >> Bias(2,Zero,Zeros) >> Tanh,false)" + model.toString shouldBe "RNN(SimpleRNNCell(2,GlorotUniform,Orthogonal,Zero,Zero) >> Bias(2,Zero,Zeros) >> Tanh,false,false)" } } @@ -54,15 +60,28 @@ class RNNLayerSpec extends AnyWordSpec with CustomMatchers { Tensor.matrix(Array(-0.02363521f, 0.1830315f)), Tensor.matrix(Array(-0.37248448f, 0.07327155f), Array(-0.70414686f, -0.3490578f)), Tensor.vector(0f, 0f)) - val w = wf ++ wi ++ wg ++ wo - val result = layer.result[Float].compile - val prediction = result(input, w).const.roundAt(6).eval + val params = Params( + "forget" / "l" / "l" / "kernel_weights" -> wf(0), + "forget" / "l" / "l" / "recurrent_weights" -> wf(1), + "forget" / "l" / "r" / Weights -> wf(2), + "input" / "l" / "l" / "kernel_weights" -> wi(0), + "input" / "l" / "l" / "recurrent_weights" -> wi(1), + "input" / "l" / "r" / Weights -> wi(2), + "gate" / "l" / "l" / "kernel_weights" -> wg(0), + "gate" / "l" / "l" / "recurrent_weights" -> wg(1), + "gate" / "l" / "r" / Weights -> wg(2), + "output" / "l" / "l" / "kernel_weights" -> wo(0), + "output" / "l" / "l" / "recurrent_weights" -> wo(1), + "output" / "l" / "r" / Weights -> wo(2), + ) + val result = layer.result_[Float].compile + val prediction = result(input, params).const.roundAt(6).eval prediction shouldBe Tensor.matrix(Array(0.382158f, 0.029766f)) } "have string repr" in { val model = RNN(LSTMCell(units = 2)) - model.toString shouldBe "RNN(LSTMCell(2,Tanh,Sigmoid,true,GlorotUniform,Orthogonal,Zeros,Ones,Zero,Zero,Zero),false)" + model.toString shouldBe "RNN(LSTMCell(2,Tanh,Sigmoid,true,GlorotUniform,Orthogonal,Zeros,Ones,Zero,Zero,Zero),false,false)" } } } diff --git a/src/test/scala/scanet/optimizers/SGDSpec.scala b/src/test/scala/scanet/optimizers/SGDSpec.scala index 469d28d..3eadc07 100644 --- a/src/test/scala/scanet/optimizers/SGDSpec.scala +++ b/src/test/scala/scanet/optimizers/SGDSpec.scala @@ -1,7 +1,8 @@ package scanet.optimizers import org.scalatest.flatspec.AnyFlatSpec -import scanet.core.Tensor +import scanet.core.Params.Weights +import scanet.core.{Params, Tensor} import scanet.math.syntax._ import scanet.models.LinearRegression import scanet.models.Loss._ @@ -19,7 +20,7 @@ class SGDSpec extends AnyFlatSpec with CustomMatchers with SharedSpark with Data .minimize[Float](`x^2`) .loss(Identity) .using(SGD(rate = 0.1f)) - .initWeights(Seq(Tensor.scalar(5.0f))) + .initParams(Params(Weights -> Tensor.scalar(5.0f))) .each(1.epochs, RecordLoss()) .on(zero) .batch(1)