Skip to content

Commit

Permalink
Add named parameters to the model
Browse files Browse the repository at this point in the history
  • Loading branch information
pashashiz committed Oct 15, 2023
1 parent 9c4dbef commit 4fdc3ca
Show file tree
Hide file tree
Showing 40 changed files with 1,033 additions and 516 deletions.
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ libraryDependencies ++= Seq(
"com.typesafe.scala-logging" %% "scala-logging" % "3.9.5",
"ch.qos.logback" % "logback-classic" % "1.4.5",
"org.scala-lang.modules" %% "scala-collection-compat" % "2.9.0",
"com.github.ben-manes.caffeine" % "caffeine" % "2.8.5",
"com.softwaremill.magnolia1_2" %% "magnolia" % "1.1.3",
"org.scalacheck" %% "scalacheck" % "1.17.0" % Test,
"org.scalatest" %% "scalatest" % "3.2.14" % Test)

Expand Down
181 changes: 142 additions & 39 deletions src/main/scala/scanet/core/Mat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,81 +11,184 @@ import scala.collection.immutable.Seq
// case type information gets lost there
trait Mat[In] {
type Out
// given MyType[Expr] deconstruct as Layout + Seq[Expr] so we could feed a session
def deconstructIn(in: In): (Layout, Seq[Expr[_]])
def constructIn(shape: Layout, expr: Seq[Expr[_]]): In
// given placeholders and layout, construct In representation so we could to pass into the session
def constructIn(layout: Layout, expr: Seq[Expr[_]]): In
// given MyType[Tensor] deconstruct as Layout + Seq[Tensor] so we could create placeholders
def deconstructOut(out: Out): (Layout, Seq[Tensor[_]])
def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out
// given raw tensors and layout, construct MyType[Tensor], that is done to get results after session run
def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out
}

sealed trait Layout {
def size: Int
}

object Layout {
case class Leaf(size: Int) extends Layout
case object Value extends Layout {
override def size: Int = 1
}
case class SeqLayout(items: Seq[Layout]) extends Layout {
override def size: Int = items.map(_.size).sum
}
case class MapLayout(items: Seq[(Any, Layout)]) extends Layout {
override def size: Int = items.map(_._2.size).sum
}
case class Struct(children: Seq[Layout]) extends Layout {
def apply(index: Int): Layout = children(index)
def size: Int = children.map(_.size).sum
}
}

class MatExpr[A: TensorType] extends Mat[Expr[A]] {
class ValueMat[A: TensorType] extends Mat[Expr[A]] {
override type Out = Tensor[A]
override def deconstructIn(in: Expr[A]): (Layout, Seq[Expr[_]]) = (Leaf(1), Seq(in))
override def constructIn(shape: Layout, expr: Seq[Expr[_]]): Expr[A] =
expr.head.asInstanceOf[Expr[A]]
override def deconstructOut(out: Tensor[A]): (Layout, Seq[Tensor[_]]) = (Leaf(1), Seq(out))
override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out =
Tensor.wrap[A](tensors.head)
}

class MatSeqOfExpr[A: TensorType] extends Mat[Seq[Expr[A]]] {
override type Out = Seq[Tensor[A]]
override def deconstructIn(in: Seq[Expr[A]]): (Layout, Seq[Expr[_]]) = (Leaf(in.size), in)
override def constructIn(shape: Layout, expr: Seq[Expr[_]]): Seq[Expr[A]] =
expr.asInstanceOf[Seq[Expr[A]]]
override def deconstructOut(out: Seq[Tensor[A]]): (Layout, Seq[Tensor[_]]) =
(Leaf(out.size), out)
override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out =
tensors.map(raw => Tensor.wrap[A](raw))
override def deconstructIn(in: Expr[A]): (Layout, Seq[Expr[_]]) = (Value, Seq(in))
override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Expr[A] = layout match {
case Value => expr.head.asInstanceOf[Expr[A]]
case other => error(s"Unsupported layout $other")
}
override def deconstructOut(out: Tensor[A]): (Layout, Seq[Tensor[_]]) = (Value, Seq(out))
override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out =
layout match {
case Value => Tensor.wrap[A](tensors.head)
case other => error(s"Unsupported layout $other")
}
}

object Mat {

trait AllSyntax {

implicit def matExpr[A: TensorType]: MatExpr[A] = new MatExpr[A]
implicit def valueMat[A: TensorType]: ValueMat[A] = new ValueMat[A]
// Note: we have to define anonymous class so dependant types would work
// that is probably caused by the fact that we have to preserve an exact structural type

implicit def matSeqOfExpr[A: TensorType]: MatSeqOfExpr[A] = new MatSeqOfExpr[A]
implicit def seqMat[A](implicit m: Mat[A]) = new Mat[Seq[A]] {
override type Out = Seq[m.Out]
override def deconstructIn(in: Seq[A]): (Layout, Seq[Expr[_]]) = {
val (layouts, allExpr) = in.map(m.deconstructIn).unzip
(SeqLayout(layouts), allExpr.flatten)
}
override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Seq[A] = {
layout match {
case SeqLayout(items) =>
items.foldLeft((Seq.empty[A], expr)) {
case ((result, exprAll), next) =>
val (consumed, rest) = exprAll.splitAt(next.size)
val constructed = m.constructIn(next, consumed)
(constructed +: result, rest)
}._1.reverse
case other => error(s"Unsupported layout $other")
}
}
override def deconstructOut(out: Seq[m.Out]): (Layout, Seq[Tensor[_]]) = {
val (layouts, allTensors) = out.map(m.deconstructOut).unzip
(SeqLayout(layouts), allTensors.flatten)
}
override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Seq[m.Out] = {
layout match {
case SeqLayout(items) =>
items.foldLeft((Seq.empty[m.Out], tensors)) {
case ((result, tensorsAll), next) =>
val (consumed, rest) = tensorsAll.splitAt(next.size)
val constructed = m.constructOutRaw(next, consumed)
(constructed +: result, rest)
}._1.reverse
case other => error(s"Unsupported layout $other")
}
}
}

// Note: we have to define anonymous class so dependant types would work
// that is probably caused by the fact that we have to preserve an original
// (m1.Out, m2.Out) which come from implicit scope
implicit def matTuple2Expr[A1, A2](implicit m1: Mat[A1], m2: Mat[A2]) =
implicit def mapMat[K, V](implicit m: Mat[V]) = new Mat[Map[K, V]] {
override type Out = Map[K, m.Out]
override def deconstructIn(in: Map[K, V]): (Layout, Seq[Expr[_]]) = {
val (layouts, allExpr) = in
.map {
case (key, value) =>
val (layout, expr) = m.deconstructIn(value)
((key, layout), expr)
}
.toList
.unzip
(MapLayout(layouts), allExpr.flatten)
}
override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Map[K, V] = {
layout match {
case MapLayout(items) =>
items.foldLeft((Map.empty[K, V], expr)) {
case ((result, exprAll), (key, next)) =>
val (consumed, rest) = exprAll.splitAt(next.size)
val constructed = m.constructIn(next, consumed)
(result + (key.asInstanceOf[K] -> constructed), rest)
}._1
case other => error(s"Unsupported layout $other")
}
}
override def deconstructOut(out: Map[K, m.Out]): (Layout, Seq[Tensor[_]]) = {
val (layouts, allTensors) = out
.map {
case (key, value) =>
val (layout, tensors) = m.deconstructOut(value)
((key, layout), tensors)
}
.toList
.unzip
(MapLayout(layouts), allTensors.flatten)
}
override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = {
layout match {
case MapLayout(items) =>
items.foldLeft((Map.empty[K, m.Out], tensors)) {
case ((result, tensorsAll), (key, next)) =>
val (consumed, rest) = tensorsAll.splitAt(next.size)
val constructed = m.constructOutRaw(next, consumed)
(result + (key.asInstanceOf[K] -> constructed), rest)
}._1
case other => error(s"Unsupported layout $other")
}
}
}

// check to see if we can implement xmap with dependant types
implicit def paramsMat[A](implicit m: Mat[A]) = new Mat[Params[A]] {
override type Out = Params[m.Out]
private val map = mapMat[Path, A]
override def deconstructIn(in: Params[A]): (Layout, Seq[Expr[_]]) =
map.deconstructIn(in.params)
override def constructIn(layout: Layout, expr: Seq[Expr[_]]): Params[A] =
Params(map.constructIn(layout, expr))
override def deconstructOut(out: Params[m.Out]): (Layout, Seq[Tensor[_]]) =
map.deconstructOut(out.params)
override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out =
Params(map.constructOutRaw(layout, tensors))
}

implicit def tuple2Mat[A1, A2](implicit m1: Mat[A1], m2: Mat[A2]) =
new Mat[(A1, A2)] {
override type Out = (m1.Out, m2.Out)
override def deconstructIn(in: (A1, A2)): (Layout, Seq[Expr[_]]) = {
val (shape1, expr1) = m1.deconstructIn(in._1)
val (shape2, expr2) = m2.deconstructIn(in._2)
(Struct(Seq(shape1, shape2)), expr1 ++ expr2)
}
override def constructIn(shape: Layout, expr: Seq[Expr[_]]): (A1, A2) = {
shape match {
override def constructIn(layout: Layout, expr: Seq[Expr[_]]): (A1, A2) = {
layout match {
case Struct(t1 :: t2 :: Nil) =>
val (s1, s2) = (t1.size, t2.size)
(
m1.constructIn(t1, expr.slice(0, s1)),
m2.constructIn(t2, expr.slice(s1, s1 + s2)))
case _ => error("StructShape of size 2 is required")
case _ => error("Struct layout of size 2 is required")
}
}
override def deconstructOut(out: (m1.Out, m2.Out)): (Layout, Seq[Tensor[_]]) = {
val (shape1, tensor1) = m1.deconstructOut(out._1)
val (shape2, tensor2) = m2.deconstructOut(out._2)
(Struct(Seq(shape1, shape2)), tensor1 ++ tensor2)
}
override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out = {
shape match {
override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = {
layout match {
case Struct(t1 :: t2 :: Nil) =>
val (s1, s2) = (t1.size, t2.size)
(
Expand All @@ -96,7 +199,7 @@ object Mat {
}
}

implicit def matTuple3Expr[A1, A2, A3](implicit m1: Mat[A1], m2: Mat[A2], m3: Mat[A3]) =
implicit def tuple3Mat[A1, A2, A3](implicit m1: Mat[A1], m2: Mat[A2], m3: Mat[A3]) =
new Mat[(A1, A2, A3)] {
override type Out = (m1.Out, m2.Out, m3.Out)
override def deconstructIn(in: (A1, A2, A3)): (Layout, Seq[Expr[_]]) = {
Expand All @@ -105,15 +208,15 @@ object Mat {
val (t3, expr3) = m3.deconstructIn(in._3)
(Struct(Seq(t1, t2, t3)), expr1 ++ expr2 ++ expr3)
}
override def constructIn(shape: Layout, expr: Seq[Expr[_]]): (A1, A2, A3) = {
shape match {
override def constructIn(layout: Layout, expr: Seq[Expr[_]]): (A1, A2, A3) = {
layout match {
case Struct(t1 :: t2 :: t3 :: Nil) =>
val (s1, s2, s3) = (t1.size, t2.size, t3.size)
(
m1.constructIn(t1, expr.slice(0, s1)),
m2.constructIn(t2, expr.slice(s1, s1 + s2)),
m3.constructIn(t3, expr.slice(s1 + s2, s1 + s2 + s3)))
case _ => error("StructShape of size 3 is required")
case _ => error("Struct layout of size 3 is required")
}
}
override def deconstructOut(out: (m1.Out, m2.Out, m3.Out)): (Layout, Seq[Tensor[_]]) = {
Expand All @@ -122,15 +225,15 @@ object Mat {
val (t3, tensor3) = m3.deconstructOut(out._3)
(Struct(Seq(t1, t2, t3)), tensor1 ++ tensor2 ++ tensor3)
}
override def constructOutRaw(shape: Layout, tensors: Seq[RawTensor]): Out = {
shape match {
override def constructOutRaw(layout: Layout, tensors: Seq[RawTensor]): Out = {
layout match {
case Struct(t1 :: t2 :: t3 :: Nil) =>
val (s1, s2, s3) = (t1.size, t2.size, t3.size)
(
m1.constructOutRaw(t1, tensors.slice(0, s1)),
m2.constructOutRaw(t2, tensors.slice(s1, s1 + s2)),
m3.constructOutRaw(t3, tensors.slice(s1 + s2, s1 + s2 + s3)))
case _ => error("StructShape of size 3 is required")
case _ => error("Struct layout of size 3 is required")
}
}
}
Expand Down
72 changes: 72 additions & 0 deletions src/main/scala/scanet/core/Params.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package scanet.core

case class Path(segments: Seq[String]) {
def /(child: Path): Path = Path(segments ++ child.segments)
def startsWith(parent: Path): Boolean = segments.startsWith(parent.segments)
def endsWith(parent: Path): Boolean = segments.endsWith(parent.segments)
def relativeTo(parent: Path): Option[Path] =
if (startsWith(parent)) Some(Path(segments.drop(parent.segments.size))) else None
override def toString: String = segments.mkString("/")
}

object Path {
def apply(s1: String, sn: String*): Path = new Path(s1 :: sn.toList)
def parse(path: String): Path = new Path(path.split("/").toList)
implicit def stringIsPath(string: String): Path = Path.parse(string)
}

// convenient wrapper over Map[Path, A] to get DSL like API
case class Params[A](params: Map[Path, A]) {
def apply(path: Path): A =
params.getOrElse(path, error(s"missing $path param"))
def paths: Set[Path] = params.keySet
def children(parent: Path): Params[A] = {
val childrenParams = params.flatMap {
case (path, value) =>
path.relativeTo(parent).map(child => child -> value)
}
Params(childrenParams)
}
def +(other: (Path, A)): Params[A] = Params(params + other)
def -(other: Path): Params[A] = Params(params - other)
def ++(other: Params[A]): Params[A] =
Params(params ++ other.params)
def map[B](f: (Path, A) => (Path, B)): Params[B] =
Params(params.map(f.tupled))
def mapValues[B](f: A => B): Params[B] =
Params(params.map { case (k, v) => (k, f(v)) })
def filter(f: (Path, A) => Boolean): Params[A] =
Params(params.filter(f.tupled))
def filterPaths(f: Path => Boolean): Params[A] =
Params(params.filter { case (k, _) => f(k) })
def filterValues(f: A => Boolean): Params[A] =
Params(params.filter { case (_, v) => f(v) })
def partition(by: (Path, A) => Boolean): (Params[A], Params[A]) = {
val (left, right) = params.partition(by.tupled)
(Params(left), Params(right))
}
def partitionPaths(by: Path => Boolean): (Params[A], Params[A]) =
partition { case (k, _) => by(k) }
def partitionValues(by: A => Boolean): (Params[A], Params[A]) =
partition { case (k, v) => by(v) }
def values: Iterable[A] = params.values
def join[B](other: Params[B]): Params[(A, B)] = {
val allPaths = paths ++ other.paths
val joinedItems = allPaths.map(path => (path, (this(path), other(path))))
Params(joinedItems.toMap)
}
def size: Int = params.size
def isEmpty: Boolean = params.isEmpty
def prependPath(path: Path): Params[A] =
Params(params.map { case (k, v) => path / k -> v })
def weights: A = apply(Params.Weights)
}

object Params {
def apply[A](elems: (Path, A)*): Params[A] =
new Params[A](Map(elems: _*))
def empty[A]: Params[A] =
new Params[A](Map.empty)
val Weights: Path = "weights"
val State: Path = "state"
}
Loading

0 comments on commit 4fdc3ca

Please sign in to comment.