Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,18 @@ func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter,
case []byte:
return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true
case fmt.Stringer:
return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true
// Check if the value is a renamed basic type. If it is, prefer the basic type encoding.
rv := reflect.ValueOf(value)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.Bool, reflect.String:
// For renamed basic types, don't use Stringer interface automatically
// Let the specific type match above handle it
default:
// For structs and other complex types that implement Stringer, use the Stringer interface
return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true
}
}

return nil, nil, false
Expand Down
51 changes: 51 additions & 0 deletions values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,57 @@ func TestEncodeTypeRename(t *testing.T) {
})
}

// Define custom types that are aliases of basic types but also implement fmt.Stringer
type StringerInt32 int32
type StringerFloat64 float64

// Implement the String() method for these types
func (s StringerInt32) String() string {
return fmt.Sprintf("StringerInt32(%d)", int32(s))
}

func (s StringerFloat64) String() string {
return fmt.Sprintf("StringerFloat64(%f)", float64(s))
}

// TestStringerTypes tests custom type aliases that implement the fmt.Stringer interface
func TestStringerTypes(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
// Test values
inInt := StringerInt32(42)
var outInt StringerInt32

inFloat := StringerFloat64(553.36)
var outFloat StringerFloat64

// Register types with the connection
conn.TypeMap().RegisterDefaultPgType(inInt, "int4")
conn.TypeMap().RegisterDefaultPgType(inFloat, "float8")

// Test that the underlying values are properly encoded/decoded,
// not the String() representation
err := conn.QueryRow(context.Background(), "select $1::int4, $2::float8", inInt, inFloat).
Scan(&outInt, &outFloat)
if err != nil {
t.Fatalf("Failed with Stringer types: %v", err)
}

// Check that the values are correctly preserved (not converted to their String() representation)
if inInt != outInt {
t.Errorf("StringerInt32: expected %v, got %v", inInt, outInt)
}

if inFloat != outFloat {
t.Errorf("StringerFloat64: expected %v, got %v", inFloat, outFloat)
}
})
}

// func TestRowDecodeBinary(t *testing.T) {
// t.Parallel()

Expand Down