-
Notifications
You must be signed in to change notification settings - Fork 852
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2083 from sodahealth/xml-codec
V1 XMLCodec supports encoding + scanning XML column type
- Loading branch information
Showing
6 changed files
with
313 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
package pgtype | ||
|
||
import ( | ||
"database/sql" | ||
"database/sql/driver" | ||
"encoding/xml" | ||
"fmt" | ||
"reflect" | ||
) | ||
|
||
type XMLCodec struct { | ||
Marshal func(v any) ([]byte, error) | ||
Unmarshal func(data []byte, v any) error | ||
} | ||
|
||
func (*XMLCodec) FormatSupported(format int16) bool { | ||
return format == TextFormatCode || format == BinaryFormatCode | ||
} | ||
|
||
func (*XMLCodec) PreferredFormat() int16 { | ||
return TextFormatCode | ||
} | ||
|
||
func (c *XMLCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { | ||
switch value.(type) { | ||
case string: | ||
return encodePlanXMLCodecEitherFormatString{} | ||
case []byte: | ||
return encodePlanXMLCodecEitherFormatByteSlice{} | ||
|
||
// Cannot rely on driver.Valuer being handled later because anything can be marshalled. | ||
// | ||
// https://github.com/jackc/pgx/issues/1430 | ||
// | ||
// Check for driver.Valuer must come before xml.Marshaler so that it is guaranteed to be used | ||
// when both are implemented https://github.com/jackc/pgx/issues/1805 | ||
case driver.Valuer: | ||
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} | ||
|
||
// Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be | ||
// marshalled. | ||
// | ||
// https://github.com/jackc/pgx/issues/1681 | ||
case xml.Marshaler: | ||
return &encodePlanXMLCodecEitherFormatMarshal{ | ||
marshal: c.Marshal, | ||
} | ||
} | ||
|
||
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the | ||
// appropriate wrappers here. | ||
for _, f := range []TryWrapEncodePlanFunc{ | ||
TryWrapDerefPointerEncodePlan, | ||
TryWrapFindUnderlyingTypeEncodePlan, | ||
} { | ||
if wrapperPlan, nextValue, ok := f(value); ok { | ||
if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { | ||
wrapperPlan.SetNext(nextPlan) | ||
return wrapperPlan | ||
} | ||
} | ||
} | ||
|
||
return &encodePlanXMLCodecEitherFormatMarshal{ | ||
marshal: c.Marshal, | ||
} | ||
} | ||
|
||
type encodePlanXMLCodecEitherFormatString struct{} | ||
|
||
func (encodePlanXMLCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { | ||
xmlString := value.(string) | ||
buf = append(buf, xmlString...) | ||
return buf, nil | ||
} | ||
|
||
type encodePlanXMLCodecEitherFormatByteSlice struct{} | ||
|
||
func (encodePlanXMLCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { | ||
xmlBytes := value.([]byte) | ||
if xmlBytes == nil { | ||
return nil, nil | ||
} | ||
|
||
buf = append(buf, xmlBytes...) | ||
return buf, nil | ||
} | ||
|
||
type encodePlanXMLCodecEitherFormatMarshal struct { | ||
marshal func(v any) ([]byte, error) | ||
} | ||
|
||
func (e *encodePlanXMLCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { | ||
xmlBytes, err := e.marshal(value) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
buf = append(buf, xmlBytes...) | ||
return buf, nil | ||
} | ||
|
||
func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { | ||
switch target.(type) { | ||
case *string: | ||
return scanPlanAnyToString{} | ||
|
||
case **string: | ||
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better | ||
// solution would be. | ||
// | ||
// https://github.com/jackc/pgx/issues/1470 -- **string | ||
// https://github.com/jackc/pgx/issues/1691 -- ** anything else | ||
|
||
if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { | ||
if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { | ||
if _, failed := nextPlan.(*scanPlanFail); !failed { | ||
wrapperPlan.SetNext(nextPlan) | ||
return wrapperPlan | ||
} | ||
} | ||
} | ||
|
||
case *[]byte: | ||
return scanPlanXMLToByteSlice{} | ||
case BytesScanner: | ||
return scanPlanBinaryBytesToBytesScanner{} | ||
|
||
// Cannot rely on sql.Scanner being handled later because scanPlanXMLToXMLUnmarshal will take precedence. | ||
// | ||
// https://github.com/jackc/pgx/issues/1418 | ||
case sql.Scanner: | ||
return &scanPlanSQLScanner{formatCode: format} | ||
} | ||
|
||
return &scanPlanXMLToXMLUnmarshal{ | ||
unmarshal: c.Unmarshal, | ||
} | ||
} | ||
|
||
type scanPlanXMLToByteSlice struct{} | ||
|
||
func (scanPlanXMLToByteSlice) Scan(src []byte, dst any) error { | ||
dstBuf := dst.(*[]byte) | ||
if src == nil { | ||
*dstBuf = nil | ||
return nil | ||
} | ||
|
||
*dstBuf = make([]byte, len(src)) | ||
copy(*dstBuf, src) | ||
return nil | ||
} | ||
|
||
type scanPlanXMLToXMLUnmarshal struct { | ||
unmarshal func(data []byte, v any) error | ||
} | ||
|
||
func (s *scanPlanXMLToXMLUnmarshal) Scan(src []byte, dst any) error { | ||
if src == nil { | ||
dstValue := reflect.ValueOf(dst) | ||
if dstValue.Kind() == reflect.Ptr { | ||
el := dstValue.Elem() | ||
switch el.Kind() { | ||
case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Struct: | ||
el.Set(reflect.Zero(el.Type())) | ||
return nil | ||
} | ||
} | ||
|
||
return fmt.Errorf("cannot scan NULL into %T", dst) | ||
} | ||
|
||
elem := reflect.ValueOf(dst).Elem() | ||
elem.Set(reflect.Zero(elem.Type())) | ||
|
||
return s.unmarshal(src, dst) | ||
} | ||
|
||
func (c *XMLCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { | ||
if src == nil { | ||
return nil, nil | ||
} | ||
|
||
dstBuf := make([]byte, len(src)) | ||
copy(dstBuf, src) | ||
return dstBuf, nil | ||
} | ||
|
||
func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { | ||
if src == nil { | ||
return nil, nil | ||
} | ||
|
||
var dst any | ||
err := c.Unmarshal(src, &dst) | ||
return dst, err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package pgtype_test | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"encoding/xml" | ||
"testing" | ||
|
||
pgx "github.com/jackc/pgx/v5" | ||
"github.com/jackc/pgx/v5/pgxtest" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
type xmlStruct struct { | ||
XMLName xml.Name `xml:"person"` | ||
Name string `xml:"name"` | ||
Age int `xml:"age,attr"` | ||
} | ||
|
||
func TestXMLCodec(t *testing.T) { | ||
skipCockroachDB(t, "CockroachDB does not support XML.") | ||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "xml", []pgxtest.ValueRoundTripTest{ | ||
{nil, new(*xmlStruct), isExpectedEq((*xmlStruct)(nil))}, | ||
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, | ||
{map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, | ||
{[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, | ||
{nil, new([]byte), isExpectedEqBytes([]byte(nil))}, | ||
|
||
// Test sql.Scanner. | ||
{"", new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, | ||
|
||
// Test driver.Valuer. | ||
{sql.NullString{String: "", Valid: true}, new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, | ||
}) | ||
|
||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "xml", []pgxtest.ValueRoundTripTest{ | ||
{[]byte(`<?xml version="1.0"?><Root></Root>`), new([]byte), isExpectedEqBytes([]byte(`<Root></Root>`))}, | ||
{[]byte(`<?xml version="1.0"?>`), new([]byte), isExpectedEqBytes([]byte(``))}, | ||
{[]byte(`<?xml version="1.0"?>`), new(string), isExpectedEq(``)}, | ||
{[]byte(`<Root></Root>`), new([]byte), isExpectedEqBytes([]byte(`<Root></Root>`))}, | ||
{[]byte(`<Root></Root>`), new(string), isExpectedEq(`<Root></Root>`)}, | ||
{[]byte(""), new([]byte), isExpectedEqBytes([]byte(""))}, | ||
{xmlStruct{Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, | ||
{xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, | ||
{[]byte(`<person age="10"><name>Adam</name></person>`), new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, | ||
}) | ||
} | ||
|
||
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 | ||
func TestXMLCodecUnmarshalSQLNull(t *testing.T) { | ||
skipCockroachDB(t, "CockroachDB does not support XML.") | ||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { | ||
// Byte arrays are nilified | ||
slice := []byte{10, 4} | ||
err := conn.QueryRow(ctx, "select null::xml").Scan(&slice) | ||
assert.NoError(t, err) | ||
assert.Nil(t, slice) | ||
|
||
// Non-pointer structs are zeroed | ||
m := xmlStruct{Name: "Adam"} | ||
err = conn.QueryRow(ctx, "select null::xml").Scan(&m) | ||
assert.NoError(t, err) | ||
assert.Empty(t, m) | ||
|
||
// Pointers to structs are nilified | ||
pm := &xmlStruct{Name: "Adam"} | ||
err = conn.QueryRow(ctx, "select null::xml").Scan(&pm) | ||
assert.NoError(t, err) | ||
assert.Nil(t, pm) | ||
|
||
// Pointer to pointer are nilified | ||
n := "" | ||
p := &n | ||
err = conn.QueryRow(ctx, "select null::xml").Scan(&p) | ||
assert.NoError(t, err) | ||
assert.Nil(t, p) | ||
|
||
// A string cannot scan a NULL. | ||
str := "foobar" | ||
err = conn.QueryRow(ctx, "select null::xml").Scan(&str) | ||
assert.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string") | ||
}) | ||
} | ||
|
||
func TestXMLCodecPointerToPointerToString(t *testing.T) { | ||
skipCockroachDB(t, "CockroachDB does not support XML.") | ||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { | ||
var s *string | ||
err := conn.QueryRow(ctx, "select ''::xml").Scan(&s) | ||
require.NoError(t, err) | ||
require.NotNil(t, s) | ||
require.Equal(t, "", *s) | ||
|
||
err = conn.QueryRow(ctx, "select null::xml").Scan(&s) | ||
require.NoError(t, err) | ||
require.Nil(t, s) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters