diff --git a/pgtype/json.go b/pgtype/json.go index e71dcb9bf..c2aa0d3bf 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -37,7 +37,7 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco // // https://github.com/jackc/pgx/issues/1430 // - // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to beused + // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to be used // when both are implemented https://github.com/jackc/pgx/issues/1805 case driver.Valuer: return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} @@ -177,13 +177,6 @@ func (scanPlanJSONToByteSlice) Scan(src []byte, dst any) error { return nil } -type scanPlanJSONToBytesScanner struct{} - -func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error { - scanner := (dst).(BytesScanner) - return scanner.ScanBytes(src) -} - type scanPlanJSONToJSONUnmarshal struct { unmarshal func(data []byte, v any) error } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 32d68f403..30f6bdef5 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -26,6 +26,8 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + XMLOID = 142 + XMLArrayOID = 143 JSONArrayOID = 199 PointOID = 600 LsegOID = 601 diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index 9525f37c9..c81257311 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -2,6 +2,7 @@ package pgtype import ( "encoding/json" + "encoding/xml" "net" "net/netip" "reflect" @@ -89,6 +90,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{Marshal: xml.Marshal, Unmarshal: xml.Unmarshal}}) // Range types defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) @@ -153,6 +155,7 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_xml", OID: XMLArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XMLOID]}}) // Integer types that directly map to a PostgreSQL type registerDefaultPgTypeVariants[int16](defaultMap, "int2") diff --git a/pgtype/xml.go b/pgtype/xml.go new file mode 100644 index 000000000..fb4c49ad9 --- /dev/null +++ b/pgtype/xml.go @@ -0,0 +1,198 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "encoding/xml" + "fmt" + "reflect" +) + +type XMLCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} + +func (*XMLCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (*XMLCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (c *XMLCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch value.(type) { + case string: + return encodePlanXMLCodecEitherFormatString{} + case []byte: + return encodePlanXMLCodecEitherFormatByteSlice{} + + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. + // + // https://github.com/jackc/pgx/issues/1430 + // + // Check for driver.Valuer must come before xml.Marshaler so that it is guaranteed to be used + // when both are implemented https://github.com/jackc/pgx/issues/1805 + case driver.Valuer: + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case xml.Marshaler: + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } + } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } +} + +type encodePlanXMLCodecEitherFormatString struct{} + +func (encodePlanXMLCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlString := value.(string) + buf = append(buf, xmlString...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatByteSlice struct{} + +func (encodePlanXMLCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes := value.([]byte) + if xmlBytes == nil { + return nil, nil + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} + +func (e *encodePlanXMLCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes, err := e.marshal(value) + if err != nil { + return nil, err + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch target.(type) { + case *string: + return scanPlanAnyToString{} + + case **string: + // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better + // solution would be. + // + // https://github.com/jackc/pgx/issues/1470 -- **string + // https://github.com/jackc/pgx/issues/1691 -- ** anything else + + if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { + if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + case *[]byte: + return scanPlanXMLToByteSlice{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + + // Cannot rely on sql.Scanner being handled later because scanPlanXMLToXMLUnmarshal will take precedence. + // + // https://github.com/jackc/pgx/issues/1418 + case sql.Scanner: + return &scanPlanSQLScanner{formatCode: format} + } + + return &scanPlanXMLToXMLUnmarshal{ + unmarshal: c.Unmarshal, + } +} + +type scanPlanXMLToByteSlice struct{} + +func (scanPlanXMLToByteSlice) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanXMLToXMLUnmarshal struct { + unmarshal func(data []byte, v any) error +} + +func (s *scanPlanXMLToXMLUnmarshal) Scan(src []byte, dst any) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Struct: + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + elem := reflect.ValueOf(dst).Elem() + elem.Set(reflect.Zero(elem.Type())) + + return s.unmarshal(src, dst) +} + +func (c *XMLCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil +} + +func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var dst any + err := c.Unmarshal(src, &dst) + return dst, err +} diff --git a/pgtype/xml_test.go b/pgtype/xml_test.go new file mode 100644 index 000000000..0f755e96f --- /dev/null +++ b/pgtype/xml_test.go @@ -0,0 +1,99 @@ +package pgtype_test + +import ( + "context" + "database/sql" + "encoding/xml" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type xmlStruct struct { + XMLName xml.Name `xml:"person"` + Name string `xml:"name"` + Age int `xml:"age,attr"` +} + +func TestXMLCodec(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "xml", []pgxtest.ValueRoundTripTest{ + {nil, new(*xmlStruct), isExpectedEq((*xmlStruct)(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + + // Test sql.Scanner. + {"", new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, + + // Test driver.Valuer. + {sql.NullString{String: "", Valid: true}, new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "xml", []pgxtest.ValueRoundTripTest{ + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new(string), isExpectedEq(``)}, + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new(string), isExpectedEq(``)}, + {[]byte(""), new([]byte), isExpectedEqBytes([]byte(""))}, + {xmlStruct{Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + {xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + {[]byte(`Adam`), new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + }) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 +func TestXMLCodecUnmarshalSQLNull(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Byte arrays are nilified + slice := []byte{10, 4} + err := conn.QueryRow(ctx, "select null::xml").Scan(&slice) + assert.NoError(t, err) + assert.Nil(t, slice) + + // Non-pointer structs are zeroed + m := xmlStruct{Name: "Adam"} + err = conn.QueryRow(ctx, "select null::xml").Scan(&m) + assert.NoError(t, err) + assert.Empty(t, m) + + // Pointers to structs are nilified + pm := &xmlStruct{Name: "Adam"} + err = conn.QueryRow(ctx, "select null::xml").Scan(&pm) + assert.NoError(t, err) + assert.Nil(t, pm) + + // Pointer to pointer are nilified + n := "" + p := &n + err = conn.QueryRow(ctx, "select null::xml").Scan(&p) + assert.NoError(t, err) + assert.Nil(t, p) + + // A string cannot scan a NULL. + str := "foobar" + err = conn.QueryRow(ctx, "select null::xml").Scan(&str) + assert.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string") + }) +} + +func TestXMLCodecPointerToPointerToString(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var s *string + err := conn.QueryRow(ctx, "select ''::xml").Scan(&s) + require.NoError(t, err) + require.NotNil(t, s) + require.Equal(t, "", *s) + + err = conn.QueryRow(ctx, "select null::xml").Scan(&s) + require.NoError(t, err) + require.Nil(t, s) + }) +} diff --git a/stdlib/sql.go b/stdlib/sql.go index cf76900a5..c1d00ab40 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -795,6 +795,16 @@ func (r *Rows) Next(dest []driver.Value) error { } return d.Value() } + case pgtype.XMLOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d, nil + } default: var d string scanPlan := m.PlanScan(dataTypeOID, format, &d)