Skip to content

Commit

Permalink
implement json marshal and unmarshal for range
Browse files Browse the repository at this point in the history
  • Loading branch information
acim committed Dec 4, 2022
1 parent 24c5325 commit 1e883e7
Show file tree
Hide file tree
Showing 4 changed files with 489 additions and 5 deletions.
41 changes: 40 additions & 1 deletion pgtype/numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"math/big"
Expand Down Expand Up @@ -233,13 +234,51 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
return []byte("null"), nil
}

if n.NaN {
switch {
case n.InfinityModifier == Infinity:
return []byte(`"infinity"`), nil
case n.InfinityModifier == NegativeInfinity:
return []byte(`"-infinity"`), nil
case n.NaN:
return []byte(`"NaN"`), nil
}

return n.numberTextBytes(), nil
}

func (n *Numeric) UnmarshalJSON(b []byte) error {
var s *string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}

if s == nil {
*n = Numeric{}
return nil
}

switch *s {
case "infinity":
*n = Numeric{NaN: true, InfinityModifier: Infinity, Valid: true}
case "-infinity":
*n = Numeric{NaN: true, InfinityModifier: -Infinity, Valid: true}
default:
num, exp, err := parseNumericString(*s)
if err != nil {
return fmt.Errorf("failed to decode %s to numeric: %w", *s, err)
}

*n = Numeric{
Int: num,
Exp: exp,
Valid: true,
}
}

return nil
}

// numberString returns a string of the number. undefined if NaN, infinite, or NULL
func (n Numeric) numberTextBytes() []byte {
intStr := n.Int.String()
Expand Down
71 changes: 71 additions & 0 deletions pgtype/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgtype
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
)

Expand Down Expand Up @@ -320,3 +321,73 @@ func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error {
r.Valid = true
return nil
}

func (r Range[T]) MarshalJSON() ([]byte, error) {
if !r.Valid {
return []byte("null"), nil
}

enc := encodePlanRangeCodecRangeValuerToText{
m: &encodePlanRangeCodecJson{},
}

buf, err := enc.Encode(r, []byte(`"`))
if err != nil {
return nil, fmt.Errorf("failed to encode %v as range: %w", r, err)
}

buf = append(buf, `"`...)

return buf, nil
}

func (r *Range[T]) UnmarshalJSON(b []byte) error {
if b[0] == byte('"') && b[len(b)-1] == byte('"') {
b = b[1 : len(b)-1]
}

s := string(b)

if s == "null" {
*r = Range[T]{}
return nil
}

utr, err := parseUntypedTextRange(s)
if err != nil {
return fmt.Errorf("failed to decode %s to range: %w", s, err)
}

*r = Range[T]{
LowerType: utr.LowerType,
UpperType: utr.UpperType,
Valid: true,
}

if r.LowerType == Empty && r.UpperType == Empty {
return nil
}

if r.LowerType != Unbounded {
if err = r.unmarshalJSON(utr.Lower, &r.Lower); err != nil {
return fmt.Errorf("failed to decode %s to range lower: %w", utr.Lower, err)
}
}

if r.UpperType != Unbounded {
if err = r.unmarshalJSON(utr.Upper, &r.Upper); err != nil {
return fmt.Errorf("failed to decode %s to range upper: %w", utr.Upper, err)
}
}

return nil
}

func (_ *Range[T]) unmarshalJSON(data string, v *T) error {
buf := make([]byte, 0, len(data)+2)
buf = append(buf, `"`...)
buf = append(buf, data...)
buf = append(buf, `"`...)

return json.Unmarshal(buf, v)
}
38 changes: 34 additions & 4 deletions pgtype/range_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgtype

import (
"database/sql/driver"
"encoding/json"
"fmt"

"github.com/jackc/pgx/v5/internal/pgio"
Expand Down Expand Up @@ -157,8 +158,10 @@ func (plan *encodePlanRangeCodecRangeValuerToBinary) Encode(value any, buf []byt
}

type encodePlanRangeCodecRangeValuerToText struct {
rc *RangeCodec
m *Map
rc Codec
m interface {
PlanEncode(oid uint32, format int16, value any) EncodePlan
}
}

func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte) (newBuf []byte, err error) {
Expand All @@ -182,12 +185,18 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
return nil, fmt.Errorf("unknown lower bound type %v", lowerType)
}

var oid uint32

if rc, ok := plan.rc.(*RangeCodec); ok {
oid = rc.ElementType.OID
}

if lowerType != Unbounded {
if lower == nil {
return nil, fmt.Errorf("Lower cannot be NULL unless LowerType is Unbounded")
}

lowerPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, lower)
lowerPlan := plan.m.PlanEncode(oid, TextFormatCode, lower)
if lowerPlan == nil {
return nil, fmt.Errorf("cannot encode %v as element of range", lower)
}
Expand All @@ -208,7 +217,7 @@ func (plan *encodePlanRangeCodecRangeValuerToText) Encode(value any, buf []byte)
return nil, fmt.Errorf("Upper cannot be NULL unless UpperType is Unbounded")
}

upperPlan := plan.m.PlanEncode(plan.rc.ElementType.OID, TextFormatCode, upper)
upperPlan := plan.m.PlanEncode(oid, TextFormatCode, upper)
if upperPlan == nil {
return nil, fmt.Errorf("cannot encode %v as element of range", upper)
}
Expand Down Expand Up @@ -377,3 +386,24 @@ func (c *RangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (
err := c.PlanScan(m, oid, format, &r).Scan(src, &r)
return r, err
}

type encodePlanRangeCodecJson struct{}

func (s *encodePlanRangeCodecJson) PlanEncode(_ uint32, _ int16, _ any) EncodePlan {
return s
}

func (s *encodePlanRangeCodecJson) Encode(value any, buf []byte) (newBuf []byte, err error) {
b, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("failed to encode %v as element of range: %v", value, err)
}

if b[0] == byte('"') && b[len(b)-1] == byte('"') {
buf = append(buf, b[1:len(b)-1]...)
} else {
buf = append(buf, b...)
}

return buf, nil
}
Loading

0 comments on commit 1e883e7

Please sign in to comment.