Skip to content

Commit

Permalink
Merge pull request #1624 from disneystreaming/nan-edge-cases
Browse files Browse the repository at this point in the history
Handle edge cases with NaN/Infinity
  • Loading branch information
Baccata authored Dec 5, 2024
2 parents e369dd4 + 3b91604 commit 2522c02
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 34 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ When adding entries, please treat them as if they could end up in a release any

Thank you!

# 0.18.27

* Fix for how `NaN` is handled for `Float` and `Double` inside of the `MetadataDecoder` and `Range` constraint `RefinementProvider`

# 0.18.26

* Optimises the conversion of empty smithy4s.Blob to fs2.Stream, to avoid performance degradation in Ember (see [#1609](https://github.com/disneystreaming/smithy4s/pull/1609))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import smithy4s.schema.Schema.unit
trait DummyServiceGen[F[_, _, _, _, _]] {
self =>

def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None): F[Queries, Nothing, Unit, Nothing, Nothing]
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None): F[Queries, Nothing, Unit, Nothing, Nothing]
def dummyHostPrefix(label1: String, label2: String, label3: HostLabelEnum): F[HostLabelInput, Nothing, Unit, Nothing, Nothing]
def dummyPath(str: String, int: Int, ts1: Timestamp, ts2: Timestamp, ts3: Timestamp, ts4: Timestamp, b: Boolean, ie: Numbers): F[PathParams, Nothing, Unit, Nothing, Nothing]

Expand Down Expand Up @@ -69,12 +69,12 @@ sealed trait DummyServiceOperation[Input, Err, Output, StreamedInput, StreamedOu
object DummyServiceOperation {

object reified extends DummyServiceGen[DummyServiceOperation] {
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None): Dummy = Dummy(Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, slm))
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None): Dummy = Dummy(Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, dbl, slm))
def dummyHostPrefix(label1: String, label2: String, label3: HostLabelEnum): DummyHostPrefix = DummyHostPrefix(HostLabelInput(label1, label2, label3))
def dummyPath(str: String, int: Int, ts1: Timestamp, ts2: Timestamp, ts3: Timestamp, ts4: Timestamp, b: Boolean, ie: Numbers): DummyPath = DummyPath(PathParams(str, int, ts1, ts2, ts3, ts4, b, ie))
}
class Transformed[P[_, _, _, _, _], P1[_ ,_ ,_ ,_ ,_]](alg: DummyServiceGen[P], f: PolyFunction5[P, P1]) extends DummyServiceGen[P1] {
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None): P1[Queries, Nothing, Unit, Nothing, Nothing] = f[Queries, Nothing, Unit, Nothing, Nothing](alg.dummy(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, slm))
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None): P1[Queries, Nothing, Unit, Nothing, Nothing] = f[Queries, Nothing, Unit, Nothing, Nothing](alg.dummy(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, dbl, slm))
def dummyHostPrefix(label1: String, label2: String, label3: HostLabelEnum): P1[HostLabelInput, Nothing, Unit, Nothing, Nothing] = f[HostLabelInput, Nothing, Unit, Nothing, Nothing](alg.dummyHostPrefix(label1, label2, label3))
def dummyPath(str: String, int: Int, ts1: Timestamp, ts2: Timestamp, ts3: Timestamp, ts4: Timestamp, b: Boolean, ie: Numbers): P1[PathParams, Nothing, Unit, Nothing, Nothing] = f[PathParams, Nothing, Unit, Nothing, Nothing](alg.dummyPath(str, int, ts1, ts2, ts3, ts4, b, ie))
}
Expand All @@ -83,7 +83,7 @@ object DummyServiceOperation {
def apply[I, E, O, SI, SO](op: DummyServiceOperation[I, E, O, SI, SO]): P[I, E, O, SI, SO] = op.run(impl)
}
final case class Dummy(input: Queries) extends DummyServiceOperation[Queries, Nothing, Unit, Nothing, Nothing] {
def run[F[_, _, _, _, _]](impl: DummyServiceGen[F]): F[Queries, Nothing, Unit, Nothing, Nothing] = impl.dummy(input.str, input.int, input.ts1, input.ts2, input.ts3, input.ts4, input.b, input.sl, input.ie, input.on, input.ons, input.slm)
def run[F[_, _, _, _, _]](impl: DummyServiceGen[F]): F[Queries, Nothing, Unit, Nothing, Nothing] = impl.dummy(input.str, input.int, input.ts1, input.ts2, input.ts3, input.ts4, input.b, input.sl, input.ie, input.on, input.ons, input.dbl, input.slm)
def ordinal: Int = 0
def endpoint: smithy4s.Endpoint[DummyServiceOperation,Queries, Nothing, Unit, Nothing, Nothing] = Dummy
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ import smithy4s.ShapeId
import smithy4s.ShapeTag
import smithy4s.Timestamp
import smithy4s.schema.Schema.boolean
import smithy4s.schema.Schema.double
import smithy4s.schema.Schema.int
import smithy4s.schema.Schema.string
import smithy4s.schema.Schema.struct
import smithy4s.schema.Schema.timestamp

final case class Queries(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None)
final case class Queries(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None)

object Queries extends ShapeTag.Companion[Queries] {
val id: ShapeId = ShapeId("smithy4s.example", "Queries")

val hints: Hints = Hints.empty

// constructor using the original order from the spec
private def make(str: Option[String], int: Option[Int], ts1: Option[Timestamp], ts2: Option[Timestamp], ts3: Option[Timestamp], ts4: Option[Timestamp], b: Option[Boolean], sl: Option[List[String]], ie: Option[Numbers], on: Option[OpenNums], ons: Option[OpenNumsStr], slm: Option[Map[String, String]]): Queries = Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, slm)
private def make(str: Option[String], int: Option[Int], ts1: Option[Timestamp], ts2: Option[Timestamp], ts3: Option[Timestamp], ts4: Option[Timestamp], b: Option[Boolean], sl: Option[List[String]], ie: Option[Numbers], on: Option[OpenNums], ons: Option[OpenNumsStr], dbl: Option[Double], slm: Option[Map[String, String]]): Queries = Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, dbl, slm)

implicit val schema: Schema[Queries] = struct(
string.optional[Queries]("str", _.str).addHints(smithy.api.HttpQuery("str")),
Expand All @@ -33,6 +34,7 @@ object Queries extends ShapeTag.Companion[Queries] {
Numbers.schema.optional[Queries]("ie", _.ie).addHints(smithy.api.HttpQuery("nums")),
OpenNums.schema.optional[Queries]("on", _.on).addHints(smithy.api.HttpQuery("openNums")),
OpenNumsStr.schema.optional[Queries]("ons", _.ons).addHints(smithy.api.HttpQuery("openNumsStr")),
double.validated(smithy.api.Range(min = Some(scala.math.BigDecimal(0.0)), max = Some(scala.math.BigDecimal(100.0)))).optional[Queries]("dbl", _.dbl).addHints(smithy.api.HttpQuery("dbl")),
StringMap.underlyingSchema.optional[Queries]("slm", _.slm).addHints(smithy.api.HttpQueryParams()),
)(make).withId(id).addHints(hints)
}
22 changes: 22 additions & 0 deletions modules/bootstrapped/test/src/smithy4s/DocumentSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import munit._
import smithy4s.example.DefaultNullsOperationOutput
import alloy.Untagged
import smithy4s.example.TimestampOperationInput
import scala.util.Try

class DocumentSpec() extends FunSuite {

Expand Down Expand Up @@ -370,6 +371,27 @@ class DocumentSpec() extends FunSuite {
expect.same(roundTripped, Right(mapTest))
}

test("encoding NaN") {
// The Document type cannot hold a `NaN` value since it uses BigDecimal to hold numeric values
// this test exists to show this. For the same reason, a test on decoding from `NaN` is not necessary
// or possible.
implicit val schema: Schema[Double] =
double.validated(smithy.api.Range(None, Some(BigDecimal(3))))

val in = Double.NaN
val error = Try(Document.encode(in)).failed.get
val expectedMessage =
if (weaver.Platform.isJS || weaver.Platform.isNative)
"For input string: \"NaN\""
else
"Character N is neither a decimal digit number, decimal point, nor \"e\" notation exponential mark."

expect.same(
error.getMessage,
expectedMessage
)
}

test(
"optional fields for structs should decode Document.DNull"
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,28 @@ class MetadataSpec() extends FunSuite {
.left
.map(_.getMessage())
expect.same(encoded, expectedEncoding)
expect(result == Right(finished))
expect.same(result, Right(finished))
}

def checkQueryRoundTripError[A](
initial: A,
expectedEncoding: Metadata,
errorMessage: String,
allowNaN: Boolean
)(implicit
s: Schema[A],
loc: Location
): Unit = {
val encoded = Metadata.encode(initial)
val decoder =
if (allowNaN) Metadata.AwsDecoder.fromSchema(s)
else Metadata.Decoder.fromSchema(s)
val result = decoder
.decode(encoded)
.left
.map(_.getMessage())
expect.same(encoded, expectedEncoding)
expect.same(result, Left(errorMessage))
}

def checkRoundTripDefault[A](expectedDecoded: A)(implicit
Expand Down Expand Up @@ -123,6 +144,27 @@ class MetadataSpec() extends FunSuite {
checkQueryRoundTrip(queries, expected, finished)
}

// In this test the Metadata Decoder will allow NaN by creating a `Double.NaN` value.
// The Range RefinementProvider will reject this since `NaN` is not a valid `BigDecimal`
// which it uses
test("Double NaN query parameter - allow NaN in decoder") {
val queries = Queries(dbl = Some(Double.NaN))
val expected = Metadata(query = Map("dbl" -> List("NaN")))
val errorMessage =
"Field dbl, found in Query parameter dbl, failed constraint checks with message: Numeric values must not be NaN or pos/neg infinity. Found NaN"
checkQueryRoundTripError(queries, expected, errorMessage, allowNaN = true)
}

// This test is where the Metadata Decoder will reject NaN itself
// As such the RefinementProvider for Range will not be called in this test
test("Double NaN query parameter - disallow NaN in decoder") {
val queries = Queries(dbl = Some(Double.NaN))
val expected = Metadata(query = Map("dbl" -> List("NaN")))
val errorMessage =
"NaN or pos/neg infinity are not allowed for inputs of type Double"
checkQueryRoundTripError(queries, expected, errorMessage, allowNaN = false)
}

test("String query parameter with default") {
val expectedDecoded = QueriesWithDefaults(dflt = "test")
checkRoundTripDefault(expectedDecoded)
Expand Down
49 changes: 28 additions & 21 deletions modules/core/src/smithy4s/RefinementProvider.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,27 +161,34 @@ object RefinementProvider extends LowPriorityImplicits {
val N = implicitly[Numeric[N]]

(a: A) =>
val value = BigDecimal(N.toDouble(getValue(a)))
(range.min, range.max) match {
case (Some(min), Some(max)) =>
if (value >= min && value <= max) Right(())
else
Left(
s"Input must be >= $min and <= $max, but was $value"
)
case (None, Some(max)) =>
if (value <= max) Right(())
else
Left(
s"Input must be <= $max, but was $value"
)
case (Some(min), None) =>
if (value >= min) Right(())
else
Left(
s"Input must be >= $min, but was $value"
)
case (None, None) => Right(())
val doubleValue = N.toDouble(getValue(a))
if (doubleValue.isNaN || doubleValue.isInfinite) {
Left(
s"Numeric values must not be NaN or pos/neg infinity. Found $doubleValue"
)
} else {
val value = BigDecimal.apply(d = doubleValue)
(range.min, range.max) match {
case (Some(min), Some(max)) =>
if (value >= min && value <= max) Right(())
else
Left(
s"Input must be >= $min and <= $max, but was $value"
)
case (None, Some(max)) =>
if (value <= max) Right(())
else
Left(
s"Input must be <= $max, but was $value"
)
case (Some(min), None) =>
if (value >= min) Right(())
else
Left(
s"Input must be >= $min, but was $value"
)
case (None, None) => Right(())
}
}
}
}
Expand Down
23 changes: 18 additions & 5 deletions modules/core/src/smithy4s/http/Metadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,24 @@ object Metadata {
implicit def decoderFromSchema[A: Schema]: Decoder[A] =
Decoder.derivedImplicitInstance

object Decoder extends CachedDecoderCompilerImpl(awsHeaderEncoding = false) {
object Decoder
extends CachedDecoderCompilerImpl(
awsHeaderEncoding = false,
allowNaNAndInfiniteValues = false
) {
type Compiler = CachedSchemaCompiler[Decoder]
}

private[smithy4s] object AwsDecoder
extends CachedDecoderCompilerImpl(awsHeaderEncoding = true)
extends CachedDecoderCompilerImpl(
awsHeaderEncoding = true,
allowNaNAndInfiniteValues = true
)

private[http] class CachedDecoderCompilerImpl(awsHeaderEncoding: Boolean)
extends CachedSchemaCompiler.DerivingImpl[Decoder] {
private[http] class CachedDecoderCompilerImpl(
awsHeaderEncoding: Boolean,
allowNaNAndInfiniteValues: Boolean
) extends CachedSchemaCompiler.DerivingImpl[Decoder] {
type Aux[A] = internals.MetaDecode[A]

def apply[A](implicit instance: Decoder[A]): Decoder[A] =
Expand All @@ -190,7 +199,11 @@ object Metadata {
cache: CompilationCache[internals.MetaDecode]
): Decoder[A] = {
val metaDecode =
new SchemaVisitorMetadataReader(cache, awsHeaderEncoding)(schema)
new SchemaVisitorMetadataReader(
cache,
awsHeaderEncoding,
allowNaNAndInfiniteValues
)(schema)
metaDecode match {
case internals.MetaDecode.StructureMetaDecode(decodeFunction) =>
decodeFunction(_: Metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ import java.util.Base64
* contains values such as path-parameters, query-parameters, headers, and status code.
*
* @param awsHeaderEncoding defines whether the AWS encoding of headers should be expected.
* @param allowNaNAndInfiniteValues defines whether or not Double and Float values of 'NaN'
* positive/negative infinity should be accepted.
*/
private[http] class SchemaVisitorMetadataReader(
val cache: CompilationCache[MetaDecode],
awsHeaderEncoding: Boolean
awsHeaderEncoding: Boolean,
allowNaNAndInfiniteValues: Boolean
) extends SchemaVisitor.Cached[MetaDecode]
with ScalaCompat { self =>

Expand All @@ -50,6 +53,38 @@ private[http] class SchemaVisitorMetadataReader(
tag: Primitive[P]
): MetaDecode[P] = {
val desc = SchemaDescription.primitive(shapeId, hints, tag)

tag match {
case Primitive.PDouble =>
val decode: MetaDecode[Double] =
primitiveHandler(shapeId, hints, tag, desc)
decode.map(d =>
if (!allowNaNAndInfiniteValues && (d.isNaN || d.isInfinite))
throw MetadataError.ImpossibleDecoding(
s"NaN or pos/neg infinity are not allowed for inputs of type $desc"
)
else d
)
case Primitive.PFloat =>
val decode: MetaDecode[Float] =
primitiveHandler(shapeId, hints, tag, desc)
decode.map(f =>
if (!allowNaNAndInfiniteValues && (f.isNaN || f.isInfinite))
throw MetadataError.ImpossibleDecoding(
s"NaN or pos/neg infinity are not allowed for inputs of type $desc"
)
else f
)
case _ => primitiveHandler(shapeId, hints, tag, desc)
}
}

private def primitiveHandler[P](
shapeId: ShapeId,
hints: Hints,
tag: Primitive[P],
desc: String
): MetaDecode[P] = {
val hasMedia = hints.has(smithy.api.MediaType)
Primitive.stringParser(tag, hints) match {
case Some(parse) if hasMedia =>
Expand Down
11 changes: 11 additions & 0 deletions modules/json/test/src/smithy4s/json/JsonSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class JsonSpec() extends FunSuite {
assertEquals(roundTripped, Right(foo))
}

test("Json read - NaN") {
implicit val schemaDouble: Schema[Double] =
double.validated(smithy.api.Range(None, Some(BigDecimal(3))))
val expectedJson = """"NaN""""
val roundTripped = Json.read[Double](Blob(expectedJson))

assert(
roundTripped.left.toOption.get.message.startsWith("illegal number")
)
}

test("Json document read/write") {
val foo =
Document.obj("a" -> Document.fromInt(1), "b" -> Document.fromInt(2))
Expand Down
3 changes: 3 additions & 0 deletions sampleSpecs/metadata.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ structure Queries {
on: OpenNums
@httpQuery("openNumsStr")
ons: OpenNumsStr
@httpQuery("dbl")
@range(min: 0, max: 100)
dbl: Double
@httpQueryParams
slm: StringMap
}
Expand Down

0 comments on commit 2522c02

Please sign in to comment.