Skip to content

Commit

Permalink
Add support for passing time.Duration
Browse files Browse the repository at this point in the history
  • Loading branch information
joris-bright authored and nineinchnick committed Sep 19, 2024
1 parent 806af86 commit 2b16344
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ types:
passed to Trino as a time with a time zone
* the result of `trino.Timestamp(year, month, day, hour, minute, second,
nanosecond)` - passed to Trino as a timestamp without a time zone
* `time.Duration` - passed to Trino as an interval day to second. Because Trino does not support nanosecond precision for intervals, if the nanosecond part of the value is not zero, an error will be returned.

It's not yet possible to pass:
* `float32` or `float64`
* `byte`
* `time.Duration`
* `json.RawMessage`
* maps

Expand Down
125 changes: 125 additions & 0 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"fmt"
"io"
"log"
"math"
"math/big"
"net/http"
"os"
Expand Down Expand Up @@ -987,3 +988,127 @@ func contextSleep(ctx context.Context, d time.Duration) error {
return ctx.Err()
}
}

func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) {
db := integrationOpen(t)
defer db.Close()
tests := []struct {
name string
arg time.Duration
wantErr bool
}{
{
name: "valid 1234567891s",
arg: time.Duration(1234567891) * time.Second,
wantErr: false,
},
{
name: "valid 123456789.1s",
arg: time.Duration(123456789100) * time.Millisecond,
wantErr: false,
},
{
name: "valid 12345678.91s",
arg: time.Duration(12345678910) * time.Millisecond,
wantErr: false,
},
{
name: "valid 1234567.891s",
arg: time.Duration(1234567891) * time.Millisecond,
wantErr: false,
},
{
name: "valid -1234567891s",
arg: time.Duration(-1234567891) * time.Second,
wantErr: false,
},
{
name: "valid -123456789.1s",
arg: time.Duration(-123456789100) * time.Millisecond,
wantErr: false,
},
{
name: "valid -12345678.91s",
arg: time.Duration(-12345678910) * time.Millisecond,
wantErr: false,
},
{
name: "valid -1234567.891s",
arg: time.Duration(-1234567891) * time.Millisecond,
wantErr: false,
},
{
name: "invalid 1234567891.2s",
arg: time.Duration(1234567891200) * time.Millisecond,
wantErr: true,
},
{
name: "invalid 123456789.12s",
arg: time.Duration(123456789120) * time.Millisecond,
wantErr: true,
},
{
name: "invalid 12345678.912s",
arg: time.Duration(12345678912) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -1234567891.2s",
arg: time.Duration(-1234567891200) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -123456789.12s",
arg: time.Duration(-123456789120) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -12345678.912s",
arg: time.Duration(-12345678912) * time.Millisecond,
wantErr: true,
},
{
name: "invalid max seconds (9223372036)",
arg: time.Duration(math.MaxInt64) / time.Second * time.Second,
wantErr: true,
},
{
name: "invalid min seconds (-9223372036)",
arg: time.Duration(math.MinInt64) / time.Second * time.Second,
wantErr: true,
},
{
name: "valid max seconds (2147483647)",
arg: math.MaxInt32 * time.Second,
},
{
name: "valid min seconds (-2147483647)",
arg: -math.MaxInt32 * time.Second,
},
{
name: "valid max minutes (153722867)",
arg: time.Duration(math.MaxInt64) / time.Minute * time.Minute,
},
{
name: "valid min minutes (-153722867)",
arg: time.Duration(math.MinInt64) / time.Minute * time.Minute,
},
{
name: "valid max hours (2562047)",
arg: time.Duration(math.MaxInt64) / time.Hour * time.Hour,
},
{
name: "valid min hours (-2562047)",
arg: time.Duration(math.MinInt64) / time.Hour * time.Hour,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := db.Exec("SELECT ?", test.arg)
if (err != nil) != test.wantErr {
t.Errorf("Exec() error = %v, wantErr %v", err, test.wantErr)
return
}
})
}
}
51 changes: 50 additions & 1 deletion trino/serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package trino
import (
"encoding/json"
"fmt"
"math"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -163,7 +164,7 @@ func Serial(v interface{}) (string, error) {
return "TIMESTAMP " + time.Time(x).Format("'2006-01-02 15:04:05.999999999 Z07:00'"), nil

case time.Duration:
return "", UnsupportedArgError{"time.Duration"}
return serialDuration(x)

// TODO - json.RawMesssage should probably be matched to 'JSON' in Trino
case json.RawMessage:
Expand Down Expand Up @@ -208,3 +209,51 @@ func serialSlice(v []interface{}) (string, error) {

return "ARRAY[" + strings.Join(ss, ", ") + "]", nil
}

const (
// For seconds with milliseconds there is a maximum length of 10 digits
// or 11 characters with the dot and 12 characters with the minus sign and dot
maxIntervalStrLenWithDot = 11 // 123456789.1 and 12345678.91 are valid
)

func serialDuration(dur time.Duration) (string, error) {
switch {
case dur%time.Hour == 0:
return serialHoursInterval(dur), nil
case dur%time.Minute == 0:
return serialMinutesInterval(dur), nil
case dur%time.Second == 0:
return serialSecondsInterval(dur)
case dur%time.Millisecond == 0:
return serialMillisecondsInterval(dur)
default:
return "", fmt.Errorf("trino: duration %v is not a multiple of hours, minutes, seconds or milliseconds", dur)
}
}

func serialHoursInterval(dur time.Duration) string {
return "INTERVAL '" + strconv.Itoa(int(dur/time.Hour)) + "' HOUR"
}

func serialMinutesInterval(dur time.Duration) string {
return "INTERVAL '" + strconv.Itoa(int(dur/time.Minute)) + "' MINUTE"
}

func serialSecondsInterval(dur time.Duration) (string, error) {
seconds := int64(dur / time.Second)
if seconds <= math.MinInt32 || seconds > math.MaxInt32 {
return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds type", dur)
}
return "INTERVAL '" + strconv.FormatInt(seconds, 10) + "' SECOND", nil
}

func serialMillisecondsInterval(dur time.Duration) (string, error) {
seconds := int64(dur / time.Second)
millisInSecond := dur.Abs().Milliseconds() % 1000
intervalNr := strings.TrimRight(fmt.Sprintf("%d.%03d", seconds, millisInSecond), "0")
if seconds > 0 && len(intervalNr) > maxIntervalStrLenWithDot ||
seconds < 0 && len(intervalNr) > maxIntervalStrLenWithDot+1 { // +1 for the minus sign
return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds with millis type", dur)
}
return "INTERVAL '" + intervalNr + "' SECOND", nil
}
81 changes: 81 additions & 0 deletions trino/serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package trino

import (
"math"
"testing"
"time"

Expand Down Expand Up @@ -160,6 +161,86 @@ func TestSerial(t *testing.T) {
value: time.Date(2017, 7, 10, 11, 34, 25, 123456, time.UTC),
expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456 Z'",
},
{
name: "duration",
value: 10*time.Second + 5*time.Millisecond,
expectedSerial: "INTERVAL '10.005' SECOND",
},
{
name: "duration with negative value",
value: -(10*time.Second + 5*time.Millisecond),
expectedSerial: "INTERVAL '-10.005' SECOND",
},
{
name: "minute duration",
value: 10 * time.Minute,
expectedSerial: "INTERVAL '10' MINUTE",
},
{
name: "hour duration",
value: 23 * time.Hour,
expectedSerial: "INTERVAL '23' HOUR",
},
{
name: "max hour duration",
value: (math.MaxInt64 / time.Hour) * time.Hour,
expectedSerial: "INTERVAL '2562047' HOUR",
},
{
name: "min hour duration",
value: (math.MinInt64 / time.Hour) * time.Hour,
expectedSerial: "INTERVAL '-2562047' HOUR",
},
{
name: "max minute duration",
value: (math.MaxInt64 / time.Minute) * time.Minute,
expectedSerial: "INTERVAL '153722867' MINUTE",
},
{
name: "min minute duration",
value: (math.MinInt64 / time.Minute) * time.Minute,
expectedSerial: "INTERVAL '-153722867' MINUTE",
},
{
name: "too big second duration",
value: (math.MaxInt64 / time.Second) * time.Second,
expectedError: true,
},
{
name: "too small second duration",
value: (math.MinInt64 / time.Second) * time.Second,
expectedError: true,
},
{
name: "too big millisecond duration",
value: time.Millisecond*912 + time.Second*12345678,
expectedError: true,
},
{
name: "too small millisecond duration",
value: -(time.Millisecond*910 + time.Second*123456789),
expectedError: true,
},
{
name: "max allowed second duration",
value: math.MaxInt32 * time.Second,
expectedSerial: "INTERVAL '2147483647' SECOND",
},
{
name: "min allowed second duration",
value: -math.MaxInt32 * time.Second,
expectedSerial: "INTERVAL '-2147483647' SECOND",
},
{
name: "max allowed second with milliseconds duration",
value: 999999999*time.Second + 900*time.Millisecond,
expectedSerial: "INTERVAL '999999999.9' SECOND",
},
{
name: "min allowed second with milliseconds duration",
value: -999999999*time.Second - 900*time.Millisecond,
expectedSerial: "INTERVAL '-999999999.9' SECOND",
},
{
name: "nil",
value: nil,
Expand Down
2 changes: 1 addition & 1 deletion trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case nil:
return nil
case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp:
case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp, time.Duration:
return nil
default:
{
Expand Down

0 comments on commit 2b16344

Please sign in to comment.