-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add some very basic GADT constraints from type cases
- Loading branch information
Showing
8 changed files
with
287 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import scala.compiletime.ops.string.+ | ||
import scala.compiletime.ops.any | ||
|
||
type Index = Int & Singleton | ||
|
||
sealed trait Indices | ||
|
||
final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices: | ||
override def toString = s"$head ::: $tail" | ||
|
||
sealed trait INil extends Indices | ||
case object INil extends INil | ||
|
||
object Indices: | ||
type ToString[X <: Indices] <: String = X match | ||
case INil => "INil" | ||
case head ::: tail => any.ToString[head] + " ::: " + ToString[tail] | ||
|
||
type Contains[Haystack <: Indices, Needle <: Index] <: Boolean = Haystack match | ||
case head ::: tail => head match | ||
case Needle => true | ||
case _ => Contains[tail, Needle] | ||
case INil => false | ||
|
||
type RemoveValue[RemoveFrom <: Indices, Value <: Index] <: Indices = RemoveFrom match | ||
case INil => INil | ||
case head ::: tail => head match | ||
case Value => RemoveValue[tail, Value] | ||
case _ => head ::: RemoveValue[tail, Value] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import scala.compiletime.ops.int.{S, +, <, <=, *} | ||
import scala.compiletime.ops.boolean.&& | ||
|
||
type Dimension = Int & Singleton | ||
|
||
sealed trait Shape extends Product with Serializable | ||
|
||
final case class #:[+H <: Dimension, +T <: Shape](head: H, tail: T) extends Shape: | ||
override def toString = (head: Any) match | ||
case _ #: _ => s"($head) #: $tail" | ||
case _ => s"$head #: $tail" | ||
|
||
sealed trait SNil extends Shape | ||
case object SNil extends SNil | ||
|
||
object Shape: | ||
def scalar: SNil = SNil | ||
|
||
type Concat[X <: Shape, Y <: Shape] <: Shape = X match | ||
case SNil => Y | ||
case head #: tail => head #: Concat[tail, Y] | ||
|
||
type Reverse[X <: Shape] <: Shape = X match | ||
case SNil => SNil | ||
case head #: tail => Concat[Reverse[tail], head #: SNil] | ||
|
||
type NumElements[X <: Shape] <: Int = X match | ||
case SNil => 1 | ||
case head #: tail => head * NumElements[tail] | ||
|
||
type Rank[X <: Shape] <: Int = X match | ||
case SNil => 0 | ||
case head #: tail => Rank[tail] + 1 | ||
|
||
type IsEmpty[X <: Shape] <: Boolean = X match | ||
case SNil => true | ||
case _ #: _ => false | ||
|
||
type Head[X <: Shape] <: Dimension = X match { case head #: _ => head } | ||
type Tail[X <: Shape] <: Shape = X match { case _ #: tail => tail } | ||
|
||
type Reduce[S <: Shape, Axes <: None.type | Indices] <: Shape = Axes match | ||
case None.type => SNil | ||
case Indices => ReduceLoop[S, Axes, 0] | ||
|
||
protected type ReduceLoop[RemoveFrom <: Shape, ToRemove <: Indices, I <: Index] <: Shape = RemoveFrom match | ||
case head #: tail => Indices.Contains[ToRemove, I] match | ||
case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]] | ||
case false => head #: ReduceLoop[tail, ToRemove, S[I]] | ||
case SNil => ToRemove match { case INil => SNil } | ||
|
||
type WithinBounds[I <: Index, S <: Shape] = (0 <= I && I < Rank[S]) | ||
|
||
type RemoveIndex[RemoveFrom <: Shape, I <: Index] <: Shape = WithinBounds[I, RemoveFrom] match | ||
case true => RemoveIndexLoop[RemoveFrom, I, 0] | ||
|
||
protected type RemoveIndexLoop[RemoveFrom <: Shape, I <: Index, Current <: Index] <: Shape = RemoveFrom match | ||
case head #: tail => Current match | ||
case I => tail | ||
case _ => head #: RemoveIndexLoop[tail, I, S[Current]] | ||
|
||
type Map[X <: Shape, F[_ <: Dimension] <: Dimension] <: Shape = X match | ||
case SNil => SNil | ||
case head #: tail => F[head] #: Map[tail, F] | ||
|
||
type FoldLeft[B, X <: Shape, Z <: B, F[_ <: B, _ <: Int] <: B] <: B = X match | ||
case SNil => Z | ||
case head #: tail => FoldLeft[B, tail, F[Z, head], F] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import scala.compiletime.ops.int.S | ||
|
||
type DimensionDenotation = String & Singleton | ||
|
||
sealed trait TensorShapeDenotation extends Product with Serializable | ||
|
||
final case class ##:[+H <: DimensionDenotation, +T <: TensorShapeDenotation](head: H, tail: T) extends TensorShapeDenotation: | ||
override def toString = (head: Any) match | ||
case _ ##: _ => s"($head) ##: $tail" | ||
case _ => s"$head ##: $tail" | ||
|
||
sealed trait TSNil extends TensorShapeDenotation | ||
case object TSNil extends TSNil | ||
|
||
object TensorShapeDenotation: | ||
type Reduce[S <: TensorShapeDenotation, Axes <: None.type | Indices] <: TensorShapeDenotation = Axes match | ||
case None.type => TSNil | ||
case Indices => ReduceLoop[S, Axes, 0] | ||
|
||
protected type ReduceLoop[RemoveFrom <: TensorShapeDenotation, ToRemove <: Indices, I <: Index] <: TensorShapeDenotation = RemoveFrom match | ||
case head ##: tail => Indices.Contains[ToRemove, I] match | ||
case true => ReduceLoop[tail, Indices.RemoveValue[ToRemove, I], S[I]] | ||
case false => head ##: ReduceLoop[tail, ToRemove, S[I]] | ||
case TSNil => ToRemove match { case INil => TSNil } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import scala.compiletime.ops.int.* | ||
|
||
object Tensors: | ||
import Shape.Reverse | ||
|
||
type Supported = Int | Long | Float | Double | Byte | Short | Boolean | String | ||
|
||
type TensorTypeDenotation = String & Singleton | ||
|
||
type Axes = Tuple3[TensorTypeDenotation, TensorShapeDenotation, Shape] | ||
|
||
opaque type Tensor[T <: Supported, +Ax <: Axes] = Tuple2[Array[T], Ax] | ||
|
||
type SparseTensor[T <: Supported, A <: Axes] = Tensor[T, A] | ||
|
||
type KeepOrReduceDims[S <: Shape, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: Shape = KeepDims match | ||
case true => ReduceKeepDims[S, AxisIndices] | ||
case false => Shape.Reduce[S, AxisIndices] | ||
|
||
type KeepOrReduceDimDenotations[Td <: TensorShapeDenotation, AxisIndices <: None.type | Indices, KeepDims <: (Boolean & Singleton)] <: TensorShapeDenotation = KeepDims match | ||
case true => Td | ||
case false => TensorShapeDenotation.Reduce[Td, AxisIndices] | ||
|
||
type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match | ||
case None.type => SNil | ||
case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0] | ||
|
||
protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match | ||
case head #: tail => Indices.Contains[ToReplace, I] match | ||
case true => 1 #: ReduceKeepDimsLoop[tail, Indices.RemoveValue[ToReplace, I], S[I]] | ||
case false => head #: ReduceKeepDimsLoop[tail, ToReplace, S[I]] | ||
case SNil => ToReplace match { case INil => SNil } | ||
|
||
type AddGivenAxisSize[S <: Shape, S1 <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match | ||
case None.type => SNil | ||
case Indices => AddGivenAxisSizeLoop[S, S1, AxisIndices, 0] | ||
|
||
protected type AddGivenAxisSizeLoop[First <: Shape, Second <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = First match | ||
case head #: tail => Indices.Contains[AxisIndex, I] match | ||
case true => Second match | ||
case secondHead #: secondTail => (head + secondHead) #: AddGivenAxisSizeLoop[tail, secondTail, Indices.RemoveValue[AxisIndex, I], S[I]] | ||
case SNil => AxisIndex match { case INil => SNil } | ||
case false => Second match | ||
case secondHead #: secondTail => (head) #: AddGivenAxisSizeLoop[tail, secondTail, AxisIndex, S[I]] | ||
case SNil => AxisIndex match { case INil => SNil } | ||
|
||
type UnsqueezeShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match | ||
case None.type => SNil | ||
case Indices => UnsqueezeShapeLoop[S, AxisIndex, 0] | ||
|
||
protected type UnsqueezeShapeLoop[ToUnsqueeze <: Shape, AxisIndex <: Indices, I <: Index] <: Shape = ToUnsqueeze match | ||
case head #: tail => Indices.Contains[AxisIndex, I] match | ||
case true => 1 #: head #: UnsqueezeShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I]] | ||
case false => head #: UnsqueezeShapeLoop[tail, AxisIndex, S[I]] | ||
case SNil => AxisIndex match { case INil => SNil } | ||
|
||
type GatheredShape[S <: Shape, AxisIndex <: None.type | Indices, AxisIndices <: Indices] <: Shape = AxisIndex match | ||
case None.type => SNil | ||
case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices] | ||
|
||
protected type GatheredShapeLoop[ToGather <: Shape, AxisIndex <: Indices, I <: Index, AxisIndices <: Indices] <: Shape = ToGather match | ||
case head #: tail => Indices.Contains[AxisIndex, I] match | ||
case true => IndicesSize[AxisIndices] #: GatheredShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], AxisIndices] | ||
case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices] | ||
case SNil => AxisIndex match { case INil => SNil } | ||
|
||
type IndicesSize[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0] | ||
|
||
type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] <: Dimension = AxisIndices match | ||
case head ::: tail => IndicesSizeLoop[tail, S[Acc]] | ||
case INil => Acc | ||
|
||
type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match | ||
case None.type => SNil | ||
case Indices => FlattenedShapeLoop[S, AxisIndex, 0, 1] | ||
|
||
protected type FlattenedShapeLoop[ToFlatten <: Shape, AxisIndex <: Indices, I <: Index, Acc <: Index] <: Shape = ToFlatten match | ||
case head #: tail => Indices.Contains[AxisIndex, I] match | ||
case true => Acc #: FlattenedShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], head] | ||
case false => FlattenedShapeLoop[tail, AxisIndex, S[I], head * Acc] | ||
case SNil => AxisIndex match { case INil => Acc #: SNil } | ||
|
||
type SlicedShape[AxisIndicesStarts <: None.type | Indices, AxisIndicesEnds <: None.type | Indices] <: Shape = AxisIndicesStarts match | ||
case None.type => SNil | ||
case Indices => AxisIndicesEnds match | ||
case None.type => SNil | ||
case Indices => SlicedShapeLoop[AxisIndicesStarts, AxisIndicesEnds] | ||
|
||
protected type SlicedShapeLoop[Starts <: Indices, Ends <: Indices] <: Shape = Starts match | ||
case head ::: tail => Ends match | ||
case endsHead ::: endsTail => (endsHead - head) #: SlicedShapeLoop[tail, endsTail] | ||
case INil => SNil | ||
case INil => Ends match { case INil => SNil } | ||
|
||
type PaddedShape[PadFrom <: Shape, AxisBefore <: None.type | Shape, AxisAfter <: None.type | Shape] <: Shape = AxisBefore match | ||
case None.type => PadFrom | ||
case Shape => AxisAfter match | ||
case None.type => PadFrom | ||
case Shape => Reverse[PaddedShapeLoop[Reverse[PadFrom], Reverse[AxisBefore], Reverse[AxisAfter]]] | ||
|
||
protected type PaddedShapeLoop[PadFrom <: Shape, Before <: Shape, After <: Shape] <: Shape = Before match | ||
case head #: tail => After match | ||
case afterHead #: afterTail => PadFrom match | ||
case padFromHead #: padFromTail => (head + padFromHead + afterHead) #: PaddedShapeLoop[padFromTail, tail, afterTail] | ||
case SNil => SNil | ||
case SNil => SNil | ||
case SNil => After match | ||
case SNil => PadFrom match | ||
case padFromHead #: padFromTail => padFromHead #: PaddedShapeLoop[padFromTail, SNil, SNil] | ||
case SNil => SNil | ||
|
||
type TiledShape[TileFrom <: Shape, AxisRepeats <: None.type | Indices] <: Shape = AxisRepeats match | ||
case None.type => SNil | ||
case Indices => TiledShapeLoop[TileFrom, AxisRepeats] | ||
|
||
protected type TiledShapeLoop[TileFrom <: Shape, Repeats <: Indices] <: Shape = Repeats match | ||
case head ::: tail => TileFrom match | ||
case tileFromHead #: tileFromTail => (head * tileFromHead) #: TiledShapeLoop[tileFromTail, tail] | ||
case SNil => SNil | ||
case INil => SNil | ||
|
||
type PoolShape[From <: Shape, KernelShape <: None.type | Shape] <: Shape = KernelShape match | ||
case None.type => SNil | ||
case Shape => Reverse[PoolShapeLoop[Reverse[From], Reverse[KernelShape]]] | ||
|
||
protected type PoolShapeLoop[From <: Shape, KernelShape <: Shape] <: Shape = KernelShape match | ||
case head #: tail => From match | ||
case fromHead #: fromTail => (fromHead - head + 1) #: PoolShapeLoop[fromTail, tail] | ||
case SNil => SNil | ||
case SNil => From |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import scala.compiletime.ops.int.* | ||
|
||
type Index = Int & Singleton | ||
type Dimension = Int & Singleton | ||
|
||
sealed trait Indices extends Product with Serializable | ||
sealed trait Shape extends Product with Serializable | ||
final case class :::[+H <: Index, +T <: Indices](head: H, tail: T) extends Indices | ||
final case class #:[+H <: Dimension, +T <: Shape ](head: H, tail: T) extends Shape | ||
sealed trait INil extends Indices; case object INil extends INil | ||
sealed trait SNil extends Shape; case object SNil extends SNil | ||
|
||
object Ts: | ||
type ReduceKeepDims[S <: Shape, AxisIndices <: None.type | Indices] <: Shape = AxisIndices match | ||
case None.type => SNil | ||
case Indices => ReduceKeepDimsLoop[S, AxisIndices, 0] | ||
|
||
protected type ReduceKeepDimsLoop[ReplaceFrom <: Shape, ToReplace <: Indices, I <: Index] <: Shape = ReplaceFrom match | ||
case head #: tail => ReduceKeepDimsLoop[tail, ToReplace, S[I]] | ||
case SNil => ToReplace match { case INil => SNil } |