From 1751abbc2f3ab195213f8026495973faffb2d9cd Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 27 Apr 2026 02:38:05 -0700 Subject: [PATCH 01/16] update cockroachdb/apd to v3 --- go.mod | 2 +- go.sum | 5 ++--- postgres/parser/encoding/decimal.go | 2 +- postgres/parser/encoding/encoding.go | 2 +- postgres/parser/json/encode.go | 2 +- postgres/parser/json/json.go | 2 +- postgres/parser/sem/tree/datum.go | 4 ++-- postgres/parser/sem/tree/decimal.go | 2 +- server/ast/expr.go | 2 +- 9 files changed, 11 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 40c7ff0ceb..83e0a75f23 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.26.2 require ( github.com/PuerkitoBio/goquery v1.8.1 - github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a + github.com/cockroachdb/apd/v3 v3.2.3 github.com/cockroachdb/errors v1.7.5 github.com/dolthub/dolt/go v0.40.5-0.20260424225502-433406d3ff75 github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 diff --git a/go.sum b/go.sum index 022dddfbb6..392cc5756b 100644 --- a/go.sum +++ b/go.sum @@ -205,8 +205,8 @@ github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/T github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a h1:9VFe4R5FRCUyidB1rdm3XdCRVuD/75P7Y4PtzEGhEE4= -github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a/go.mod h1:DDxRlzC2lo3/vSlmSoS7JkqbbrARPuFOGr0B9pvN3Gw= +github.com/cockroachdb/apd/v3 v3.2.3 h1:4Zx+I3R35bFXMnltzmjP79i2cravE4jTRL6ps9Aux80= +github.com/cockroachdb/apd/v3 v3.2.3/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= github.com/cockroachdb/datadriven v1.0.0/go.mod h1:5Ib8Meh+jk1RlHIXej6Pzevx/NLlNvQB9pmSBZErGA4= github.com/cockroachdb/errors v1.6.1/go.mod h1:tm6FTP5G81vwJ5lC0SizQo374JNCOPrHyXGitRJoDqM= github.com/cockroachdb/errors v1.7.5 h1:ptyO1BLW+sBxwBTSKJfS6kGzYCVKhI7MyBhoXAnPIKM= @@ -670,7 +670,6 @@ github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTw github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/postgres/parser/encoding/decimal.go b/postgres/parser/encoding/decimal.go index 43c1e62526..245909e8cb 100644 --- a/postgres/parser/encoding/decimal.go +++ b/postgres/parser/encoding/decimal.go @@ -34,7 +34,7 @@ import ( "math/big" "unsafe" - "github.com/cockroachdb/apd/v2" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" ) diff --git a/postgres/parser/encoding/encoding.go b/postgres/parser/encoding/encoding.go index 65d4487c5d..c827013f69 100644 --- a/postgres/parser/encoding/encoding.go +++ b/postgres/parser/encoding/encoding.go @@ -30,7 +30,7 @@ import ( "fmt" "unsafe" - "github.com/cockroachdb/apd/v2" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/doltgresql/postgres/parser/uuid" diff --git a/postgres/parser/json/encode.go b/postgres/parser/json/encode.go index f3e0fe6c5f..780bc3fecc 100644 --- a/postgres/parser/json/encode.go +++ b/postgres/parser/json/encode.go @@ -25,7 +25,7 @@ package json import ( - "github.com/cockroachdb/apd/v2" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/doltgresql/postgres/parser/encoding" diff --git a/postgres/parser/json/json.go b/postgres/parser/json/json.go index 0cea86b341..9554384afe 100644 --- a/postgres/parser/json/json.go +++ b/postgres/parser/json/json.go @@ -36,7 +36,7 @@ import ( "unicode/utf8" "unsafe" - "github.com/cockroachdb/apd/v2" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/doltgresql/postgres/parser/encoding" diff --git a/postgres/parser/sem/tree/datum.go b/postgres/parser/sem/tree/datum.go index 7e5ef5d916..0db0d4ed61 100644 --- a/postgres/parser/sem/tree/datum.go +++ b/postgres/parser/sem/tree/datum.go @@ -34,7 +34,7 @@ import ( "unicode" "unsafe" - "github.com/cockroachdb/apd/v2" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" @@ -563,7 +563,7 @@ func (d *DDecimal) IsComposite() bool { // Check if d is divisible by 10. var r big.Int - r.Rem(&d.Decimal.Coeff, bigTen) + r.Rem((&d.Decimal.Coeff).MathBigInt(), bigTen) return r.Sign() == 0 } diff --git a/postgres/parser/sem/tree/decimal.go b/postgres/parser/sem/tree/decimal.go index 7eef146cfa..01526aaffc 100644 --- a/postgres/parser/sem/tree/decimal.go +++ b/postgres/parser/sem/tree/decimal.go @@ -28,7 +28,7 @@ import ( "fmt" "math" - "github.com/cockroachdb/apd/v2" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/doltgresql/postgres/parser/pgcode" "github.com/dolthub/doltgresql/postgres/parser/pgerror" diff --git a/server/ast/expr.go b/server/ast/expr.go index 35344fd2bf..4c2a742850 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -525,7 +525,7 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) { bigInt = bigInt.Neg(bigInt) } return vitess.InjectedExpr{ - Expression: pgexprs.NewRawLiteralNumeric(decimal.NewFromBigInt(bigInt, node.Exponent)), + Expression: pgexprs.NewRawLiteralNumeric(decimal.NewFromBigInt(bigInt.MathBigInt(), node.Exponent)), }, nil case *tree.DEnum: return nil, errors.Errorf("the statement is not yet supported") From 298edecd0f39adf60eedca104c67ed5406df2f10 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 27 Apr 2026 10:12:41 -0700 Subject: [PATCH 02/16] update decimal.Decimal to apd.Decimal --- go.mod | 2 +- go.sum | 4 +- server/analyzer/type_sanitizer.go | 3 +- server/ast/expr.go | 10 +- server/cast/float32.go | 9 +- server/cast/float64.go | 10 +- server/cast/int16.go | 4 +- server/cast/int32.go | 4 +- server/cast/int64.go | 4 +- server/cast/jsonb.go | 40 +- server/cast/numeric.go | 40 +- server/expression/gms_cast.go | 11 +- server/expression/literal.go | 16 +- server/functions/abs.go | 6 +- server/functions/binary/divide.go | 13 +- server/functions/binary/equal.go | 4 +- server/functions/binary/greater.go | 4 +- server/functions/binary/greater_equal.go | 4 +- server/functions/binary/less.go | 4 +- server/functions/binary/less_equal.go | 4 +- server/functions/binary/minus.go | 10 +- server/functions/binary/mod.go | 10 +- server/functions/binary/multiply.go | 13 +- server/functions/binary/not_equal.go | 4 +- server/functions/binary/plus.go | 22 +- server/functions/ceil.go | 11 +- server/functions/date_part.go | 64 ++- server/functions/div.go | 43 +- server/functions/exp.go | 12 +- server/functions/extract.go | 108 +++-- server/functions/factorial.go | 12 +- server/functions/floor.go | 12 +- server/functions/generate_series.go | 44 +- server/functions/ln.go | 18 +- server/functions/log.go | 58 +-- server/functions/min_scale.go | 8 +- server/functions/mod.go | 33 +- server/functions/numeric.go | 104 +---- server/functions/power.go | 69 +++- server/functions/round.go | 35 +- server/functions/sign.go | 7 +- server/functions/sqrt.go | 23 +- server/functions/trim_scale.go | 6 +- server/functions/trunc.go | 29 +- server/functions/unary/minus.go | 6 +- server/functions/width_bucket.go | 51 ++- server/types/json_document.go | 31 +- server/types/numeric.go | 251 ++++++++++- server/types/type.go | 10 +- server/types/typeinfo.go | 4 +- .../function_coverage/generators.go | 3 + .../output/framework_test.go | 28 +- testing/go/coercion_test.go | 80 ++++ testing/go/framework.go | 27 +- testing/go/functions_test.go | 105 +++++ testing/go/operators_test.go | 389 ++++++++++++++++++ testing/go/types_test.go | 24 ++ testing/go/wire_test.go | 37 +- 58 files changed, 1531 insertions(+), 466 deletions(-) diff --git a/go.mod b/go.mod index 83e0a75f23..a82874d589 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/dolthub/dolt/go v0.40.5-0.20260424225502-433406d3ff75 github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.20.1-0.20260424221156-62a3b12d1f59 + github.com/dolthub/go-mysql-server v0.20.1-0.20260427164548-6bc0cfa4e92a github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20260424215137-ec6bd432b0be diff --git a/go.sum b/go.sum index 392cc5756b..d39e9e94da 100644 --- a/go.sum +++ b/go.sum @@ -255,8 +255,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866 h1:U6gSf5I0e6h6GP1/5Sa7D2lWW1CWfcVPtY5wkyHq6jY= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE= -github.com/dolthub/go-mysql-server v0.20.1-0.20260424221156-62a3b12d1f59 h1:rLs9hbZmhQ3G2myEL9VzKpY9E08s/tOLxvL6FvUxdB4= -github.com/dolthub/go-mysql-server v0.20.1-0.20260424221156-62a3b12d1f59/go.mod h1:O43PPQxMeNi7O5idizj6Itf2TZcSYfI/0WU24xhXg4I= +github.com/dolthub/go-mysql-server v0.20.1-0.20260427164548-6bc0cfa4e92a h1:l5b092QxSRIrWey7P7KhhEVfyuQUY8AqfIokKJuAHQQ= +github.com/dolthub/go-mysql-server v0.20.1-0.20260427164548-6bc0cfa4e92a/go.mod h1:XFCNmCSCXcQ6KNZr/FHoERkYpgMEMyIs9CTGSsiXRz4= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20260414231531-5f031e3e9037 h1:oIW9HwuWrhxv+4HZxA+QQSKHLqWFyXZ2FmNjUYwkdiM= diff --git a/server/analyzer/type_sanitizer.go b/server/analyzer/type_sanitizer.go index 276ba6d82f..df6e7480e7 100644 --- a/server/analyzer/type_sanitizer.go +++ b/server/analyzer/type_sanitizer.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" @@ -177,7 +178,7 @@ func typeSanitizerLiterals(ctx *sql.Context, gmsLiteral *expression.Literal) (sq if !ok { return nil, transform.NewTree, errors.Errorf("SANITIZER: expected decimal type: %T", gmsLiteral.Value()) } - return pgexprs.NewRawLiteralNumeric(dec), transform.NewTree, nil + return pgexprs.NewRawLiteralNumeric(*apd.New(dec.Coefficient().Int64(), dec.Exponent())), transform.NewTree, nil case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP: newVal, _, err := types.Datetime.Convert(ctx, gmsLiteral.Value()) if err != nil { diff --git a/server/ast/expr.go b/server/ast/expr.go index 4c2a742850..244a04c29b 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -23,7 +23,6 @@ import ( "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" - "github.com/shopspring/decimal" "github.com/sirupsen/logrus" "github.com/dolthub/doltgresql/core/id" @@ -518,15 +517,8 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) { Expression: pgexprs.NewRawLiteralDate(t), }, nil case *tree.DDecimal: - // TODO: should we use apd.Decimal for Numeric type values? - // |Coeff| is always positive, so need to |Negative| to negate the big.Int - bigInt := &node.Coeff - if node.Negative { - bigInt = bigInt.Neg(bigInt) - } return vitess.InjectedExpr{ - Expression: pgexprs.NewRawLiteralNumeric(decimal.NewFromBigInt(bigInt.MathBigInt(), node.Exponent)), - }, nil + Expression: pgexprs.NewRawLiteralNumeric(node.Decimal)}, nil case *tree.DEnum: return nil, errors.Errorf("the statement is not yet supported") case *tree.DFloat: diff --git a/server/cast/float32.go b/server/cast/float32.go index d86bc627cc..0692f111d8 100644 --- a/server/cast/float32.go +++ b/server/cast/float32.go @@ -17,9 +17,9 @@ package cast import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -70,7 +70,12 @@ func float32Assignment() { FromType: pgtypes.Float32, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return pgtypes.GetNumericValueWithTypmod(decimal.NewFromFloat(float64(val.(float32))), targetType.GetAttTypMod()) + d := new(apd.Decimal) + err := d.Scan(float64(val.(float32))) + if err != nil { + return nil, err + } + return pgtypes.GetNumericValueWithTypmod(*d, targetType.GetAttTypMod()) }, }) } diff --git a/server/cast/float64.go b/server/cast/float64.go index aedae498bd..cba265c9de 100644 --- a/server/cast/float64.go +++ b/server/cast/float64.go @@ -17,9 +17,9 @@ package cast import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -76,7 +76,13 @@ func float64Assignment() { FromType: pgtypes.Float64, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return pgtypes.GetNumericValueWithTypmod(decimal.NewFromFloat(val.(float64)), targetType.GetAttTypMod()) + d := new(apd.Decimal) + d.String() + err := d.Scan(val.(float64)) + if err != nil { + return nil, err + } + return pgtypes.GetNumericValueWithTypmod(*d, targetType.GetAttTypMod()) }, }) } diff --git a/server/cast/int16.go b/server/cast/int16.go index 5ac3fffa2d..03e18e58a5 100644 --- a/server/cast/int16.go +++ b/server/cast/int16.go @@ -15,8 +15,8 @@ package cast import ( + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" @@ -62,7 +62,7 @@ func int16Implicit() { FromType: pgtypes.Int16, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return decimal.NewFromInt(int64(val.(int16))), nil + return pgtypes.GetNumericValueWithTypmod(*apd.New(int64(val.(int16)), 0), targetType.GetAttTypMod()) }, }) framework.MustAddImplicitTypeCast(framework.TypeCast{ diff --git a/server/cast/int32.go b/server/cast/int32.go index cbf46fb5c0..55df7aa07e 100644 --- a/server/cast/int32.go +++ b/server/cast/int32.go @@ -15,10 +15,10 @@ package cast import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" @@ -84,7 +84,7 @@ func int32Implicit() { FromType: pgtypes.Int32, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return decimal.NewFromInt(int64(val.(int32))), nil + return pgtypes.GetNumericValueWithTypmod(*apd.New(int64(val.(int32)), 0), targetType.GetAttTypMod()) }, }) framework.MustAddImplicitTypeCast(framework.TypeCast{ diff --git a/server/cast/int64.go b/server/cast/int64.go index bbfdbf5ca9..b56700d3ee 100644 --- a/server/cast/int64.go +++ b/server/cast/int64.go @@ -15,10 +15,10 @@ package cast import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" @@ -75,7 +75,7 @@ func int64Implicit() { FromType: pgtypes.Int64, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return decimal.NewFromInt(val.(int64)), nil + return pgtypes.GetNumericValueWithTypmod(*apd.New(val.(int64), 0), targetType.GetAttTypMod()) }, }) framework.MustAddImplicitTypeCast(framework.TypeCast{ diff --git a/server/cast/jsonb.go b/server/cast/jsonb.go index db236c68f7..2f9b3614c4 100644 --- a/server/cast/jsonb.go +++ b/server/cast/jsonb.go @@ -15,10 +15,10 @@ package cast import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -66,7 +66,8 @@ func jsonbExplicit() { case pgtypes.JsonValueString: return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: - f, _ := decimal.Decimal(value).Float64() + d := apd.Decimal(value) + f, _ := d.Float64() return float32(f), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) @@ -89,7 +90,8 @@ func jsonbExplicit() { case pgtypes.JsonValueString: return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: - f, _ := decimal.Decimal(value).Float64() + d := apd.Decimal(value) + f, _ := d.Float64() return f, nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) @@ -112,11 +114,15 @@ func jsonbExplicit() { case pgtypes.JsonValueString: return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: - d := decimal.Decimal(value) - if d.LessThan(pgtypes.NumericValueMinInt16) || d.GreaterThan(pgtypes.NumericValueMaxInt16) { + d := apd.Decimal(value) + if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Errorf("smallint out of range") } - return int16(d.IntPart()), nil + i, err := d.Int64() + if err != nil { + return nil, err + } + return int16(i), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -138,11 +144,15 @@ func jsonbExplicit() { case pgtypes.JsonValueString: return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: - d := decimal.Decimal(value) - if d.LessThan(pgtypes.NumericValueMinInt32) || d.GreaterThan(pgtypes.NumericValueMaxInt32) { + d := apd.Decimal(value) + if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Errorf("integer out of range") } - return int32(d.IntPart()), nil + i, err := d.Int64() + if err != nil { + return nil, err + } + return int32(i), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -164,11 +174,15 @@ func jsonbExplicit() { case pgtypes.JsonValueString: return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: - d := decimal.Decimal(value) - if d.LessThan(pgtypes.NumericValueMinInt64) || d.GreaterThan(pgtypes.NumericValueMaxInt64) { + d := apd.Decimal(value) + if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Errorf("bigint out of range") } - return int64(d.IntPart()), nil + i, err := d.Int64() + if err != nil { + return nil, err + } + return int64(i), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -190,7 +204,7 @@ func jsonbExplicit() { case pgtypes.JsonValueString: return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: - return decimal.Decimal(value), nil + return apd.Decimal(value), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: diff --git a/server/cast/numeric.go b/server/cast/numeric.go index 8eab4ef909..761d38b506 100644 --- a/server/cast/numeric.go +++ b/server/cast/numeric.go @@ -15,10 +15,10 @@ package cast import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -36,33 +36,45 @@ func numericAssignment() { FromType: pgtypes.Numeric, ToType: pgtypes.Int16, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(decimal.Decimal) - if d.LessThan(pgtypes.NumericValueMinInt16) || d.GreaterThan(pgtypes.NumericValueMaxInt16) { + d := val.(apd.Decimal) + if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") } - return int16(d.Round(0).IntPart()), nil + i, err := d.Int64() + if err != nil { + return nil, err + } + return int16(i), nil }, }) framework.MustAddAssignmentTypeCast(framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Int32, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(decimal.Decimal) - if d.LessThan(pgtypes.NumericValueMinInt32) || d.GreaterThan(pgtypes.NumericValueMaxInt32) { + d := val.(apd.Decimal) + if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") } - return int32(d.Round(0).IntPart()), nil + i, err := d.Int64() + if err != nil { + return nil, err + } + return int32(i), nil }, }) framework.MustAddAssignmentTypeCast(framework.TypeCast{ FromType: pgtypes.Numeric, ToType: pgtypes.Int64, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(decimal.Decimal) - if d.LessThan(pgtypes.NumericValueMinInt64) || d.GreaterThan(pgtypes.NumericValueMaxInt64) { + d := val.(apd.Decimal) + if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range") } - return int64(d.Round(0).IntPart()), nil + i, err := d.Int64() + if err != nil { + return nil, err + } + return int64(i), nil }, }) } @@ -73,7 +85,8 @@ func numericImplicit() { FromType: pgtypes.Numeric, ToType: pgtypes.Float32, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - f, _ := val.(decimal.Decimal).Float64() + d := val.(apd.Decimal) + f, _ := d.Float64() return float32(f), nil }, }) @@ -81,7 +94,8 @@ func numericImplicit() { FromType: pgtypes.Numeric, ToType: pgtypes.Float64, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - f, _ := val.(decimal.Decimal).Float64() + d := val.(apd.Decimal) + f, _ := d.Float64() return f, nil }, }) @@ -89,7 +103,7 @@ func numericImplicit() { FromType: pgtypes.Numeric, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return pgtypes.GetNumericValueWithTypmod(val.(decimal.Decimal), targetType.GetAttTypMod()) + return pgtypes.GetNumericValueWithTypmod(val.(apd.Decimal), targetType.GetAttTypMod()) }, }) } diff --git a/server/expression/gms_cast.go b/server/expression/gms_cast.go index f4ff1b384e..0ea167e5c5 100644 --- a/server/expression/gms_cast.go +++ b/server/expression/gms_cast.go @@ -18,6 +18,7 @@ import ( "encoding/json" "time" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/go-mysql-server/sql" @@ -122,10 +123,11 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - if _, ok := newVal.(decimal.Decimal); !ok { + dec, ok := newVal.(decimal.Decimal) + if !ok { return nil, errors.Errorf("GMSCast expected type `decimal.Decimal`, got `%T`", val) } - return newVal, nil + return *apd.New(dec.CoefficientInt64(), dec.Exponent()), nil case query.Type_FLOAT32: newVal, _, err := types.Float32.Convert(ctx, val) if err != nil { @@ -149,10 +151,11 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - if _, ok := newVal.(decimal.Decimal); !ok { + dec, ok := newVal.(decimal.Decimal) + if !ok { return nil, errors.Errorf("GMSCast expected type `decimal.Decimal`, got `%T`", val) } - return newVal, nil + return *apd.New(dec.CoefficientInt64(), dec.Exponent()), nil case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP: if val, ok := val.(time.Time); ok { return val, nil diff --git a/server/expression/literal.go b/server/expression/literal.go index b9444fb407..7e991ffc65 100644 --- a/server/expression/literal.go +++ b/server/expression/literal.go @@ -18,9 +18,9 @@ import ( "strconv" "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -32,7 +32,11 @@ import ( // NewNumericLiteral returns a new *expression.Literal containing a NUMERIC value. func NewNumericLiteral(numericValue string) (*expression.Literal, error) { - d, err := decimal.NewFromString(numericValue) + //TODO: should use the input function of the type + d, err := pgtypes.GetNumericValueFromStringWithTypmod(numericValue, -1) + if err != nil { + return nil, err + } return expression.NewLiteral(d, pgtypes.Numeric), err } @@ -48,8 +52,7 @@ func NewIntegerLiteral(integerValue string) (*expression.Literal, error) { } } else { // If we errored the first time, then we'll assume it's a NUMERIC value - d, err := decimal.NewFromString(integerValue) - return expression.NewLiteral(d, pgtypes.Numeric), err + return NewNumericLiteral(integerValue) } } @@ -105,7 +108,7 @@ func NewRawLiteralFloat64(val float64) *expression.Literal { } // NewRawLiteralNumeric returns a new *expression.Literal containing a decimal.Decimal value. -func NewRawLiteralNumeric(val decimal.Decimal) *expression.Literal { +func NewRawLiteralNumeric(val apd.Decimal) *expression.Literal { return expression.NewLiteral(val, pgtypes.Numeric) } @@ -173,7 +176,8 @@ func ToVitessLiteral(l *expression.Literal) *vitess.SQLVal { case pgtypes.Int64.ID: return vitess.NewIntVal([]byte(strconv.FormatInt(l.Value().(int64), 10))) case pgtypes.Numeric.ID: - return vitess.NewFloatVal([]byte(l.Value().(decimal.Decimal).String())) + d := l.Value().(apd.Decimal) + return vitess.NewFloatVal([]byte(d.String())) case pgtypes.Text.ID: return vitess.NewStrVal([]byte(l.Value().(string))) case pgtypes.Unknown.ID: diff --git a/server/functions/abs.go b/server/functions/abs.go index b60ff4a153..fa4cbf573b 100644 --- a/server/functions/abs.go +++ b/server/functions/abs.go @@ -15,8 +15,8 @@ package functions import ( + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -83,6 +83,8 @@ var abs_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - return val1.(decimal.Decimal).Abs(), nil + dec := val1.(apd.Decimal) + abs := dec.Abs(&dec) + return *abs, nil }, } diff --git a/server/functions/binary/divide.go b/server/functions/binary/divide.go index 895d77870f..88307fa23e 100644 --- a/server/functions/binary/divide.go +++ b/server/functions/binary/divide.go @@ -16,11 +16,10 @@ package binary import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/postgres/parser/duration" + "github.com/dolthub/doltgresql/server/functions" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -285,19 +284,11 @@ var interval_div = framework.Function2{ Callable: interval_div_callable, } -// numeric_div_callable is the callable logic for the numeric_div function. -func numeric_div_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - if val2.(decimal.Decimal).Equal(decimal.Zero) { - return nil, errors.Errorf("division by zero") - } - return val1.(decimal.Decimal).Div(val2.(decimal.Decimal)), nil -} - // numeric_div represents the PostgreSQL function of the same name, taking the same parameters. var numeric_div = framework.Function2{ Name: "numeric_div", Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: numeric_div_callable, + Callable: functions.NumericDivCallable, } diff --git a/server/functions/binary/equal.go b/server/functions/binary/equal.go index 49c6818b6c..244db200dc 100644 --- a/server/functions/binary/equal.go +++ b/server/functions/binary/equal.go @@ -17,8 +17,8 @@ package binary import ( "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -459,7 +459,7 @@ var nameeqtext = framework.Function2{ // numeric_eq_callable is the callable logic for the numeric_eq function. func numeric_eq_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(decimal.Decimal), val2.(decimal.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) return res == 0, err } diff --git a/server/functions/binary/greater.go b/server/functions/binary/greater.go index b2f58588f8..e953c2a808 100644 --- a/server/functions/binary/greater.go +++ b/server/functions/binary/greater.go @@ -18,8 +18,8 @@ import ( "cmp" "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -385,7 +385,7 @@ var numeric_gt = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(decimal.Decimal), val2.(decimal.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) return res == 1, err }, } diff --git a/server/functions/binary/greater_equal.go b/server/functions/binary/greater_equal.go index 8985e8e57a..c35ea6c707 100644 --- a/server/functions/binary/greater_equal.go +++ b/server/functions/binary/greater_equal.go @@ -18,8 +18,8 @@ import ( "cmp" "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -385,7 +385,7 @@ var numeric_ge = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(decimal.Decimal), val2.(decimal.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) return res >= 0, err }, } diff --git a/server/functions/binary/less.go b/server/functions/binary/less.go index 70ae88372f..d8fed02a42 100644 --- a/server/functions/binary/less.go +++ b/server/functions/binary/less.go @@ -18,8 +18,8 @@ import ( "cmp" "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -385,7 +385,7 @@ var numeric_lt = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(decimal.Decimal), val2.(decimal.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) return res == -1, err }, } diff --git a/server/functions/binary/less_equal.go b/server/functions/binary/less_equal.go index 3b15c1ca4c..8f4eb2a15f 100644 --- a/server/functions/binary/less_equal.go +++ b/server/functions/binary/less_equal.go @@ -18,8 +18,8 @@ import ( "cmp" "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -385,7 +385,7 @@ var numeric_le = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(decimal.Decimal), val2.(decimal.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) return res <= 0, err }, } diff --git a/server/functions/binary/minus.go b/server/functions/binary/minus.go index 4d0baba6cd..a2937749ee 100644 --- a/server/functions/binary/minus.go +++ b/server/functions/binary/minus.go @@ -18,9 +18,9 @@ import ( "math" "time" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/timeofday" @@ -240,7 +240,13 @@ var numeric_sub = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - return val1.(decimal.Decimal).Sub(val2.(decimal.Decimal)), nil + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + _, err := pgtypes.BaseContext.Sub(&num1, &num1, &num2) + if err != nil { + return nil, err + } + return num1, nil }, } diff --git a/server/functions/binary/mod.go b/server/functions/binary/mod.go index 1564716b66..60e08f1649 100644 --- a/server/functions/binary/mod.go +++ b/server/functions/binary/mod.go @@ -16,10 +16,9 @@ package binary import ( "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" + "github.com/dolthub/doltgresql/server/functions" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -83,10 +82,5 @@ var numeric_mod = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - if val2.(decimal.Decimal).Equal(decimal.Zero) { - return nil, errors.Errorf("division by zero") - } - return val1.(decimal.Decimal).Mod(val2.(decimal.Decimal)), nil - }, + Callable: functions.NumericModCallable, } diff --git a/server/functions/binary/multiply.go b/server/functions/binary/multiply.go index 0b1db95af4..5377e7f2b0 100644 --- a/server/functions/binary/multiply.go +++ b/server/functions/binary/multiply.go @@ -17,9 +17,9 @@ package binary import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/server/functions/framework" @@ -225,7 +225,16 @@ var numeric_mul = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - return val1.(decimal.Decimal).Mul(val2.(decimal.Decimal)), nil + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if (num1.Form == apd.Infinite || num2.Form == apd.Infinite) && (num1.IsZero() || num2.IsZero()) { + return pgtypes.NumericNaN, nil + } + _, err := pgtypes.BaseContext.Mul(&num1, &num1, &num2) + if err != nil { + return nil, err + } + return num1, nil }, } diff --git a/server/functions/binary/not_equal.go b/server/functions/binary/not_equal.go index 63a308c278..88f2854e73 100644 --- a/server/functions/binary/not_equal.go +++ b/server/functions/binary/not_equal.go @@ -17,8 +17,8 @@ package binary import ( "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -386,7 +386,7 @@ var numeric_ne = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(decimal.Decimal), val2.(decimal.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) return res != 0, err }, } diff --git a/server/functions/binary/plus.go b/server/functions/binary/plus.go index b58bdb876f..e3873de328 100644 --- a/server/functions/binary/plus.go +++ b/server/functions/binary/plus.go @@ -18,9 +18,9 @@ import ( "math" "time" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/timeofday" @@ -376,7 +376,25 @@ var interval_pl_timestamptz = framework.Function2{ // numeric_add_callable is the callable logic for the numeric_add function. func numeric_add_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - return val1.(decimal.Decimal).Add(val2.(decimal.Decimal)), nil + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite && num2.Negative) || + (num2.Form == apd.Infinite && num1.Form == apd.Infinite && num1.Negative) { + return pgtypes.NumericNaN, nil + } + if num1.Form == apd.Infinite || num2.Form == apd.Infinite { + if num1.Negative || num2.Negative { + return pgtypes.NumericNegInf, nil + } + return pgtypes.NumericInf, nil + } + + _, err := pgtypes.BaseContext.Add(&num1, &num1, &num2) + if err != nil { + return nil, err + } + return num1, nil } // numeric_add represents the PostgreSQL function of the same name, taking the same parameters. diff --git a/server/functions/ceil.go b/server/functions/ceil.go index ff2cfa6aa4..fc1e10ba17 100644 --- a/server/functions/ceil.go +++ b/server/functions/ceil.go @@ -17,8 +17,8 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -54,7 +54,12 @@ var ceil_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - return val1.(decimal.Decimal).Ceil(), nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + dec := val.(apd.Decimal) + _, err := pgtypes.BaseContext.Ceil(&dec, &dec) + if err != nil { + return nil, err + } + return dec, nil }, } diff --git a/server/functions/date_part.go b/server/functions/date_part.go index 2de104f0b0..36facc27e5 100644 --- a/server/functions/date_part.go +++ b/server/functions/date_part.go @@ -18,8 +18,9 @@ import ( "strings" "time" + "github.com/cockroachdb/apd/v3" + cerrors "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/timeofday" @@ -191,14 +192,20 @@ var date_part_text_interval = framework.Function2{ // This mirrors the exact logic from extract_text_interval switch strings.ToLower(field) { case "century", "centuries": - result := decimal.NewFromFloat(float64(dur.Months) / 12 / 100).Floor() - f, _ := result.Float64() + dec, err := numericFloor(float64(dur.Months) / 12 / 100) + if err != nil { + return nil, err + } + f, _ := dec.Float64() return f, nil case "day", "days": return float64(dur.Days), nil case "decade", "decades": - result := decimal.NewFromFloat(float64(dur.Months) / 12 / 10).Floor() - f, _ := result.Float64() + dec, err := numericFloor(float64(dur.Months) / 12 / 10) + if err != nil { + return nil, err + } + f, _ := dec.Float64() return f, nil case "epoch": epoch := float64(duration.SecsPerDay*duration.DaysPerMonth*dur.Months) + float64(duration.SecsPerDay*dur.Days) + @@ -206,16 +213,22 @@ var date_part_text_interval = framework.Function2{ return epoch, nil case "hour", "hours": hours := float64(dur.Nanos()) / float64(NanosPerSec*duration.SecsPerHour) - result := decimal.NewFromFloat(hours).Floor() - f, _ := result.Float64() + dec, err := numericFloor(hours) + if err != nil { + return nil, err + } + f, _ := dec.Float64() return f, nil case "microsecond", "microseconds": secondsInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerMinute) microseconds := float64(secondsInNanos) / float64(NanosPerMicro) return microseconds, nil case "millennium", "millenniums": - result := decimal.NewFromFloat(float64(dur.Months) / 12 / 1000).Floor() - f, _ := result.Float64() + dec, err := numericFloor(float64(dur.Months) / 12 / 1000) + if err != nil { + return nil, err + } + f, _ := dec.Float64() return f, nil case "millisecond", "milliseconds": secondsInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerMinute) @@ -224,8 +237,11 @@ var date_part_text_interval = framework.Function2{ case "minute", "minutes": minutesInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerHour) minutes := float64(minutesInNanos) / float64(NanosPerSec*duration.SecsPerMinute) - result := decimal.NewFromFloat(minutes).Floor() - f, _ := result.Float64() + dec, err := numericFloor(minutes) + if err != nil { + return nil, err + } + f, _ := dec.Float64() return f, nil case "month", "months": return float64(dur.Months % 12), nil @@ -236,8 +252,11 @@ var date_part_text_interval = framework.Function2{ seconds := float64(secondsInNanos) / float64(NanosPerSec) return seconds, nil case "year", "years": - result := decimal.NewFromFloat(float64(dur.Months) / 12).Floor() - f, _ := result.Float64() + dec, err := numericFloor(float64(dur.Months) / 12) + if err != nil { + return nil, err + } + f, _ := dec.Float64() return f, nil case "dow", "doy", "isodow", "isoyear", "julian", "timezone", "timezone_hour", "timezone_minute", "week": return nil, ErrUnitNotSupported.New(field, "interval") @@ -246,3 +265,22 @@ var date_part_text_interval = framework.Function2{ } }, } + +func numericFloor(val any) (apd.Decimal, error) { + switch val.(type) { + case int64, float64: + // expects these types to Scan from + default: + return apd.Decimal{}, cerrors.Errorf("invalid type for numeric convert: %T", val) + } + dec := new(apd.Decimal) + err := dec.Scan(val) + if err != nil { + return apd.Decimal{}, err + } + _, err = pgtypes.BaseContext.Floor(dec, dec) + if err != nil { + return apd.Decimal{}, err + } + return *dec, nil +} diff --git a/server/functions/div.go b/server/functions/div.go index 10adf7e7fb..d3aae61072 100644 --- a/server/functions/div.go +++ b/server/functions/div.go @@ -15,10 +15,9 @@ package functions import ( - "github.com/cockroachdb/errors" - + "github.com/cockroachdb/apd/v3" + errors "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -35,13 +34,33 @@ var div_numeric = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1Interface any, val2Interface any) (any, error) { - val1 := val1Interface.(decimal.Decimal) - val2 := val2Interface.(decimal.Decimal) - if val2.Cmp(decimal.Zero) == 0 { - return nil, errors.Errorf("division by zero") - } - val := val1.Div(val2) - return val.Truncate(0), nil - }, + Callable: NumericDivCallable, +} + +// NumericDivCallable is the callable logic for the numeric_div and div functions. +func NumericDivCallable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { + return pgtypes.NumericNaN, nil + } + if num2.IsZero() { + return nil, errors.Errorf("division by zero") + } + if num1.Form == apd.Infinite { + return num1, nil + } + if num2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + _, err := pgtypes.BaseContext.Quo(&num1, &num1, &num2) + if err != nil { + return nil, err + } + _, err = pgtypes.BaseContext.Quantize(&num1, &num1, -16) + if err != nil { + return nil, err + } + return num1, nil } diff --git a/server/functions/exp.go b/server/functions/exp.go index 1987844334..dadbca2b91 100644 --- a/server/functions/exp.go +++ b/server/functions/exp.go @@ -17,8 +17,8 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -47,10 +47,12 @@ var exp_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + dec := val.(apd.Decimal) + _, err := pgtypes.BaseContext.WithPrecision(32).Exp(&dec, &dec) + if err != nil { + return nil, err } - return val1.(decimal.Decimal).ExpTaylor(32) + return dec, nil }, } diff --git a/server/functions/extract.go b/server/functions/extract.go index 65269b03d2..2a6b833af3 100644 --- a/server/functions/extract.go +++ b/server/functions/extract.go @@ -19,9 +19,9 @@ import ( "strings" "time" + "github.com/cockroachdb/apd/v3" cerrors "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -58,7 +58,7 @@ var extract_text_date = framework.Function2{ "minute", "minutes", "second", "seconds", "timezone", "timezone_hour", "timezone_minute": return nil, ErrUnitNotSupported.New(field, "date") case "epoch": - return decimal.NewFromFloat(float64(dateVal.UnixMicro()) / 1000000), nil + return numeric(float64(dateVal.UnixMicro())/1000000, false, 0) default: return getFieldFromTimeVal(field, dateVal) } @@ -103,11 +103,11 @@ var extract_text_timetz = framework.Function2{ "quarter", "week", "year", "years": return nil, ErrUnitNotSupported.New(field, "time with time zone") case "timezone": - return decimal.NewFromInt(-int64(-currentOffset)), nil + return numeric(-int64(-currentOffset), false, 0) case "timezone_hour": - return decimal.NewFromInt(-int64(-currentOffset / 3600)), nil + return numeric(-int64(-currentOffset/3600), false, 0) case "timezone_minute": - return decimal.NewFromInt(-int64((-currentOffset % 3600) / 60)), nil + return numeric(-int64((-currentOffset%3600)/60), false, 0) default: return getFieldFromTimeVal(field, timetzVal) } @@ -150,13 +150,13 @@ var extract_text_timestamptz = framework.Function2{ switch strings.ToLower(field) { case "timezone": // TODO: postgres seem to use server timezone regardless of input value - return decimal.NewFromInt(-28800), nil + return numeric(int64(-28800), false, 0) case "timezone_hour": // TODO: postgres seem to use server timezone regardless of input value - return decimal.NewFromInt(-8), nil + return numeric(int64(-8), false, 0) case "timezone_minute": // TODO: postgres seem to use server timezone regardless of input value - return decimal.NewFromInt(0), nil + return numeric(int64(0), false, 0) default: return getFieldFromTimeVal(field, tstzVal) } @@ -181,42 +181,42 @@ var extract_text_interval = framework.Function2{ dur := val2.(duration.Duration) switch strings.ToLower(field) { case "century", "centuries": - return decimal.NewFromFloat(math.Floor(float64(dur.Months) / 12 / 100)), nil + return numeric(math.Floor(float64(dur.Months)/12/100), false, 0) case "day", "days": - return decimal.NewFromInt(dur.Days), nil + return numeric(dur.Days, false, 0) case "decade", "decades": - return decimal.NewFromFloat(math.Floor(float64(dur.Months) / 12 / 10)), nil + return numeric(math.Floor(float64(dur.Months)/12/10), false, 0) case "epoch": epoch := float64(duration.SecsPerDay*duration.DaysPerMonth*dur.Months) + float64(duration.SecsPerDay*dur.Days) + (float64(dur.Nanos()) / (NanosPerSec)) - return decimal.NewFromString(decimal.NewFromFloat(epoch).StringFixed(6)) + return numeric(epoch, true, 6) case "hour", "hours": hours := math.Floor(float64(dur.Nanos()) / (NanosPerSec * duration.SecsPerHour)) - return decimal.NewFromFloat(hours), nil + return numeric(hours, false, 0) case "microsecond", "microseconds": secondsInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerMinute) microseconds := float64(secondsInNanos) / NanosPerMicro - return decimal.NewFromFloat(microseconds), nil + return numeric(microseconds, false, 0) case "millennium", "millenniums": - return decimal.NewFromFloat(math.Floor(float64(dur.Months) / 12 / 1000)), nil + return numeric(math.Floor(float64(dur.Months)/12/1000), false, 0) case "millisecond", "milliseconds": secondsInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerMinute) milliseconds := float64(secondsInNanos) / NanosPerMilli - return decimal.NewFromString(decimal.NewFromFloat(milliseconds).StringFixed(3)) + return numeric(milliseconds, true, 3) case "minute", "minutes": minutesInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerHour) minutes := math.Floor(float64(minutesInNanos) / (NanosPerSec * duration.SecsPerMinute)) - return decimal.NewFromFloat(minutes), nil + return numeric(minutes, false, 0) case "month", "months": - return decimal.NewFromInt(dur.Months % 12), nil + return numeric(dur.Months%12, false, 0) case "quarter": - return decimal.NewFromInt((dur.Months%12-1)/3 + 1), nil + return numeric((dur.Months%12-1)/3+1, false, 0) case "second", "seconds": secondsInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerMinute) seconds := float64(secondsInNanos) / NanosPerSec - return decimal.NewFromString(decimal.NewFromFloat(seconds).StringFixed(6)) + return numeric(seconds, true, 6) case "year", "years": - return decimal.NewFromFloat(math.Floor(float64(dur.Months) / 12)), nil + return numeric(math.Floor(float64(dur.Months)/12), false, 0) case "dow", "doy", "isodow", "isoyear", "julian", "timezone", "timezone_hour", "timezone_minute", "week": return nil, ErrUnitNotSupported.New(field, "interval") default: @@ -226,64 +226,86 @@ var extract_text_interval = framework.Function2{ } // getFieldFromTimeVal returns the value for given field extracted from non-interval values. -func getFieldFromTimeVal(field string, tVal time.Time) (decimal.Decimal, error) { +func getFieldFromTimeVal(field string, tVal time.Time) (apd.Decimal, error) { switch strings.ToLower(field) { case "century", "centuries": if year := tVal.Year(); year <= 0 { - return decimal.NewFromFloat(math.Floor(float64(year-1) / 100)), nil + return numeric(math.Floor(float64(year-1)/100), false, 0) } else { - return decimal.NewFromFloat(math.Ceil(float64(year) / 100)), nil + return numeric(math.Ceil(float64(year)/100), false, 0) } case "day", "days": - return decimal.NewFromInt(int64(tVal.Day())), nil + return numeric(int64(tVal.Day()), false, 0) case "decade", "decades": - return decimal.NewFromFloat(math.Floor(float64(tVal.Year()) / 10)), nil + return numeric(math.Floor(float64(tVal.Year())/10), false, 0) case "dow": - return decimal.NewFromInt(int64(tVal.Weekday())), nil + return numeric(int64(tVal.Weekday()), false, 0) case "doy": - return decimal.NewFromInt(int64(tVal.YearDay())), nil + return numeric(int64(tVal.YearDay()), false, 0) case "epoch": - return decimal.NewFromString(decimal.NewFromFloat(float64(tVal.UnixMicro()) / 1000000).StringFixed(6)) + return numeric(float64(tVal.UnixMicro())/1000000, true, 6) case "hour", "hours": - return decimal.NewFromInt(int64(tVal.Hour())), nil + return numeric(int64(tVal.Hour()), false, 0) case "isodow": wd := int64(tVal.Weekday()) if wd == 0 { wd = 7 } - return decimal.NewFromInt(wd), nil + return numeric(wd, false, 0) case "isoyear": year, _ := tVal.ISOWeek() - return decimal.NewFromInt(int64(year)), nil + return numeric(int64(year), false, 0) case "julian": - return decimal.NewFromInt(int64(date2J(tVal.Year(), int(tVal.Month()), tVal.Day()))), nil + return numeric(int64(date2J(tVal.Year(), int(tVal.Month()), tVal.Day())), false, 0) case "microsecond", "microseconds", "usec", "usecs": w := float64(tVal.Second() * 1000000) f := float64(tVal.Nanosecond()) / float64(1000) - return decimal.NewFromFloat(w + f), nil + return numeric(w+f, false, 0) case "millennium", "millenniums": - return decimal.NewFromFloat(math.Ceil(float64(tVal.Year()) / 1000)), nil + return numeric(math.Ceil(float64(tVal.Year())/1000), false, 0) case "millisecond", "milliseconds", "msec", "msecs": w := float64(tVal.Second() * 1000) f := float64(tVal.Nanosecond()) / float64(1000000) - return decimal.NewFromString(decimal.NewFromFloat(w + f).StringFixed(3)) + return numeric(w+f, true, 3) case "minute", "minutes": - return decimal.NewFromInt(int64(tVal.Minute())), nil + return numeric(int64(tVal.Minute()), false, 0) case "month", "months": - return decimal.NewFromInt(int64(tVal.Month())), nil + return numeric(int64(tVal.Month()), false, 0) case "quarter": q := (int(tVal.Month())-1)/3 + 1 - return decimal.NewFromInt(int64(q)), nil + return numeric(int64(q), false, 0) case "second", "seconds": w := float64(tVal.Second()) f := float64(tVal.Nanosecond()) / float64(1000000000) - return decimal.NewFromString(decimal.NewFromFloat(w + f).StringFixed(6)) + return numeric(w+f, true, 6) + case "week": _, week := tVal.ISOWeek() - return decimal.NewFromInt(int64(week)), nil + return numeric(int64(week), false, 0) case "year", "years": - return decimal.NewFromInt(int64(tVal.Year())), nil + return numeric(int64(tVal.Year()), false, 0) + default: + return apd.Decimal{}, cerrors.Errorf("unknown field given: %s", field) + } +} + +func numeric(val any, setScale bool, scale int32) (apd.Decimal, error) { + switch val.(type) { + case int64, float64: + // expects these types to Scan from default: - return decimal.Decimal{}, cerrors.Errorf("unknown field given: %s", field) + return apd.Decimal{}, cerrors.Errorf("invalid type for numeric convert: %T", val) + } + dec := new(apd.Decimal) + err := dec.Scan(val) + if err != nil { + return apd.Decimal{}, err + } + if setScale { + _, err = pgtypes.BaseContext.Quantize(dec, dec, -scale) + if err != nil { + return apd.Decimal{}, err + } } + return *dec, nil } diff --git a/server/functions/factorial.go b/server/functions/factorial.go index 4aa12d16e5..0ef3568ca9 100644 --- a/server/functions/factorial.go +++ b/server/functions/factorial.go @@ -15,10 +15,10 @@ package functions import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -35,15 +35,15 @@ var factorial_int64 = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Int64}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1Interface any) (any, error) { - val1 := val1Interface.(int64) - if val1 < 0 { + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + n := val.(int64) + if n < 0 { return nil, errors.Errorf("factorial of a negative number is undefined") } total := int64(1) - for i := int64(2); i <= val1; i++ { + for i := int64(2); i <= n; i++ { total *= i } - return decimal.NewFromInt(total), nil + return *apd.New(total, 0), nil }, } diff --git a/server/functions/floor.go b/server/functions/floor.go index 3191ada5d8..7e58060fed 100644 --- a/server/functions/floor.go +++ b/server/functions/floor.go @@ -17,8 +17,8 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -47,10 +47,12 @@ var floor_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + dec := val.(apd.Decimal) + _, err := pgtypes.BaseContext.Floor(&dec, &dec) + if err != nil { + return nil, err } - return val1.(decimal.Decimal).Floor(), nil + return dec, nil }, } diff --git a/server/functions/generate_series.go b/server/functions/generate_series.go index 519071a6fb..d7e572557e 100644 --- a/server/functions/generate_series.go +++ b/server/functions/generate_series.go @@ -18,9 +18,9 @@ import ( "io" "time" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/server/functions/framework" @@ -143,10 +143,10 @@ var generate_series_numeric_numeric = framework.Function2{ Strict: true, SRF: true, Callable: func(ctx *sql.Context, t [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - start := val1.(decimal.Decimal) - finish := val2.(decimal.Decimal) + start := val1.(apd.Decimal) + stop := val2.(apd.Decimal) step := numericOne // by default - return numericGenerateSeries(start, finish, step) + return numericGenerateSeries(start, stop, *step) }, } @@ -158,25 +158,43 @@ var generate_series_numeric_numeric_numeric = framework.Function3{ Strict: true, SRF: true, Callable: func(ctx *sql.Context, t [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) { - start := val1.(decimal.Decimal) - finish := val2.(decimal.Decimal) - step := val3.(decimal.Decimal) - return numericGenerateSeries(start, finish, step) + start := val1.(apd.Decimal) + stop := val2.(apd.Decimal) + step := val3.(apd.Decimal) + return numericGenerateSeries(start, stop, step) }, } // numericGenerateSeries returns RowIter for generate_series function results for given numeric values. // This function checks for error of step being zero. -func numericGenerateSeries(start, finish, step decimal.Decimal) (*pgtypes.SetReturningFunctionRowIter, error) { - if step.Equal(decimal.Zero) { +func numericGenerateSeries(start, stop, step apd.Decimal) (*pgtypes.SetReturningFunctionRowIter, error) { + if step.IsZero() { return nil, errStepSizeZero } + if start.Form == apd.NaN { + return nil, errors.Errorf(`start value cannot be NaN`) + } else if start.Form == apd.Infinite { + return nil, errors.Errorf(`start value cannot be infinity`) + } + if stop.Form == apd.NaN { + return nil, errors.Errorf(`stop value cannot be NaN`) + } else if stop.Form == apd.Infinite { + return nil, errors.Errorf(`stop value cannot be infinity`) + } + if step.Form == apd.NaN { + return nil, errors.Errorf(`step value cannot be NaN`) + } else if step.Form == apd.Infinite { + return nil, errors.Errorf(`step value cannot be infinity`) + } return pgtypes.NewSetReturningFunctionRowIter(func(ctx *sql.Context) (sql.Row, error) { defer func() { - start = start.Add(step) + _, err := pgtypes.BaseContext.Add(&start, &start, &step) + if err != nil { + // TODO + panic(err) + } }() - if (step.GreaterThan(decimal.Zero) && start.GreaterThan(finish)) || - (step.LessThan(decimal.Zero) && start.LessThan(finish)) { + if (step.Sign() > 0 && start.Cmp(&stop) > 0) || (step.Sign() < 0 && start.Cmp(&stop) < 0) { return nil, io.EOF } return sql.Row{start}, nil diff --git a/server/functions/ln.go b/server/functions/ln.go index b5aced1afe..69c05371a1 100644 --- a/server/functions/ln.go +++ b/server/functions/ln.go @@ -17,9 +17,9 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -54,16 +54,16 @@ var ln_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil - } - // TODO: add an actual ln for numerics rather than relying on float64 - f, _ := val1.(decimal.Decimal).Float64() - if f == 0 { + dec := val1.(apd.Decimal) + if dec.Sign() == 0 { return nil, errors.Errorf("cannot take logarithm of zero") - } else if f < 0 { + } else if dec.Sign() < 0 { return nil, errors.Errorf("cannot take logarithm of a negative number") } - return decimal.NewFromFloat(math.Log(f)), nil + _, err := pgtypes.BaseContext.Ln(&dec, &dec) + if err != nil { + return nil, err + } + return dec, nil }, } diff --git a/server/functions/log.go b/server/functions/log.go index 1bd0b77c4c..d3b60c6090 100644 --- a/server/functions/log.go +++ b/server/functions/log.go @@ -17,9 +17,9 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -55,19 +55,18 @@ var log_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1Interface any) (any, error) { - if val1Interface == nil { - return nil, nil - } - val1 := val1Interface.(decimal.Decimal) - if val1.Equal(decimal.Zero) { + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { + dec := val1.(apd.Decimal) + if dec.IsZero() { return nil, errors.Errorf("cannot take logarithm of zero") - } else if val1.LessThan(decimal.Zero) { + } else if dec.Sign() < 0 { return nil, errors.Errorf("cannot take logarithm of a negative number") } - // TODO: implement log for numeric instead of relying on float64 - f, _ := val1.Float64() - return decimal.NewFromFloat(math.Log10(f)), nil + _, err := pgtypes.BaseContext.Log10(&dec, &dec) + if err != nil { + return nil, err + } + return dec, nil }, } @@ -77,25 +76,32 @@ var log_numeric_numeric = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1Interface any, val2Interface any) (any, error) { - if val1Interface == nil || val2Interface == nil { - return nil, nil - } - val1 := val1Interface.(decimal.Decimal) - val2 := val2Interface.(decimal.Decimal) - if val1.Equal(decimal.Zero) || val2.Equal(decimal.Zero) { + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + base := val1.(apd.Decimal) + num := val2.(apd.Decimal) + if base.IsZero() || num.IsZero() { return nil, errors.Errorf("cannot take logarithm of zero") - } else if val1.LessThan(decimal.Zero) || val2.LessThan(decimal.Zero) { + } else if base.Sign() < 0 || num.Sign() < 0 { return nil, errors.Errorf("cannot take logarithm of a negative number") } - // TODO: implement log for numeric instead of relying on float64 - base, _ := val1.Float64() - num, _ := val2.Float64() - logNum := math.Log(num) - logBase := math.Log(base) - if logBase == 0 { + logBase := new(apd.Decimal) + _, err := pgtypes.BaseContext.Log10(&base, &base) + if err != nil { + return nil, err + } + logNum := new(apd.Decimal) + _, err = pgtypes.BaseContext.Log10(&num, &num) + if err != nil { + return nil, err + } + if logNum.IsZero() { return nil, errors.Errorf("division by zero") } - return decimal.NewFromFloat(logNum / logBase), nil + res := new(apd.Decimal) + _, err = pgtypes.BaseContext.Quo(res, logNum, logBase) + if err != nil { + return nil, err + } + return *res, nil }, } diff --git a/server/functions/min_scale.go b/server/functions/min_scale.go index c0cd908bbb..65105a8cba 100644 --- a/server/functions/min_scale.go +++ b/server/functions/min_scale.go @@ -17,8 +17,8 @@ package functions import ( "strings" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -36,7 +36,11 @@ var min_scale_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - str := val1.(decimal.Decimal).String() + dec := val1.(apd.Decimal) + if dec.Form == apd.NaN || dec.Form == apd.Infinite { + return nil, nil + } + str := dec.String() if idx := strings.Index(str, "."); idx != -1 { str = str[idx+1:] i := len(str) - 1 diff --git a/server/functions/mod.go b/server/functions/mod.go index 49b622d37c..85d1bcf407 100644 --- a/server/functions/mod.go +++ b/server/functions/mod.go @@ -15,10 +15,10 @@ package functions import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -80,10 +80,29 @@ var mod_numeric_numeric = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - if val2.(decimal.Decimal).Cmp(decimal.Zero) == 0 { - return nil, errors.Errorf("division by zero") - } - return val1.(decimal.Decimal).Mod(val2.(decimal.Decimal)), nil - }, + Callable: NumericModCallable, +} + +// NumericModCallable is the callable logic for the numeric_mod and mod functions. +func NumericModCallable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { + return pgtypes.NumericNaN, nil + } + if num2.IsZero() { + return nil, errors.Errorf("division by zero") + } + if num1.Form == apd.Infinite { + return num1, nil + } + if num2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + _, err := pgtypes.BaseContext.Rem(&num1, &num1, &num2) + if err != nil { + return nil, err + } + return num1, nil } diff --git a/server/functions/numeric.go b/server/functions/numeric.go index 1b0b9d7578..47df1875cd 100644 --- a/server/functions/numeric.go +++ b/server/functions/numeric.go @@ -17,15 +17,12 @@ package functions import ( "fmt" "strconv" - "strings" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/jackc/pgtype" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" - "github.com/dolthub/doltgresql/utils" ) // initNumeric registers the functions to the catalog. @@ -47,12 +44,12 @@ var numeric_in = framework.Function3{ Strict: true, Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) { input := val1.(string) - val, err := decimal.NewFromString(strings.TrimSpace(input)) + typmod := val3.(int32) + dec, _, err := apd.NewFromString(input) if err != nil { return nil, pgtypes.ErrInvalidSyntaxForType.New("numeric", input) } - typmod := val3.(int32) - return pgtypes.GetNumericValueWithTypmod(val, typmod) + return pgtypes.GetNumericValueWithTypmod(*dec, typmod) }, } @@ -64,14 +61,14 @@ var numeric_out = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { typ := t[0] - dec := val.(decimal.Decimal) + dec := val.(apd.Decimal) tm := typ.GetAttTypMod() - if tm == -1 { - return dec.StringFixed(dec.Exponent() * -1), nil - } else { - _, s := pgtypes.GetPrecisionAndScaleFromTypmod(tm) - return dec.StringFixed(s), nil + dec, err := pgtypes.GetNumericValueWithTypmod(dec, tm) + if err != nil { + return nil, err } + return dec.Text('f'), nil + //return dec.StringFixed(dec.Exponent() * -1), nil }, } @@ -87,12 +84,9 @@ var numeric_recv = framework.Function3{ return nil, nil } typmod := val3.(int32) - var out pgtype.Numeric - err := out.DecodeBinary(nil, data) - if err != nil { - return nil, err - } - return pgtypes.GetNumericValueWithTypmod(decimal.NewFromBigInt(out.Int, out.Exp), typmod) + // TODO: chekc this doesn't update the original type + newType := *pgtypes.Numeric.WithAttTypMod(typmod) + return newType.DeserializationFunc(ctx, &newType, data) }, } @@ -103,70 +97,8 @@ var numeric_send = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(decimal.Decimal) - writer := utils.NewWireWriter() - // Short-circuit if this is the zero value - if dec.IsZero() { - writer.WriteBytes([]byte{0, 0, 0, 0, 0, 0, 0, 0}) - return writer.BufferData(), nil - } - // There's a way to do this more efficiently, but we can do that work once this becomes a performance issue. - // This is based on the terminology used in Postgres' `numeric.c` file - decStr := dec.String() - isNegative := false - if strings.HasPrefix(decStr, "-") { - isNegative = true - decStr = decStr[1:] - } - // Split the integer and fractional parts - var intPart string - var fractPart string - if idx := strings.Index(decStr, "."); idx != -1 { - intPart = decStr[:idx] - fractPart = decStr[idx+1:] - } else { - intPart = decStr - } - // Find the "dscale", which is the number of digits in the fractional part - typmod := t[0].GetAttTypMod() - var dscale int16 - if typmod != -1 { - _, dscale32 := pgtypes.GetPrecisionAndScaleFromTypmod(typmod) - dscale = int16(dscale32) - } else { - dscale = int16(len(fractPart)) - } - // Pad the integer and fractional parts so that we can take groups of 4 numbers - if intPart == "0" { - intPart = "" - } else if len(intPart)%4 != 0 { - intPart = strings.Repeat("0", 4-(len(intPart)%4)) + intPart - } - if len(fractPart)%4 != 0 { - fractPart = fractPart + strings.Repeat("0", 4-(len(fractPart)%4)) - } - // Write the "ndigits" first, or the number of base-10000 digits - writer.WriteInt16(int16((len(intPart) / 4) + (len(fractPart) / 4))) - // Write the "weight", which is the number of base-10000 digits in the integer part subtracted by 1 - writer.WriteInt16(int16((len(intPart) / 4) - 1)) - // Write the "sign" - if isNegative { - writer.WriteInt16(16384) - } else { - writer.WriteInt16(0) - } - // Write the "dscale" - writer.WriteInt16(dscale) - // Write all of the digits - fullPart := intPart + fractPart - for i := 0; i < len(fullPart); i += 4 { - part, err := strconv.Atoi(fullPart[i : i+4]) - if err != nil { - return nil, err - } - writer.WriteInt16(int16(part)) - } - return writer.BufferData(), nil + dec := val.(apd.Decimal) + return pgtypes.Numeric.SerializationFunc(ctx, t[0], dec) }, } @@ -221,8 +153,8 @@ var numeric_cmp = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - ab := val1.(decimal.Decimal) - bb := val2.(decimal.Decimal) - return int32(ab.Cmp(bb)), nil + ab := val1.(apd.Decimal) + bb := val2.(apd.Decimal) + return int32(pgtypes.NumericCompare(ab, bb)), nil }, } diff --git a/server/functions/power.go b/server/functions/power.go index 335dc722b0..d82aeb42e1 100644 --- a/server/functions/power.go +++ b/server/functions/power.go @@ -17,9 +17,9 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -35,7 +35,7 @@ var ( // errPowerZeroToNegative is an error for raising zero to a negative power in the "power" functions. errPowerZeroToNegative = errors.New("zero raised to a negative power is undefined") // numericOne is equivalent to decimal.NewFromInt(1), but represented as a value for the sake of efficiency. - numericOne = decimal.NewFromInt(1) + numericOne = apd.New(1, 0) ) // power_float64_float64 represents the PostgreSQL function of the same name, taking the same parameters. @@ -61,22 +61,63 @@ var power_numeric_numeric = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - if val1 == nil || val2 == nil { - return nil, nil + dec1 := val1.(apd.Decimal) + dec2 := val2.(apd.Decimal) + if dec1.Form == apd.NaN || dec2.Form == apd.NaN { + return pgtypes.NumericNaN, nil } - d1 := val1.(decimal.Decimal) - d2 := val2.(decimal.Decimal) - if d1.Equal(numericOne) { - return numericOne, nil + if dec1.Form == apd.Infinite && dec1.Negative { + even := dec2.Form == apd.Infinite && !dec2.Negative + if dec2.Form == apd.Finite { + i, err := dec2.Int64() + if err != nil { + return nil, errors.Errorf(`a negative number raised to a non-integer power yields a complex result`) + } + even = i%2 == 0 + } + + if dec2.Sign() > 0 { + // +inf will return neginf == fix!! + if even { + return pgtypes.NumericInf, nil + } + return pgtypes.NumericNegInf, nil + } + if (dec2.Form == apd.Infinite && dec2.Negative) || dec2.Sign() < 0 { + return *apd.New(0, 0), nil + } + return *apd.New(1, 0), nil } - if d1.Equal(decimal.Zero) && d2.Cmp(decimal.Zero) == -1 { - return nil, errPowerZeroToNegative + if dec1.IsZero() { + if dec2.Sign() < 0 { + // includes neg inf + return nil, errPowerZeroToNegative + } + if dec2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + if dec2.Sign() > 0 { + d := *apd.New(0, 0) + _, _ = pgtypes.BaseContext.Quantize(&d, &d, -16) + return d, nil + } } // decimal.Pow() does not handle the zero exponent properly, so we special case it - if d2.Equal(decimal.Zero) { - return numericOne, nil + + if dec2.IsZero() || dec1.Cmp(numericOne) == 0 { + d := *apd.New(1, 0) + _, _ = pgtypes.BaseContext.Quantize(&d, &d, -16) + return d, nil + } + // give enough precision that we can round it to 16 exp + _, err := pgtypes.BaseContext.WithPrecision(17).Pow(&dec1, &dec1, &dec2) + if err != nil { + return nil, err + } + _, err = pgtypes.BaseContext.Quantize(&dec1, &dec1, -16) + if err != nil { + return nil, err } - // TODO: this doesn't handle non-integer exponents - return d1.Pow(d2), nil + return dec1, nil }, } diff --git a/server/functions/round.go b/server/functions/round.go index 7529d51d1e..a4d474850d 100644 --- a/server/functions/round.go +++ b/server/functions/round.go @@ -17,8 +17,8 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -37,11 +37,8 @@ var round_float64 = framework.Function1{ Return: pgtypes.Float64, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Float64}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil - } - return math.RoundToEven(val1.(float64)), nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + return math.RoundToEven(val.(float64)), nil }, } @@ -51,11 +48,17 @@ var round_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + dec := val.(apd.Decimal) + _, err := pgtypes.BaseContext.Round(&dec, &dec) + if err != nil { + return nil, err + } + _, err = pgtypes.BaseContext.Quantize(&dec, &dec, 0) + if err != nil { + return nil, err } - return val1.(decimal.Decimal).Round(0), nil + return dec, nil }, } @@ -66,6 +69,16 @@ var round_numeric_int64 = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Int64}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - return val1.(decimal.Decimal).Round(int32(val2.(int64))), nil + dec := val1.(apd.Decimal) + places := val2.(int64) + _, err := pgtypes.BaseContext.Round(&dec, &dec) + if err != nil { + return nil, err + } + _, err = pgtypes.BaseContext.Quantize(&dec, &dec, int32(-places)) + if err != nil { + return nil, err + } + return dec, nil }, } diff --git a/server/functions/sign.go b/server/functions/sign.go index 17d8af2ffa..daf99dfcc8 100644 --- a/server/functions/sign.go +++ b/server/functions/sign.go @@ -15,8 +15,8 @@ package functions import ( + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" @@ -52,7 +52,8 @@ var sign_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - return decimal.NewFromInt(int64(val1.(decimal.Decimal).Cmp(decimal.Zero))), nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + dec := val.(apd.Decimal) + return *apd.New(int64(dec.Sign()), 0), nil }, } diff --git a/server/functions/sqrt.go b/server/functions/sqrt.go index d8e4efdf5d..adcff81157 100644 --- a/server/functions/sqrt.go +++ b/server/functions/sqrt.go @@ -17,9 +17,9 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -37,14 +37,11 @@ var sqrt_float64 = framework.Function1{ Return: pgtypes.Float64, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Float64}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil - } - if val1.(float64) < 0 { + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + if val.(float64) < 0 { return nil, errors.Errorf("cannot take square root of a negative number") } - return math.Sqrt(val1.(float64)), nil + return math.Sqrt(val.(float64)), nil }, } @@ -54,11 +51,15 @@ var sqrt_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1.(decimal.Decimal).Cmp(decimal.Zero) == -1 { + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + dec := val.(apd.Decimal) + if dec.Sign() < 0 { return nil, errors.Errorf("cannot take square root of a negative number") } - // TODO: decimal's Pow function does not work correctly using an exponent of 0.5, need to fix - return decimal.NewFromFloat(math.Sqrt(val1.(decimal.Decimal).InexactFloat64())), nil + _, err := pgtypes.BaseContext.Sqrt(&dec, &dec) + if err != nil { + return nil, err + } + return dec, nil }, } diff --git a/server/functions/trim_scale.go b/server/functions/trim_scale.go index 4de98277a2..7a66e0f5e8 100644 --- a/server/functions/trim_scale.go +++ b/server/functions/trim_scale.go @@ -15,8 +15,8 @@ package functions import ( + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" @@ -34,9 +34,9 @@ var trim_scale_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { // We don't store the scale in the value, so I'm not sure if this is functionally correct. // Seems like we'd need to modify the type of the return value (by trimming the scale), rather than the value itself. - return val1.(decimal.Decimal), nil + return val.(apd.Decimal), nil }, } diff --git a/server/functions/trunc.go b/server/functions/trunc.go index 73eb83ca4d..a432e82cc1 100644 --- a/server/functions/trunc.go +++ b/server/functions/trunc.go @@ -17,8 +17,8 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -37,11 +37,8 @@ var trunc_float64 = framework.Function1{ Return: pgtypes.Float64, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Float64}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil - } - return math.Trunc(val1.(float64)), nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + return math.Trunc(val.(float64)), nil }, } @@ -51,11 +48,13 @@ var trunc_numeric = framework.Function1{ Return: pgtypes.Numeric, Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - if val1 == nil { - return nil, nil + Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { + dec := val.(apd.Decimal) + _, err := pgtypes.BaseContext.Quantize(&dec, &dec, 0) + if err != nil { + return nil, err } - return decimal.NewFromInt(val1.(decimal.Decimal).IntPart()), nil + return dec, nil }, } @@ -65,8 +64,14 @@ var trunc_numeric_int64 = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Int32}, Strict: true, - Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, num any, places any) (any, error) { + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { //TODO: test for negative values in places - return num.(decimal.Decimal).Truncate(places.(int32)), nil + dec := val1.(apd.Decimal) + places := val2.(int32) + _, err := pgtypes.BaseContext.Quantize(&dec, &dec, -places) + if err != nil { + return nil, err + } + return dec, nil }, } diff --git a/server/functions/unary/minus.go b/server/functions/unary/minus.go index 4a85214a1d..126b0bda4e 100644 --- a/server/functions/unary/minus.go +++ b/server/functions/unary/minus.go @@ -15,8 +15,8 @@ package unary import ( + "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -112,6 +112,8 @@ var numeric_uminus = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - return val1.(decimal.Decimal).Neg(), nil + dec := val1.(apd.Decimal) + neg := dec.Neg(&dec) + return *neg, nil }, } diff --git a/server/functions/width_bucket.go b/server/functions/width_bucket.go index 5b7e0d52c4..86145db24b 100644 --- a/server/functions/width_bucket.go +++ b/server/functions/width_bucket.go @@ -17,9 +17,9 @@ package functions import ( "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -71,29 +71,52 @@ var width_bucket_numeric_numeric_numeric_int64 = framework.Function4{ Parameters: [4]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric, pgtypes.Numeric, pgtypes.Int32}, Strict: true, Callable: func(ctx *sql.Context, _ [5]*pgtypes.DoltgresType, operandInterface any, lowInterface any, highInterface any, countInterface any) (any, error) { - operand := operandInterface.(decimal.Decimal) - low := lowInterface.(decimal.Decimal) - high := highInterface.(decimal.Decimal) - if low.Cmp(high) == 0 { + operand := operandInterface.(apd.Decimal) + low := lowInterface.(apd.Decimal) + high := highInterface.(apd.Decimal) + if low.Cmp(&high) == 0 { return nil, errors.Errorf("lower bound cannot equal upper bound") } count := countInterface.(int32) if count <= 0 { return nil, errors.Errorf("count must be greater than zero") } - if operand.Equal(high) { + if operand.Cmp(&high) == 0 { return count + 1, nil - } else if operand.Equal(low) { + } else if operand.Cmp(&low) == 0 { return int32(1), nil } - bucket := high.Sub(low).Div(decimal.NewFromInt(int64(count))) - result := operand.Sub(low).Div(bucket).Ceil() - if result.LessThan(decimal.Zero) { - result = decimal.Zero - } else if result.GreaterThan(decimal.NewFromInt(int64(count + 1))) { - result = decimal.NewFromInt(int64(count + 1)) + bucket := new(apd.Decimal) + _, err := pgtypes.BaseContext.Sub(bucket, &high, &low) + if err != nil { + return nil, err + } + _, err = pgtypes.BaseContext.Quo(bucket, bucket, apd.New(int64(count), 0)) + if err != nil { + return nil, err + } + result := new(apd.Decimal) + _, err = pgtypes.BaseContext.Sub(result, &operand, &low) + if err != nil { + return nil, err + } + _, err = pgtypes.BaseContext.Sub(result, result, bucket) + if err != nil { + return nil, err + } + _, err = pgtypes.BaseContext.Ceil(result, result) + if err != nil { + return nil, err + } + if result.Sign() < 0 { + result = apd.New(0, 0) + } else if c1 := apd.New(int64(count+1), 0); result.Cmp(c1) > 0 { + result = c1 + } + i64, err := result.Int64() + if err != nil { + return nil, err } - i64 := result.IntPart() return int32(i64), nil }, } diff --git a/server/types/json_document.go b/server/types/json_document.go index 0023505a9c..3233520816 100644 --- a/server/types/json_document.go +++ b/server/types/json_document.go @@ -19,9 +19,9 @@ import ( "sort" "strings" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/goccy/go-json" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/utils" ) @@ -72,7 +72,7 @@ type JsonValueArray []JsonValue type JsonValueString string // JsonValueNumber represents a number. -type JsonValueNumber decimal.Decimal +type JsonValueNumber apd.Decimal // JsonValueBoolean represents a boolean value. type JsonValueBoolean bool @@ -189,7 +189,9 @@ func JsonValueCompare(v1 JsonValue, v2 JsonValue) int { return 1 } case JsonValueNumber: - return decimal.Decimal(v1).Cmp(decimal.Decimal(v2.(JsonValueNumber))) + n1 := apd.Decimal(v1) + n2 := apd.Decimal(v2.(JsonValueNumber)) + return n1.Cmp(&n2) case JsonValueBoolean: v2 := v2.(JsonValueBoolean) if v1 == v2 { @@ -249,7 +251,7 @@ func JsonValueSerialize(writer *utils.Writer, value JsonValue) { case JsonValueNumber: writer.Byte(byte(JsonValueType_Number)) // MarshalBinary cannot error, so we can safely ignore it - bytes, _ := decimal.Decimal(value).MarshalBinary() + bytes, _ := Numeric.SerializationFunc(nil, Numeric, apd.Decimal(value)) writer.ByteSlice(bytes) case JsonValueBoolean: writer.Byte(byte(JsonValueType_Boolean)) @@ -289,9 +291,14 @@ func JsonValueDeserialize(reader *utils.Reader) (_ JsonValue, err error) { case JsonValueType_String: return JsonValueString(reader.String()), nil case JsonValueType_Number: - d := decimal.Decimal{} - err = d.UnmarshalBinary(reader.ByteSlice()) - return JsonValueNumber(d), err + d, err := Numeric.DeserializationFunc(nil, Numeric, reader.ByteSlice()) + if err != nil { + return nil, err + } + if d == nil { + d = apd.Decimal{} + } + return JsonValueNumber(d.(apd.Decimal)), err case JsonValueType_Boolean: return JsonValueBoolean(reader.Bool()), nil case JsonValueType_Null: @@ -330,7 +337,8 @@ func JsonValueFormatter(sb *strings.Builder, value JsonValue) { sb.WriteString(strings.ReplaceAll(string(value), `"`, `\"`)) sb.WriteRune('"') case JsonValueNumber: - sb.WriteString(decimal.Decimal(value).String()) + d := apd.Decimal(value) + sb.WriteString(d.Text('f')) case JsonValueBoolean: if value { sb.WriteString(`true`) @@ -407,7 +415,12 @@ func ConvertToJsonDocument(val interface{}) (JsonValue, error) { return JsonValueString(val), nil case float64: // TODO: handle this as a proper numeric as float64 is not precise enough - return JsonValueNumber(decimal.NewFromFloat(val)), nil + d := new(apd.Decimal) + err := d.Scan(val) + if err != nil { + return nil, err + } + return JsonValueNumber(*d), nil case bool: return JsonValueBoolean(val), nil case nil: diff --git a/server/types/numeric.go b/server/types/numeric.go index c74382f28a..ecada41a01 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -15,28 +15,47 @@ package types import ( + "encoding/binary" + "strconv" "strings" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/utils" ) const ( - MaxUint32 = 4294967295 // MaxUint32 is the largest possible value of Uint32 - MinInt32 = -2147483648 // MinInt32 is the smallest possible value of Int32 + MaxUint32 = 4294967295 // MaxUint32 is the largest possible value of Uint32 + MinInt32 = -2147483648 // MinInt32 is the smallest possible value of Int32 + MaxPrecision = uint32(100000) +) + +const ( + pgNumericNaN = 0x00000000c0000000 + pgNumericNaNSign = 0xc000 + + pgNumericPosInf = 0x00000000d0000000 + pgNumericPosInfSign = 0xd000 + + pgNumericNegInf = 0x00000000f0000000 + pgNumericNegInfSign = 0xf000 ) var ( - NumericValueMaxInt16 = decimal.NewFromInt(32767) // NumericValueMaxInt16 is the max Int16 value for NUMERIC types - NumericValueMaxInt32 = decimal.NewFromInt(2147483647) // NumericValueMaxInt32 is the max Int32 value for NUMERIC types - NumericValueMaxInt64 = decimal.NewFromInt(9223372036854775807) // NumericValueMaxInt64 is the max Int64 value for NUMERIC types - NumericValueMinInt16 = decimal.NewFromInt(-32768) // NumericValueMinInt16 is the min Int16 value for NUMERIC types - NumericValueMinInt32 = decimal.NewFromInt(MinInt32) // NumericValueMinInt32 is the min Int32 value for NUMERIC types - NumericValueMinInt64 = decimal.NewFromInt(-9223372036854775808) // NumericValueMinInt64 is the min Int64 value for NUMERIC types - NumericValueMaxUint32 = decimal.NewFromInt(MaxUint32) // NumericValueMaxUint32 is the max Uint32 value for NUMERIC types + NumericValueMaxInt16 = apd.New(32767, 0) // NumericValueMaxInt16 is the max Int16 value for NUMERIC types + NumericValueMaxInt32 = apd.New(2147483647, 0) // NumericValueMaxInt32 is the max Int32 value for NUMERIC types + NumericValueMaxInt64 = apd.New(9223372036854775807, 0) // NumericValueMaxInt64 is the max Int64 value for NUMERIC types + NumericValueMinInt16 = apd.New(-32768, 0) // NumericValueMinInt16 is the min Int16 value for NUMERIC types + NumericValueMinInt32 = apd.New(MinInt32, 0) // NumericValueMinInt32 is the min Int32 value for NUMERIC types + NumericValueMinInt64 = apd.New(-9223372036854775808, 0) // NumericValueMinInt64 is the min Int64 value for NUMERIC types + NumericValueMaxUint32 = apd.New(MaxUint32, 0) // NumericValueMaxUint32 is the max Uint32 value for NUMERIC types + NumericNaN = apd.Decimal{Form: apd.NaN} + NumericInf = apd.Decimal{Form: apd.Infinite} + NumericNegInf = apd.Decimal{Form: apd.Infinite, Negative: true} + BaseContext = apd.BaseContext.WithPrecision(MaxPrecision) ) // Numeric is a precise and unbounded decimal value. @@ -108,24 +127,122 @@ func GetPrecisionAndScaleFromTypmod(typmod int32) (int32, int32) { // GetNumericValueWithTypmod returns either given numeric value or truncated or error // depending on the precision and scale decoded from given type modifier value. -func GetNumericValueWithTypmod(val decimal.Decimal, typmod int32) (decimal.Decimal, error) { +func GetNumericValueWithTypmod(val apd.Decimal, typmod int32) (apd.Decimal, error) { if typmod == -1 { return val, nil } + res := new(apd.Decimal) precision, scale := GetPrecisionAndScaleFromTypmod(typmod) - str := val.StringFixed(scale) - parts := strings.Split(str, ".") - if int32(len(parts[0])) > precision-scale && val.IntPart() != 0 { - // TODO: split error message to ERROR and DETAIL - return decimal.Decimal{}, errors.Errorf("numeric field overflow - A field with precision %v, scale %v must round to an absolute value less than 10^%v", precision, scale, precision-scale) + _, err := BaseContext.WithPrecision(uint32(precision)).Quantize(res, &val, -scale) + if err != nil { + return apd.Decimal{}, errors.Errorf("numeric field overflow - A field with precision %v, scale %v must round to an absolute value less than 10^%v", precision, scale, precision-scale) } - return decimal.NewFromString(str) + return *res, nil +} + +// GetNumericValueFromStringWithTypmod returns either given numeric value or truncated or error +// depending on the precision and scale decoded from given type modifier value. +func GetNumericValueFromStringWithTypmod(val string, typmod int32) (apd.Decimal, error) { + dec, cond, err := BaseContext.WithPrecision(MaxPrecision).NewFromString(val) + if err != nil { + return apd.Decimal{}, err + } + if cond.Inexact() || cond.Rounded() { + return apd.Decimal{}, errors.Errorf(`numeric precision was lost or truncated for %s`, val) + } + return GetNumericValueWithTypmod(*dec, typmod) } // serializeTypeNumeric handles serialization from the standard representation to our serialized representation that is // written in Dolt. func serializeTypeNumeric(ctx *sql.Context, t *DoltgresType, val any) ([]byte, error) { - return val.(decimal.Decimal).MarshalBinary() + num := val.(apd.Decimal) + typmod := t.GetAttTypMod() + writer := utils.NewWireWriter() + if num.Form == apd.Finite { + // Short-circuit if this is the zero value + if num.IsZero() { + writer.WriteBytes([]byte{0, 0, 0, 0, 0, 0, 0, 0}) + return writer.BufferData(), nil + } + // There's a way to do this more efficiently, but we can do that work once this becomes a performance issue. + // This is based on the terminology used in Postgres' `numeric.c` file + decStr := num.Text('f') + isNegative := false + if strings.HasPrefix(decStr, "-") { + isNegative = true + decStr = decStr[1:] + } + // Split the integer and fractional parts + var intPart string + var fractPart string + if idx := strings.Index(decStr, "."); idx != -1 { + intPart = decStr[:idx] + fractPart = decStr[idx+1:] + } else { + intPart = decStr + } + // Find the "dscale", which is the number of digits in the fractional part + var dscale int16 + if typmod != -1 { + _, dscale32 := GetPrecisionAndScaleFromTypmod(typmod) + dscale = int16(dscale32) + } else { + dscale = int16(len(fractPart)) + } + // Pad the integer and fractional parts so that we can take groups of 4 numbers + if intPart == "0" { + intPart = "" + } else if len(intPart)%4 != 0 { + intPart = strings.Repeat("0", 4-(len(intPart)%4)) + intPart + } + if len(fractPart)%4 != 0 { + // remove trailing zeroes on right side before filling it. + fractPart = strings.TrimRightFunc(fractPart, func(r rune) bool { + return r == '0' + }) + fractPart = fractPart + strings.Repeat("0", 4-(len(fractPart)%4)) + } + // Write the "ndigits" first, or the number of base-10000 digits + writer.WriteInt16(int16((len(intPart) / 4) + (len(fractPart) / 4))) + // Write the "weight", which is the number of base-10000 digits in the integer part subtracted by 1 + writer.WriteInt16(int16((len(intPart) / 4) - 1)) + // Write the "sign" + if isNegative { + writer.WriteInt16(16384) + } else { + writer.WriteInt16(0) + } + // Write the "dscale" + writer.WriteInt16(dscale) + // Write all of the digits + fullPart := intPart + fractPart + for i := 0; i < len(fullPart); i += 4 { + part, err := strconv.Atoi(fullPart[i : i+4]) + if err != nil { + return nil, err + } + writer.WriteInt16(int16(part)) + } + } else { + var buf []byte + wp := len(buf) + buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) + if num.Form == apd.NaN { + binary.BigEndian.PutUint64(buf[wp:], pgNumericNaN) + } else if num.Form == apd.Infinite { + if num.Negative { + binary.BigEndian.PutUint64(buf[wp:], pgNumericNegInf) + } else { + binary.BigEndian.PutUint64(buf[wp:], pgNumericPosInf) + } + } + if typmod == -1 { + binary.BigEndian.PutUint16(buf[6:], uint16(32)) + } + writer.WriteBytes(buf) + } + return writer.BufferData(), nil } // deserializeTypeNumeric handles deserialization from the Dolt serialized format to our standard representation used by @@ -134,7 +251,99 @@ func deserializeTypeNumeric(ctx *sql.Context, t *DoltgresType, data []byte) (any if len(data) == 0 { return nil, nil } - retVal := decimal.NewFromInt(0) - err := retVal.UnmarshalBinary(data) - return retVal, err + reader := utils.NewWireReader(data) + var d apd.Decimal + + // 1. Read Header + ndigits := reader.ReadInt16() + weight := reader.ReadInt16() + sign := reader.ReadInt16() + dscale := reader.ReadInt16() + + // 2. Handle Special Values (NaN, Inf) + // These usually manifest as specific bit patterns in the header + switch uint16(sign) { + case 0xC000: // pgNumericNaN + d.Form = apd.NaN + return d, nil + case 0xD000: // pgNumericPosInf + d.Form = apd.Infinite + return d, nil + case 0xF000: // pgNumericNegInf + d.Form = apd.Infinite + d.Negative = true + return d, nil + } + + // 3. Handle Finite Values + if ndigits == 0 { + d.SetInt64(0) + return d, nil + } + + // Read base-10000 digits + digits := make([]int16, ndigits) + for i := 0; i < int(ndigits); i++ { + digits[i] = reader.ReadInt16() + } + + // 4. Convert base-10000 to string for apd.Decimal + // Each digit is exactly 4 characters wide (except potentially the first) + var sb strings.Builder + if sign == 16384 { + sb.WriteByte('-') + } + + for i, digit := range digits { + // Calculate how many 10000-base digits are before the decimal + // 'weight' is the index of the first digit, where 0 is 10^0 in base 10000 + if i == int(weight)+1 { + sb.WriteByte('.') + } + + sDigit := strconv.Itoa(int(digit)) + // Pad with leading zeros if not the very first digit + if l := len(sDigit); l < 4 { + padding := 4 - l + for p := 0; p < padding; p++ { + sb.WriteByte('0') + } + } + sb.WriteString(sDigit) + } + + // If weight is larger than digits, we need trailing zeros + if int(weight) >= len(digits) { + for i := 0; i < int(weight)-len(digits)+1; i++ { + sb.WriteString("0000") + } + } + + // If weight is negative, we need leading zeros after decimal point + if weight < 0 { + // This logic can get complex; using apd.SetString is the safest path + // but ensure the decimal point is placed correctly based on dscale. + } + + dec, _, err := BaseContext.NewFromString(sb.String()) + if err != nil { + return nil, err + } + _, _ = BaseContext.Quantize(dec, dec, int32(-dscale)) + return *dec, err +} + +// NumericCompare compares two apd.Decimal values handling NaN separately. +func NumericCompare(ab, bb apd.Decimal) int { + if (ab.Form == apd.NaN && bb.Form == apd.NaN) || + (ab.Form == apd.Infinite && bb.Form == apd.Infinite && ab.Negative == bb.Negative) { + return 0 + } + if ab.Form == apd.NaN { + return 1 + } + if bb.Form == apd.NaN { + return -1 + } + return ab.Cmp(&bb) } diff --git a/server/types/type.go b/server/types/type.go index 424807f042..36044e1f1d 100644 --- a/server/types/type.go +++ b/server/types/type.go @@ -23,6 +23,7 @@ import ( "reflect" "time" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" "github.com/dolthub/dolt/go/store/val" @@ -30,7 +31,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/duration" @@ -318,9 +318,9 @@ func (t *DoltgresType) Compare(ctx context.Context, v1 interface{}, v2 interface case JsonDocument: bb := v2.(JsonDocument) return JsonValueCompare(ab.Value, bb.Value), nil - case decimal.Decimal: - bb := v2.(decimal.Decimal) - return ab.Cmp(bb), nil + case apd.Decimal: + bb := v2.(apd.Decimal) + return NumericCompare(ab, bb), nil case timeofday.TimeOfDay: bb := v2.(timeofday.TimeOfDay) return ab.Compare(bb), nil @@ -1042,7 +1042,7 @@ func (t *DoltgresType) Zero() interface{} { case "int8": return int64(0) case "numeric": - return decimal.Zero + return *apd.New(0, 0) case "oid", "regclass", "regproc", "regtype": return id.Null default: diff --git a/server/types/typeinfo.go b/server/types/typeinfo.go index 93b00b151e..52a441af76 100644 --- a/server/types/typeinfo.go +++ b/server/types/typeinfo.go @@ -65,8 +65,8 @@ func (t typeInfo) Encoding() val.Encoding { return val.Float32Enc case "float8": return val.Float64Enc - case "numeric", "decimal": - return val.DecimalEnc + //case "numeric", "decimal": + // return val.DecimalEnc case "bytea": return val.BytesAdaptiveEnc // TODO: use dolt JSON document encoding here diff --git a/testing/generation/function_coverage/generators.go b/testing/generation/function_coverage/generators.go index bdc51cfe80..228aa9e094 100644 --- a/testing/generation/function_coverage/generators.go +++ b/testing/generation/function_coverage/generators.go @@ -123,6 +123,9 @@ var int64ValueGenerators = utils.Or( // numericValueGenerators contains an assortment of numbers that may be used for testing NUMERIC. var numericValueGenerators = utils.Or( + utils.Text("'NaN'::numeric"), + utils.Text("'Infinity'::numeric"), + utils.Text("'-Infinity'::numeric"), utils.Text("0::numeric"), utils.Text("-1::numeric"), utils.Text("1::numeric"), diff --git a/testing/generation/function_coverage/output/framework_test.go b/testing/generation/function_coverage/output/framework_test.go index 2992319708..5ba8d38101 100644 --- a/testing/generation/function_coverage/output/framework_test.go +++ b/testing/generation/function_coverage/output/framework_test.go @@ -25,15 +25,16 @@ import ( "testing" "time" + "github.com/cockroachdb/apd/v3" "github.com/dolthub/dolt/go/libraries/utils/svcs" "github.com/dolthub/go-mysql-server/sql" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" dserver "github.com/dolthub/doltgresql/server" + pgtypes "github.com/dolthub/doltgresql/server/types" "github.com/dolthub/doltgresql/servercfg" "github.com/dolthub/doltgresql/servercfg/cfgdetails" "github.com/dolthub/doltgresql/utils" @@ -47,9 +48,6 @@ var ( // between Go and C/C++, so this threshold allows us to check that non-integer values are close enough. Over time, // we may reduce or remove this value as we become more accurate. EquivalenceThresholdFloat64 = 0.001 - // EquivalenceThresholdNumeric represents the allowable delta for values to be considered equivalent. - // This is computed using the float64 variant so that they're equivalent. - EquivalenceThresholdNumeric = decimal.RequireFromString(strconv.FormatFloat(EquivalenceThresholdFloat64, 'f', -1, 64)) ) // ScriptTest defines a consistent structure for testing queries. @@ -249,12 +247,16 @@ func Numeric(str string) pgtype.Numeric { } // NumericToDecimal converts a pgtype.Numeric value to a decimal.Decimal value. -func NumericToDecimal(val pgtype.Numeric) decimal.Decimal { +func NumericToDecimal(val pgtype.Numeric) apd.Decimal { strVal, err := val.Value() if err != nil { panic(err) } - return decimal.RequireFromString(strVal.(string)) + d, _, err := apd.NewFromString(strVal.(string)) + if err != nil { + panic(err) + } + return *d } // CompareResults compares two sets of results, taking the equivalence thresholds into account when making the @@ -293,8 +295,18 @@ func CompareRows(t *testing.T, a sql.Row, b sql.Row) bool { return false } case pgtype.Numeric: - delta := NumericToDecimal(aVal.(pgtype.Numeric)).Sub(NumericToDecimal(bVal.(pgtype.Numeric))).Abs() - if delta.Cmp(EquivalenceThresholdNumeric) == 1 { + aDec := NumericToDecimal(aVal.(pgtype.Numeric)) + bDec := NumericToDecimal(bVal.(pgtype.Numeric)) + _, err := pgtypes.BaseContext.Sub(&aDec, &aDec, &bDec) + if err != nil { + return false + } + aDec = *aDec.Abs(&aDec) + _, err = pgtypes.BaseContext.Sub(&aDec, &aDec, &bDec) + // EquivalenceThresholdNumeric represents the allowable delta for values to be considered equivalent. + // This is computed using the float64 variant so that they're equivalent. + EquivalenceThresholdNumeric, _, _ := apd.NewFromString(strconv.FormatFloat(EquivalenceThresholdFloat64, 'f', -1, 64)) + if aDec.Cmp(EquivalenceThresholdNumeric) == 1 { return false } default: diff --git a/testing/go/coercion_test.go b/testing/go/coercion_test.go index 3838de291a..aef84e386d 100644 --- a/testing/go/coercion_test.go +++ b/testing/go/coercion_test.go @@ -110,6 +110,26 @@ func TestCoercion(t *testing.T) { Query: `SELECT abs('12345671297673227365.5123624235623456')`, Expected: []sql.Row{{float64(1.2345671297673228e+19)}}, }, + { + Query: `SELECT abs('NaN'::numeric)`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT abs('Inf'::numeric)`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT abs('-infinity'::numeric)`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT abs('0'::numeric)`, + Expected: []sql.Row{{Numeric("0")}}, + }, + { + Query: `SELECT abs('-0.50'::numeric)`, + Expected: []sql.Row{{Numeric("0.50")}}, + }, { Query: `SELECT factorial('1')`, Expected: []sql.Row{{Numeric("1")}}, @@ -118,6 +138,66 @@ func TestCoercion(t *testing.T) { Query: `SELECT factorial('1.5')`, ExpectedErr: "invalid input", }, + { + Query: `SELECT ceil('NaN'::numeric)`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT ceil('Inf'::numeric)`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT ceil('-infinity'::numeric)`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT floor('NaN'::numeric)`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT floor('Inf'::numeric)`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT floor('-infinity'::numeric)`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT ln('NaN'::numeric)`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT ln('Inf'::numeric)`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT ln('-infinity'::numeric)`, + ExpectedErr: `cannot take logarithm of a negative number`, + }, + { + Query: `SELECT log('NaN'::numeric)`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT log('Inf'::numeric)`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT log('-infinity'::numeric)`, + ExpectedErr: `cannot take logarithm of a negative number`, + }, + { + Query: `SELECT min_scale('NaN'::numeric)`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT min_scale('Inf'::numeric)`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT min_scale('-infinity'::numeric)`, + Expected: []sql.Row{{nil}}, + }, }, }, }) diff --git a/testing/go/framework.go b/testing/go/framework.go index d8071624a4..cb8fe07d66 100644 --- a/testing/go/framework.go +++ b/testing/go/framework.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" @@ -36,7 +37,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" - "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -709,17 +709,6 @@ func NormalizeValToString(dt *types.DoltgresType, v any) any { } else { return "f" } - case pgtype.Numeric: - if val.NaN { - return math.NaN() - } else if val.InfinityModifier != pgtype.Finite { - return math.Inf(int(val.InfinityModifier)) - } else if !val.Valid { - return nil - } else { - decStr := decimal.NewFromBigInt(val.Int, val.Exp).StringFixed(val.Exp * -1) - return Numeric(decStr) - } case pgtype.Time: return timeofday.TimeOfDay(val.Microseconds).String() case []any: @@ -788,7 +777,7 @@ func NormalizeVal(dt *types.DoltgresType, v any) any { } else if !val.Valid { return nil } else { - return decimal.NewFromBigInt(val.Int, val.Exp) + return *apd.New(val.Int.Int64(), val.Exp) } case pgtype.Time: // This value type is used for TIME type. @@ -854,12 +843,12 @@ func NormalizeIntsAndFloats(v any) any { // Numeric creates a numeric value from a string. func Numeric(str string) pgtype.Numeric { - // 250.0 != 250 and 42.90 != 42.9, so we trim all trailing fractional zeroes (and decimal if no fractional zeroes) - // to ensure that the input strings are homogenized, which will give us comparable representations for the same value - if idx := strings.Index(str, "."); idx != -1 { - str = strings.TrimRight(str, "0") - } - str = strings.TrimRight(str, ".") + //// 250.0 != 250 and 42.90 != 42.9, so we trim all trailing fractional zeroes (and decimal if no fractional zeroes) + //// to ensure that the input strings are homogenized, which will give us comparable representations for the same value + //if idx := strings.Index(str, "."); idx != -1 { + // str = strings.TrimRight(str, "0") + //} + //str = strings.TrimRight(str, ".") numeric := pgtype.Numeric{} if err := numeric.Scan(str); err != nil { panic(err) diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index f324c76f9b..85b1a12e78 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -420,6 +420,16 @@ func TestFunctionsMath(t *testing.T) { {4.0}, }, }, + { + Query: `SELECT round('NaN'::numeric);`, + ExpectedColNames: []string{"round"}, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 0::numeric::float8;`, + ExpectedColNames: []string{"float8"}, + Expected: []sql.Row{{0.0}}, + }, }, }, { @@ -585,6 +595,73 @@ func TestFunctionsMath(t *testing.T) { Query: `SELECT power(0::numeric, -1::numeric);`, ExpectedErr: `zero raised to a negative power is undefined`, }, + { + Query: `select power('3'::numeric, '-1'::numeric);`, + Expected: []sql.Row{ + {Numeric("0.3333333333333333")}, + }, + }, + { + Query: `select power('3'::numeric, '-3'::numeric);`, + Expected: []sql.Row{ + {Numeric("0.0370370370370370")}, + }, + }, + { + Query: `select power('Nan'::numeric, '-3'::numeric);`, + Expected: []sql.Row{ + {Numeric("NaN")}, + }, + }, + { + Query: `select power('inf'::numeric, '-3'::numeric);`, + Expected: []sql.Row{ + {Numeric("0")}, + }, + }, + { + Query: `select power('-inf'::numeric, '-3'::numeric);`, + Expected: []sql.Row{ + {Numeric("0")}, + }, + }, + { + Query: `select power('-inf'::numeric, '3'::numeric);`, + Expected: []sql.Row{ + {Numeric("-Infinity")}, + }, + }, + { + Query: `select power('-inf'::numeric, '4'::numeric);`, + Expected: []sql.Row{ + {Numeric("Infinity")}, + }, + }, + { + Skip: true, //TODO: fix + Query: `select power('0'::numeric, '3'::numeric);`, + Expected: []sql.Row{ + {Numeric("0.0000000000000000")}, + }, + }, + { + Query: `select power('-inf'::numeric, '-inf'::numeric);`, + Expected: []sql.Row{ + {Numeric("0")}, + }, + }, + { + Query: `select power('-inf'::numeric, 'inf'::numeric);`, + Expected: []sql.Row{ + {Numeric("Infinity")}, + }, + }, + { + Query: `select power('-inf'::numeric, 'nan'::numeric);`, + Expected: []sql.Row{ + {Numeric("NaN")}, + }, + }, }, }, }) @@ -3836,6 +3913,34 @@ func TestSetReturningFunctions(t *testing.T) { {"2008-03-01 06:00:00"}, }, }, + { + Query: `SELECT generate_series('1.2'::numeric,2.4)`, + Expected: []sql.Row{{Numeric("1.2")}, {Numeric("2.2")}}, + }, + { + Query: `SELECT generate_series('1.2'::numeric,1.4,0.1)`, + Expected: []sql.Row{{Numeric("1.2")}, {Numeric("1.3")}, {Numeric("1.4")}}, + }, + { + Query: `SELECT generate_series('Nan'::numeric,1.4,0.1)`, + ExpectedErr: `start value cannot be NaN`, + }, + { + Query: `SELECT generate_series('NaN'::numeric,1.4)`, + ExpectedErr: `start value cannot be NaN`, + }, + { + Query: `SELECT generate_series('1.2'::numeric,'Infinity',0.1)`, + ExpectedErr: `stop value cannot be infinity`, + }, + { + Query: `SELECT generate_series('1.2'::numeric,'-Infinity')`, + ExpectedErr: `stop value cannot be infinity`, + }, + { + Query: `SELECT generate_series('1.2'::numeric,1.4,'NAN')`, + ExpectedErr: `step value cannot be NaN`, + }, }, }, { diff --git a/testing/go/operators_test.go b/testing/go/operators_test.go index ce1a7a6f70..3b8627254d 100644 --- a/testing/go/operators_test.go +++ b/testing/go/operators_test.go @@ -172,6 +172,34 @@ func TestOperators(t *testing.T) { Query: `SELECT 1::numeric + 2::numeric;`, Expected: []sql.Row{{Numeric("3")}}, }, + { + Query: `SELECT 1::numeric + 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 1::numeric + 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT '-inf'::numeric + 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'nan'::numeric + 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT '-Inf'::numeric + '2'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT '-Infinity'::numeric + 'NaN'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'infinity'::numeric + '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, { Query: `select interval '2 days' + interval '1.5 days';`, Expected: []sql.Row{{"3 days 12:00:00"}}, @@ -348,6 +376,38 @@ func TestOperators(t *testing.T) { Query: `SELECT 1::numeric - 2::numeric;`, Expected: []sql.Row{{Numeric("-1")}}, }, + { + Query: `SELECT 1::numeric - 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 1::numeric - 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT '-inf'::numeric - 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT 'nan'::numeric - 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT '-Inf'::numeric - '2'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT '2'::numeric - '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT '-Infinity'::numeric - 'NaN'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'infinity'::numeric - '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, { Query: `select interval '2 days' - interval '1.5 days';`, Expected: []sql.Row{{"1 day -12:00:00"}}, @@ -504,6 +564,46 @@ func TestOperators(t *testing.T) { Query: `SELECT 1::numeric * 2::numeric;`, Expected: []sql.Row{{Numeric("2")}}, }, + { + Query: `SELECT 1::numeric * 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 0::numeric * 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 0::numeric * 'inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 1::numeric * 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT '-inf'::numeric * 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT 'nan'::numeric * 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT '-Inf'::numeric * '2'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT '2'::numeric * '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT '-Infinity'::numeric * 'NaN'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'infinity'::numeric * '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, { Query: `select interval '20 days' * 2.3`, Expected: []sql.Row{{"46 days"}}, @@ -648,6 +748,11 @@ func TestOperators(t *testing.T) { Query: `SELECT 8::numeric / 2::int2;`, Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, + { + Skip: true, // TODO: fix scaling for division + Query: `SELECT 44400080::numeric / 2::int2;`, + Expected: []sql.Row{{Numeric("22200040.000000000000")}}, + }, { Query: `SELECT 8::numeric / 2::int4;`, Expected: []sql.Row{{Numeric("4.0000000000000000")}}, @@ -660,6 +765,66 @@ func TestOperators(t *testing.T) { Query: `SELECT 8::numeric / 2::numeric;`, Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, + { + Query: `SELECT 1::numeric / 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 0::numeric / 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 0::numeric / 'inf'::numeric;`, + Expected: []sql.Row{{Numeric("0")}}, + }, + { + Query: `SELECT 1::numeric / 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("0")}}, + }, + { + Query: `SELECT '-inf'::numeric / 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'nan'::numeric / 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT '-Inf'::numeric / '2'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT '2'::numeric / '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("0")}}, + }, + { + Query: `SELECT '-Infinity'::numeric / 'NaN'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'infinity'::numeric / '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT '-Infinity'::numeric / '0'::numeric;`, + ExpectedErr: `division by zero`, + }, + { + Query: `SELECT 'nan'::numeric / '1'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'nan'::numeric / '0'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'infinity'::numeric / '2'::numeric;`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT '-infinity'::numeric / '2'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, { Query: `select interval '20 days' / 2.3`, Expected: []sql.Row{{"8 days 16:41:44.347826"}}, @@ -736,6 +901,66 @@ func TestOperators(t *testing.T) { Query: `SELECT 11::numeric % 3::numeric;`, Expected: []sql.Row{{Numeric("2")}}, }, + { + Query: `SELECT 1::numeric % 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 0::numeric % 'nan'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 0::numeric % 'inf'::numeric;`, + Expected: []sql.Row{{Numeric("0")}}, + }, + { + Query: `SELECT 1::numeric % 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("0")}}, + }, + { + Query: `SELECT '-inf'::numeric % 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'nan'::numeric % 'Inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT '-Inf'::numeric % '2'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: `SELECT '2'::numeric % '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("0")}}, + }, + { + Query: `SELECT '-Infinity'::numeric % 'NaN'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'infinity'::numeric % '-inf'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT '-Infinity'::numeric % '0'::numeric;`, + ExpectedErr: `division by zero`, + }, + { + Query: `SELECT 'nan'::numeric % '1'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'nan'::numeric % '0'::numeric;`, + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: `SELECT 'infinity'::numeric % '2'::numeric;`, + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: `SELECT '-infinity'::numeric % '2'::numeric;`, + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, }, }, { @@ -1024,6 +1249,30 @@ func TestOperators(t *testing.T) { Query: `SELECT 20.10::numeric < 10.20::numeric;`, Expected: []sql.Row{{"f"}}, }, + { + Query: `SELECT '-inf'::numeric < 'inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric < 'inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric < '-inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric < 'NaN'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT '-inf'::numeric < 'nan'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'inf'::numeric < 'nan'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, { Query: `SELECT 101::oid < 202::oid;`, Expected: []sql.Row{{"t"}}, @@ -1348,6 +1597,34 @@ func TestOperators(t *testing.T) { Query: `SELECT 20.10::numeric > 10.20::numeric;`, Expected: []sql.Row{{"t"}}, }, + { + Query: `SELECT '-inf'::numeric > 'inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric > 'inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric > '-inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric > 'NaN'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT '-inf'::numeric > 'nan'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'inf'::numeric > 'nan'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'inf'::numeric > 'inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, { Query: `SELECT 101::oid > 202::oid;`, Expected: []sql.Row{{"f"}}, @@ -1764,6 +2041,34 @@ func TestOperators(t *testing.T) { Query: `SELECT 20.10::numeric <= 10.20::numeric;`, Expected: []sql.Row{{"f"}}, }, + { + Query: `SELECT '-inf'::numeric <= 'inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric <= 'inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric <= '-inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric <= 'NaN'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT '-inf'::numeric <= 'nan'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'inf'::numeric <= 'nan'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'inf'::numeric <= 'inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, { Query: `SELECT 101::oid <= 202::oid;`, Expected: []sql.Row{{"t"}}, @@ -2232,6 +2537,34 @@ func TestOperators(t *testing.T) { Query: `SELECT 20.10::numeric >= 10.20::numeric;`, Expected: []sql.Row{{"t"}}, }, + { + Query: `SELECT '-inf'::numeric >= 'inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric >= 'inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric >= '-inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric >= 'NaN'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT '-inf'::numeric >= 'nan'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'inf'::numeric >= 'nan'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT '-inf'::numeric >= '-infinity'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, { Query: `SELECT 101::oid >= 202::oid;`, Expected: []sql.Row{{"f"}}, @@ -2608,6 +2941,34 @@ func TestOperators(t *testing.T) { Query: `SELECT 20.10::numeric = 10.20::numeric;`, Expected: []sql.Row{{"f"}}, }, + { + Query: `SELECT '-inf'::numeric = 'inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric = 'inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric = '-inf'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'nan'::numeric = 'NaN'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT '-inf'::numeric = 'nan'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT 'inf'::numeric = 'nan'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT '-inf'::numeric = '-infinity'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, { Query: `SELECT 101::oid = 101::oid;`, Expected: []sql.Row{{"t"}}, @@ -2920,6 +3281,34 @@ func TestOperators(t *testing.T) { Query: `SELECT 20.10::numeric <> 10.20::numeric;`, Expected: []sql.Row{{"t"}}, }, + { + Query: `SELECT '-inf'::numeric <> 'inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric <> 'inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric <> '-inf'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'nan'::numeric <> 'NaN'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, + { + Query: `SELECT '-inf'::numeric <> 'nan'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'inf'::numeric <> 'nan'::numeric;`, + Expected: []sql.Row{{"t"}}, + }, + { + Query: `SELECT 'inf'::numeric <> 'infinity'::numeric;`, + Expected: []sql.Row{{"f"}}, + }, { Query: `SELECT 101::oid <> 101::oid;`, Expected: []sql.Row{{"f"}}, diff --git a/testing/go/types_test.go b/testing/go/types_test.go index d86b486107..16dc7063f3 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -1866,6 +1866,30 @@ var typesTests = []ScriptTest{ Query: "select 1.03::float4::numeric(2,2);", ExpectedErr: `numeric field overflow`, }, + { + Query: "SELECT 'NaN'::numeric;", + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: "SELECT 'nan'::numeric;", + Expected: []sql.Row{{Numeric("NaN")}}, + }, + { + Query: "SELECT '-inf'::numeric;", + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: "SELECT '-infinity'::numeric;", + Expected: []sql.Row{{Numeric("-Infinity")}}, + }, + { + Query: "SELECT 'inf'::numeric;", + Expected: []sql.Row{{Numeric("Infinity")}}, + }, + { + Query: "SELECT 'infinity'::numeric;", + Expected: []sql.Row{{Numeric("Infinity")}}, + }, }, }, { diff --git a/testing/go/wire_test.go b/testing/go/wire_test.go index 065efb6207..245d7219cb 100644 --- a/testing/go/wire_test.go +++ b/testing/go/wire_test.go @@ -2911,7 +2911,7 @@ func TestWireTypesSending(t *testing.T) { Name: "NUMERIC returning text format", SetUpScript: []string{ "CREATE TABLE test (v1 NUMERIC, v2 NUMERIC(5,2), v3 NUMERIC(14,5));", - "INSERT INTO test VALUES (0, -0.1, NULL), (12357232.456786653224768755799, 235.67, 4278.009);", + "INSERT INTO test VALUES (0, -0.1, NULL), (12357232.456786653224768755799, 235.67, 4278.009), ('Infinity', 'NaN', 'NaN'), ('-Infinity', '0.05', '0.1045678');", }, Assertions: []WireScriptTestAssertion{ { @@ -2980,6 +2980,13 @@ func TestWireTypesSending(t *testing.T) { }, Receive: []pgproto3.BackendMessage{ &pgproto3.BindComplete{}, + &pgproto3.DataRow{ + Values: [][]byte{ + []byte(`-Infinity`), + []byte(`0.05`), + []byte(`0.10457`), + }, + }, &pgproto3.DataRow{ Values: [][]byte{ []byte(`0`), @@ -2994,7 +3001,15 @@ func TestWireTypesSending(t *testing.T) { []byte(`4278.00900`), }, }, - &pgproto3.CommandComplete{CommandTag: []byte("SELECT 2")}, + &pgproto3.DataRow{ + Values: [][]byte{ + []byte(`Infinity`), + []byte(`NaN`), + []byte(`NaN`), + }, + }, + + &pgproto3.CommandComplete{CommandTag: []byte("SELECT 4")}, &pgproto3.CloseComplete{}, &pgproto3.ReadyForQuery{TxStatus: 'I'}, }, @@ -3005,7 +3020,7 @@ func TestWireTypesSending(t *testing.T) { Name: "NUMERIC returning binary format", SetUpScript: []string{ "CREATE TABLE test (v1 NUMERIC, v2 NUMERIC(5,2), v3 NUMERIC(14,5));", - "INSERT INTO test VALUES (0, -0.1, NULL), (12357232.456786653224768755799, 235.67, 4278.009);", + "INSERT INTO test VALUES (0, -0.1, NULL), (12357232.456786653224768755799, 235.67, 4278.009), ('Infinity', 'NaN', 'NaN'), ('-Infinity', '0.05', '0.1045678');", }, Assertions: []WireScriptTestAssertion{ { @@ -3074,6 +3089,13 @@ func TestWireTypesSending(t *testing.T) { }, Receive: []pgproto3.BackendMessage{ &pgproto3.BindComplete{}, + &pgproto3.DataRow{ + Values: [][]byte{ + {0, 0, 0, 0, 240, 0, 0, 32}, + {0, 1, 255, 255, 0, 0, 0, 2, 1, 244}, + {0, 2, 255, 255, 0, 0, 0, 5, 4, 21, 27, 88}, + }, + }, &pgproto3.DataRow{ Values: [][]byte{ {0, 0, 0, 0, 0, 0, 0, 0}, @@ -3088,7 +3110,14 @@ func TestWireTypesSending(t *testing.T) { {0, 2, 0, 0, 0, 0, 0, 5, 16, 182, 0, 90}, }, }, - &pgproto3.CommandComplete{CommandTag: []byte("SELECT 2")}, + &pgproto3.DataRow{ + Values: [][]byte{ + {0, 0, 0, 0, 208, 0, 0, 32}, + {0, 0, 0, 0, 192, 0, 0, 0}, + {0, 0, 0, 0, 192, 0, 0, 0}, + }, + }, + &pgproto3.CommandComplete{CommandTag: []byte("SELECT 4")}, &pgproto3.CloseComplete{}, &pgproto3.ReadyForQuery{TxStatus: 'I'}, }, From 2f47943c342f2f16085eee7799bbb175986f6338 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 27 Apr 2026 15:29:03 -0700 Subject: [PATCH 03/16] fix precision for div, ln, log, sqrt, trunc --- server/cast/float64.go | 1 - server/expression/literal.go | 2 +- server/functions/binary/divide.go | 32 +++++++++++- server/functions/binary/mod.go | 26 +++++++++- server/functions/div.go | 51 +++++++++---------- server/functions/ln.go | 19 ++++++- server/functions/log.go | 47 ++++++++++++++--- server/functions/mod.go | 46 ++++++++--------- server/functions/sqrt.go | 25 ++++++++- server/functions/trunc.go | 8 ++- server/types/numeric.go | 8 +-- .../output/framework_test.go | 17 +++---- 12 files changed, 196 insertions(+), 86 deletions(-) diff --git a/server/cast/float64.go b/server/cast/float64.go index cba265c9de..05ce629589 100644 --- a/server/cast/float64.go +++ b/server/cast/float64.go @@ -77,7 +77,6 @@ func float64Assignment() { ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { d := new(apd.Decimal) - d.String() err := d.Scan(val.(float64)) if err != nil { return nil, err diff --git a/server/expression/literal.go b/server/expression/literal.go index 7e991ffc65..49012cc7ce 100644 --- a/server/expression/literal.go +++ b/server/expression/literal.go @@ -107,7 +107,7 @@ func NewRawLiteralFloat64(val float64) *expression.Literal { return expression.NewLiteral(val, pgtypes.Float64) } -// NewRawLiteralNumeric returns a new *expression.Literal containing a decimal.Decimal value. +// NewRawLiteralNumeric returns a new *expression.Literal containing an apd.Decimal value. func NewRawLiteralNumeric(val apd.Decimal) *expression.Literal { return expression.NewLiteral(val, pgtypes.Numeric) } diff --git a/server/functions/binary/divide.go b/server/functions/binary/divide.go index 88307fa23e..13108de633 100644 --- a/server/functions/binary/divide.go +++ b/server/functions/binary/divide.go @@ -15,11 +15,11 @@ package binary import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/postgres/parser/duration" - "github.com/dolthub/doltgresql/server/functions" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -290,5 +290,33 @@ var numeric_div = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: functions.NumericDivCallable, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { + return pgtypes.NumericNaN, nil + } + if num2.IsZero() { + return nil, errors.Errorf("division by zero") + } + if num1.Form == apd.Infinite { + return num1, nil + } + if num2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + // TODO: calculate precision and scale accurately + c := apd.BaseContext.WithPrecision(1000000) + _, err := c.QuoInteger(&num1, &num1, &num2) + if err != nil { + return nil, err + } + _, err = c.Quantize(&num1, &num1, -16) + if err != nil { + return nil, err + } + + return num1, nil + }, } diff --git a/server/functions/binary/mod.go b/server/functions/binary/mod.go index 60e08f1649..564073d35c 100644 --- a/server/functions/binary/mod.go +++ b/server/functions/binary/mod.go @@ -15,10 +15,10 @@ package binary import ( + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/doltgresql/server/functions" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -82,5 +82,27 @@ var numeric_mod = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: functions.NumericModCallable, + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { + return pgtypes.NumericNaN, nil + } + if num2.IsZero() { + return nil, errors.Errorf("division by zero") + } + if num1.Form == apd.Infinite { + return num1, nil + } + if num2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + c := apd.BaseContext.WithPrecision(1000000) + _, err := c.Rem(&num1, &num1, &num2) + if err != nil { + return nil, err + } + return num1, nil + }, } diff --git a/server/functions/div.go b/server/functions/div.go index d3aae61072..a54f2d918d 100644 --- a/server/functions/div.go +++ b/server/functions/div.go @@ -34,33 +34,28 @@ var div_numeric = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: NumericDivCallable, -} - -// NumericDivCallable is the callable logic for the numeric_div and div functions. -func NumericDivCallable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) - if num1.Form == apd.NaN || num2.Form == apd.NaN || - (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { - return pgtypes.NumericNaN, nil - } - if num2.IsZero() { - return nil, errors.Errorf("division by zero") - } - if num1.Form == apd.Infinite { + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { + return pgtypes.NumericNaN, nil + } + if num2.IsZero() { + return nil, errors.Errorf("division by zero") + } + if num1.Form == apd.Infinite { + return num1, nil + } + if num2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + // TODO: calculate precision and scale accurately + c := apd.BaseContext.WithPrecision(1000000) + _, err := c.QuoInteger(&num1, &num1, &num2) + if err != nil { + return nil, err + } return num1, nil - } - if num2.Form == apd.Infinite { - return *apd.New(0, 0), nil - } - _, err := pgtypes.BaseContext.Quo(&num1, &num1, &num2) - if err != nil { - return nil, err - } - _, err = pgtypes.BaseContext.Quantize(&num1, &num1, -16) - if err != nil { - return nil, err - } - return num1, nil + }, } diff --git a/server/functions/ln.go b/server/functions/ln.go index 69c05371a1..6f678efa58 100644 --- a/server/functions/ln.go +++ b/server/functions/ln.go @@ -16,6 +16,7 @@ package functions import ( "math" + "strings" "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" @@ -60,7 +61,23 @@ var ln_numeric = framework.Function1{ } else if dec.Sign() < 0 { return nil, errors.Errorf("cannot take logarithm of a negative number") } - _, err := pgtypes.BaseContext.Ln(&dec, &dec) + + // TODO: calculate precision and scale accurately + s := dec.Text('f') + parts := strings.Split(s, ".") + + exp := int32(-16) + if dec.Exponent < exp { + exp = dec.Exponent + } + p := uint32(len(parts[0]) + int(-exp)) + + c := apd.BaseContext.WithPrecision(p) + _, err := c.Ln(&dec, &dec) + if err != nil { + return nil, err + } + _, err = c.Quantize(&dec, &dec, exp) if err != nil { return nil, err } diff --git a/server/functions/log.go b/server/functions/log.go index d3b60c6090..ff759ef745 100644 --- a/server/functions/log.go +++ b/server/functions/log.go @@ -16,6 +16,7 @@ package functions import ( "math" + "strings" "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" @@ -62,7 +63,14 @@ var log_numeric = framework.Function1{ } else if dec.Sign() < 0 { return nil, errors.Errorf("cannot take logarithm of a negative number") } - _, err := pgtypes.BaseContext.Log10(&dec, &dec) + + // TODO: calculate precision and scale accurately + p := uint32(17) + if dec.Exponent < 0 { + p += uint32(-dec.Exponent) + } + c := apd.BaseContext.WithPrecision(p) + _, err := c.Log10(&dec, &dec) if err != nil { return nil, err } @@ -84,21 +92,44 @@ var log_numeric_numeric = framework.Function2{ } else if base.Sign() < 0 || num.Sign() < 0 { return nil, errors.Errorf("cannot take logarithm of a negative number") } - logBase := new(apd.Decimal) - _, err := pgtypes.BaseContext.Log10(&base, &base) + + // TODO: calculate precision and scale accurately + sNum := num.Text('f') + sBase := base.Text('f') + partsNum := strings.Split(sNum, ".") + partsBase := strings.Split(sBase, ".") + exp := int32(-16) + if minExp := math.Min(float64(base.Exponent), float64(num.Exponent)); int32(minExp) < exp { + exp = int32(minExp) + } + p := uint32(int32(math.Max(float64(len(partsNum[0])), float64(len(partsBase[0])))) + (-exp)) + c := apd.BaseContext.WithPrecision(p) + + lnBase := new(apd.Decimal) + _, err := c.Ln(lnBase, &base) if err != nil { return nil, err } - logNum := new(apd.Decimal) - _, err = pgtypes.BaseContext.Log10(&num, &num) + if lnBase.IsZero() { + return nil, errors.Errorf("division by zero") + } + + lnNum := new(apd.Decimal) + _, err = c.Ln(lnNum, &num) if err != nil { return nil, err } - if logNum.IsZero() { - return nil, errors.Errorf("division by zero") + if lnNum.IsZero() { + return *apd.New(0, -16), nil } + res := new(apd.Decimal) - _, err = pgtypes.BaseContext.Quo(res, logNum, logBase) + _, err = c.Quo(res, lnNum, lnBase) + if err != nil { + return nil, err + } + + _, err = c.Quantize(res, res, exp) if err != nil { return nil, err } diff --git a/server/functions/mod.go b/server/functions/mod.go index 85d1bcf407..9afdb7bc67 100644 --- a/server/functions/mod.go +++ b/server/functions/mod.go @@ -80,29 +80,27 @@ var mod_numeric_numeric = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: NumericModCallable, -} - -// NumericModCallable is the callable logic for the numeric_mod and mod functions. -func NumericModCallable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) - if num1.Form == apd.NaN || num2.Form == apd.NaN || - (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { - return pgtypes.NumericNaN, nil - } - if num2.IsZero() { - return nil, errors.Errorf("division by zero") - } - if num1.Form == apd.Infinite { + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { + return pgtypes.NumericNaN, nil + } + if num2.IsZero() { + return nil, errors.Errorf("division by zero") + } + if num1.Form == apd.Infinite { + return num1, nil + } + if num2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + c := apd.BaseContext.WithPrecision(1000000) + _, err := c.Rem(&num1, &num1, &num2) + if err != nil { + return nil, err + } return num1, nil - } - if num2.Form == apd.Infinite { - return *apd.New(0, 0), nil - } - _, err := pgtypes.BaseContext.Rem(&num1, &num1, &num2) - if err != nil { - return nil, err - } - return num1, nil + }, } diff --git a/server/functions/sqrt.go b/server/functions/sqrt.go index adcff81157..bdbcbb1d47 100644 --- a/server/functions/sqrt.go +++ b/server/functions/sqrt.go @@ -16,6 +16,7 @@ package functions import ( "math" + "strings" "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" @@ -56,7 +57,29 @@ var sqrt_numeric = framework.Function1{ if dec.Sign() < 0 { return nil, errors.Errorf("cannot take square root of a negative number") } - _, err := pgtypes.BaseContext.Sqrt(&dec, &dec) + + // TODO: calculate precision and scale accurately + s := dec.Text('f') + parts := strings.Split(s, ".") + + exp := int32(-16) + whole := int32(len(parts[0]) / 2) + if dec.Exponent == 0 { + exp = whole - 16 + } else if dec.Exponent < -16 { + exp = dec.Exponent + } + p := uint32(whole) + 1 + if exp < 0 { + p += uint32(-exp) + } + + c := apd.BaseContext.WithPrecision(p) + _, err := c.Sqrt(&dec, &dec) + if err != nil { + return nil, err + } + _, err = c.Quantize(&dec, &dec, exp) if err != nil { return nil, err } diff --git a/server/functions/trunc.go b/server/functions/trunc.go index a432e82cc1..4b37778a06 100644 --- a/server/functions/trunc.go +++ b/server/functions/trunc.go @@ -50,7 +50,9 @@ var trunc_numeric = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { dec := val.(apd.Decimal) - _, err := pgtypes.BaseContext.Quantize(&dec, &dec, 0) + // TODO: calculate precision and scale accurately + c := apd.BaseContext.WithPrecision(1000000) + _, err := c.Quantize(&dec, &dec, 0) if err != nil { return nil, err } @@ -68,7 +70,9 @@ var trunc_numeric_int64 = framework.Function2{ //TODO: test for negative values in places dec := val1.(apd.Decimal) places := val2.(int32) - _, err := pgtypes.BaseContext.Quantize(&dec, &dec, -places) + // TODO: calculate precision and scale accurately + c := apd.BaseContext.WithPrecision(1000000) + _, err := c.Quantize(&dec, &dec, -places) if err != nil { return nil, err } diff --git a/server/types/numeric.go b/server/types/numeric.go index ecada41a01..047013f88e 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -55,7 +55,7 @@ var ( NumericNaN = apd.Decimal{Form: apd.NaN} NumericInf = apd.Decimal{Form: apd.Infinite} NumericNegInf = apd.Decimal{Form: apd.Infinite, Negative: true} - BaseContext = apd.BaseContext.WithPrecision(MaxPrecision) + BaseContext = apd.BaseContext // .WithPrecision(MaxPrecision) ) // Numeric is a precise and unbounded decimal value. @@ -319,12 +319,6 @@ func deserializeTypeNumeric(ctx *sql.Context, t *DoltgresType, data []byte) (any } } - // If weight is negative, we need leading zeros after decimal point - if weight < 0 { - // This logic can get complex; using apd.SetString is the safest path - // but ensure the decimal point is placed correctly based on dscale. - } - dec, _, err := BaseContext.NewFromString(sb.String()) if err != nil { return nil, err diff --git a/testing/generation/function_coverage/output/framework_test.go b/testing/generation/function_coverage/output/framework_test.go index 5ba8d38101..9188a40b33 100644 --- a/testing/generation/function_coverage/output/framework_test.go +++ b/testing/generation/function_coverage/output/framework_test.go @@ -248,15 +248,15 @@ func Numeric(str string) pgtype.Numeric { // NumericToDecimal converts a pgtype.Numeric value to a decimal.Decimal value. func NumericToDecimal(val pgtype.Numeric) apd.Decimal { - strVal, err := val.Value() - if err != nil { - panic(err) - } - d, _, err := apd.NewFromString(strVal.(string)) - if err != nil { - panic(err) + if val.NaN { + return pgtypes.NumericNaN + } else if val.InfinityModifier == pgtype.Infinity { + return pgtypes.NumericInf + } else if val.InfinityModifier == pgtype.NegativeInfinity { + return pgtypes.NumericNegInf } - return *d + + return *apd.New(val.Int.Int64(), val.Exp) } // CompareResults compares two sets of results, taking the equivalence thresholds into account when making the @@ -302,7 +302,6 @@ func CompareRows(t *testing.T, a sql.Row, b sql.Row) bool { return false } aDec = *aDec.Abs(&aDec) - _, err = pgtypes.BaseContext.Sub(&aDec, &aDec, &bDec) // EquivalenceThresholdNumeric represents the allowable delta for values to be considered equivalent. // This is computed using the float64 variant so that they're equivalent. EquivalenceThresholdNumeric, _, _ := apd.NewFromString(strconv.FormatFloat(EquivalenceThresholdFloat64, 'f', -1, 64)) From ad4ed71b5f8eb5d148316b93ce2f3cfbe2279eb9 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 5 May 2026 15:49:08 -0700 Subject: [PATCH 04/16] update for dolt and gms bumps --- go.mod | 7 +- go.sum | 16 +- server/analyzer/resolve_values_types.go | 2 +- server/analyzer/type_sanitizer.go | 5 +- server/cast/int64.go | 11 +- server/cast/jsonb.go | 6 +- server/cast/numeric.go | 24 +- server/expression/gms_cast.go | 13 +- server/functions/binary/divide.go | 58 ++--- server/functions/binary/minus.go | 2 +- server/functions/binary/mod.go | 3 +- server/functions/binary/multiply.go | 2 +- server/functions/binary/plus.go | 2 +- server/functions/ceil.go | 2 +- server/functions/date_part.go | 2 +- server/functions/div.go | 4 +- server/functions/exp.go | 2 +- server/functions/extract.go | 2 +- server/functions/floor.go | 2 +- server/functions/generate_series.go | 3 +- server/functions/ln.go | 4 +- server/functions/log.go | 4 +- server/functions/mod.go | 3 +- server/functions/numeric.go | 186 ++++++++++++++- server/functions/oid.go | 3 +- server/functions/power.go | 10 +- server/functions/round.go | 8 +- server/functions/sqrt.go | 2 +- server/functions/to_char.go | 2 +- server/functions/trunc.go | 9 +- server/functions/width_bucket.go | 10 +- server/tables/dtables/ignore.go | 2 +- server/tables/dtables/rebase.go | 4 +- server/types/numeric.go | 213 ++---------------- server/types/typeinfo.go | 4 +- .../output/framework_test.go | 4 +- testing/go/framework.go | 6 - 37 files changed, 306 insertions(+), 336 deletions(-) diff --git a/go.mod b/go.mod index 059f6596eb..dbfef95cf6 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,13 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v3 v3.2.3 github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20260430172110-36fcc634f302 + github.com/dolthub/dolt/go v0.40.5-0.20260505224614-63ad07a5e185 github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.20.1-0.20260427172105-a0b357da2f1d + github.com/dolthub/go-mysql-server v0.20.1-0.20260505215953-bee2a70fec58 github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 - github.com/dolthub/vitess v0.0.0-20260424215137-ec6bd432b0be + github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390 github.com/fatih/color v1.13.0 github.com/goccy/go-json v0.10.2 github.com/gogo/protobuf v1.3.2 @@ -31,7 +31,6 @@ require ( github.com/pierrre/geohash v1.0.0 github.com/pkg/profile v1.5.0 github.com/sergi/go-diff v1.1.0 - github.com/shopspring/decimal v1.4.0 github.com/sirupsen/logrus v1.8.3 github.com/stretchr/testify v1.11.1 github.com/twpayne/go-geom v1.3.6 diff --git a/go.sum b/go.sum index af4cd77b16..b99ac82e2f 100644 --- a/go.sum +++ b/go.sum @@ -245,8 +245,8 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= github.com/dolthub/dolt-mcp v0.3.4 h1:AyG5cw+fNWXDHXujtQnqUPZrpWtPg6FN6yYtjv1pP44= github.com/dolthub/dolt-mcp v0.3.4/go.mod h1:bCZ7KHvDYs+M0e+ySgmGiNvLhcwsN7bbf5YCyillLrk= -github.com/dolthub/dolt/go v0.40.5-0.20260430172110-36fcc634f302 h1:rsAzhociWfQbOyPGaoPOh6gHLEG3dXO9NM6kHVp6Zcc= -github.com/dolthub/dolt/go v0.40.5-0.20260430172110-36fcc634f302/go.mod h1:+OAy6JNtk6d+T6WbDqSOvhwjapYmmmH0cKHI0VGygwY= +github.com/dolthub/dolt/go v0.40.5-0.20260505224614-63ad07a5e185 h1:0sG2SfllH8RXmL7GJVkO0A1iA4FvfrZwfhsXJKwvwpg= +github.com/dolthub/dolt/go v0.40.5-0.20260505224614-63ad07a5e185/go.mod h1:gg5m62C/jboequKCO9PmuTmE97PTK3AC5+WW/5r4FAc= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 h1:JShhbqMw26nKx3pqqu/cFxOpzBkN+4elVhzuUfgDw2k= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69/go.mod h1:SSLraQS/jGLYFgff3vuZ+JbVUct6vyEeMzjLBqWqoyM= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -255,8 +255,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866 h1:U6gSf5I0e6h6GP1/5Sa7D2lWW1CWfcVPtY5wkyHq6jY= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE= -github.com/dolthub/go-mysql-server v0.20.1-0.20260427172105-a0b357da2f1d h1:0ptkeudblr8Ee6fmah6T/1nmnqAC5XXUmHiknZiGL58= -github.com/dolthub/go-mysql-server v0.20.1-0.20260427172105-a0b357da2f1d/go.mod h1:O43PPQxMeNi7O5idizj6Itf2TZcSYfI/0WU24xhXg4I= +github.com/dolthub/go-mysql-server v0.20.1-0.20260505215953-bee2a70fec58 h1:vaw9NBd3aI2qo09GfUNktff8zpTFePIeZw+upfVb4qc= +github.com/dolthub/go-mysql-server v0.20.1-0.20260505215953-bee2a70fec58/go.mod h1:uE//4hYOi/1zIgOYgr8D+u7OvDfI4yPXFsL4dTZWk5I= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20260414231531-5f031e3e9037 h1:oIW9HwuWrhxv+4HZxA+QQSKHLqWFyXZ2FmNjUYwkdiM= @@ -267,8 +267,8 @@ github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 h1:GY17cGA4 github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1/go.mod h1:qnrZP3/1slFl2Bq5yw38HLOsArZareGwdpEceriblLc= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= -github.com/dolthub/vitess v0.0.0-20260424215137-ec6bd432b0be h1:m01DFtn4xQhLWni8bCbRRjcjeiOZkAgaj2fOWGpPr6A= -github.com/dolthub/vitess v0.0.0-20260424215137-ec6bd432b0be/go.mod h1:dKAkzdfRkAudpc0g8JOQ0eiEjV83TYIFz/yNIEdcjXM= +github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390 h1:FjUnJYan3i3mrk/i+qENwFOQhA+XfH+wgAucbgi/4sQ= +github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390/go.mod h1:dKAkzdfRkAudpc0g8JOQ0eiEjV83TYIFz/yNIEdcjXM= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= @@ -717,8 +717,8 @@ github.com/shirou/gopsutil/v4 v4.25.12 h1:e7PvW/0RmJ8p8vPGJH4jvNkOyLmbkXgXW4m6ZP github.com/shirou/gopsutil/v4 v4.25.12/go.mod h1:EivAfP5x2EhLp2ovdpKSozecVXn1TmuG7SMzs/Wh4PU= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= -github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= -github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= diff --git a/server/analyzer/resolve_values_types.go b/server/analyzer/resolve_values_types.go index 002434aeab..f319648fb1 100644 --- a/server/analyzer/resolve_values_types.go +++ b/server/analyzer/resolve_values_types.go @@ -113,7 +113,7 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s // MIN now returns numeric, so GroupBy produces numeric. But the // Project's GetField still says int4 because its tableId=GroupBy, // which wasn't in transformedVDTs. At runtime this causes a panic - // because the actual value is decimal.Decimal but the type says int32. + // because the actual value is apd.Decimal but the type says int32. // // This pass catches those: for each GetField, check if its type // disagrees with what the child node actually produces. diff --git a/server/analyzer/type_sanitizer.go b/server/analyzer/type_sanitizer.go index df6e7480e7..e74cab427c 100644 --- a/server/analyzer/type_sanitizer.go +++ b/server/analyzer/type_sanitizer.go @@ -28,7 +28,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/shopspring/decimal" pgexprs "github.com/dolthub/doltgresql/server/expression" "github.com/dolthub/doltgresql/server/functions/framework" @@ -174,11 +173,11 @@ func typeSanitizerLiterals(ctx *sql.Context, gmsLiteral *expression.Literal) (sq } return pgexprs.NewRawLiteralFloat64(newVal.(float64)), transform.NewTree, nil case query.Type_DECIMAL: - dec, ok := gmsLiteral.Value().(decimal.Decimal) + dec, ok := gmsLiteral.Value().(apd.Decimal) if !ok { return nil, transform.NewTree, errors.Errorf("SANITIZER: expected decimal type: %T", gmsLiteral.Value()) } - return pgexprs.NewRawLiteralNumeric(*apd.New(dec.Coefficient().Int64(), dec.Exponent())), transform.NewTree, nil + return pgexprs.NewRawLiteralNumeric(dec), transform.NewTree, nil case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP: newVal, _, err := types.Datetime.Convert(ctx, gmsLiteral.Value()) if err != nil { diff --git a/server/cast/int64.go b/server/cast/int64.go index b56700d3ee..7da9a674a8 100644 --- a/server/cast/int64.go +++ b/server/cast/int64.go @@ -15,9 +15,10 @@ package cast import ( + "math" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/core/id" @@ -82,7 +83,7 @@ func int64Implicit() { FromType: pgtypes.Int64, ToType: pgtypes.Oid, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { + if val.(int64) > int64(math.MaxUint32) || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } if internalID := id.Cache().ToInternal(uint32(val.(int64))); internalID.IsValid() { @@ -95,7 +96,7 @@ func int64Implicit() { FromType: pgtypes.Int64, ToType: pgtypes.Regclass, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { + if val.(int64) > int64(math.MaxUint32) || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } if internalID := id.Cache().ToInternal(uint32(val.(int64))); internalID.IsValid() { @@ -108,7 +109,7 @@ func int64Implicit() { FromType: pgtypes.Int64, ToType: pgtypes.Regproc, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { + if val.(int64) > int64(math.MaxUint32) || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } if internalID := id.Cache().ToInternal(uint32(val.(int64))); internalID.IsValid() { @@ -121,7 +122,7 @@ func int64Implicit() { FromType: pgtypes.Int64, ToType: pgtypes.Regtype, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - if val.(int64) > pgtypes.MaxUint32 || val.(int64) < 0 { + if val.(int64) > int64(math.MaxUint32) || val.(int64) < 0 { return nil, errOutOfRange.New(targetType.String()) } if internalID := id.Cache().ToInternal(uint32(val.(int64))); internalID.IsValid() { diff --git a/server/cast/jsonb.go b/server/cast/jsonb.go index 2f9b3614c4..be60923e3c 100644 --- a/server/cast/jsonb.go +++ b/server/cast/jsonb.go @@ -115,7 +115,7 @@ func jsonbExplicit() { return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: d := apd.Decimal(value) - if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { + if d.Cmp(&pgtypes.NumericValueMinInt16) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Errorf("smallint out of range") } i, err := d.Int64() @@ -145,7 +145,7 @@ func jsonbExplicit() { return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: d := apd.Decimal(value) - if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { + if d.Cmp(&pgtypes.NumericValueMinInt32) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Errorf("integer out of range") } i, err := d.Int64() @@ -175,7 +175,7 @@ func jsonbExplicit() { return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: d := apd.Decimal(value) - if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { + if d.Cmp(&pgtypes.NumericValueMinInt64) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Errorf("bigint out of range") } i, err := d.Int64() diff --git a/server/cast/numeric.go b/server/cast/numeric.go index 761d38b506..7f9d859219 100644 --- a/server/cast/numeric.go +++ b/server/cast/numeric.go @@ -17,8 +17,8 @@ package cast import ( "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" - "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -37,13 +37,10 @@ func numericAssignment() { ToType: pgtypes.Int16, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { d := val.(apd.Decimal) - if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { + if d.Cmp(&pgtypes.NumericValueMinInt16) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") } - i, err := d.Int64() - if err != nil { - return nil, err - } + i := types.DecimalIntPart(d) return int16(i), nil }, }) @@ -52,13 +49,10 @@ func numericAssignment() { ToType: pgtypes.Int32, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { d := val.(apd.Decimal) - if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { + if d.Cmp(&pgtypes.NumericValueMinInt32) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") } - i, err := d.Int64() - if err != nil { - return nil, err - } + i := types.DecimalIntPart(d) return int32(i), nil }, }) @@ -67,14 +61,10 @@ func numericAssignment() { ToType: pgtypes.Int64, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { d := val.(apd.Decimal) - if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { + if d.Cmp(&pgtypes.NumericValueMinInt64) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range") } - i, err := d.Int64() - if err != nil { - return nil, err - } - return int64(i), nil + return types.DecimalIntPart(d), nil }, }) } diff --git a/server/expression/gms_cast.go b/server/expression/gms_cast.go index 0ea167e5c5..580eeb34b6 100644 --- a/server/expression/gms_cast.go +++ b/server/expression/gms_cast.go @@ -25,7 +25,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/shopspring/decimal" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -123,11 +122,11 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - dec, ok := newVal.(decimal.Decimal) + dec, ok := newVal.(apd.Decimal) if !ok { - return nil, errors.Errorf("GMSCast expected type `decimal.Decimal`, got `%T`", val) + return nil, errors.Errorf("GMSCast expected type `apd.Decimal`, got `%T`", val) } - return *apd.New(dec.CoefficientInt64(), dec.Exponent()), nil + return dec, nil case query.Type_FLOAT32: newVal, _, err := types.Float32.Convert(ctx, val) if err != nil { @@ -151,11 +150,11 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - dec, ok := newVal.(decimal.Decimal) + dec, ok := newVal.(apd.Decimal) if !ok { - return nil, errors.Errorf("GMSCast expected type `decimal.Decimal`, got `%T`", val) + return nil, errors.Errorf("GMSCast expected type `apd.Decimal`, got `%T`", val) } - return *apd.New(dec.CoefficientInt64(), dec.Exponent()), nil + return dec, nil case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP: if val, ok := val.(time.Time); ok { return val, nil diff --git a/server/functions/binary/divide.go b/server/functions/binary/divide.go index 13108de633..e5e6827fbc 100644 --- a/server/functions/binary/divide.go +++ b/server/functions/binary/divide.go @@ -284,39 +284,39 @@ var interval_div = framework.Function2{ Callable: interval_div_callable, } +// numeric_div_callable is the callable logic for the numeric_div function. +func numeric_div_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { + num1 := val1.(apd.Decimal) + num2 := val2.(apd.Decimal) + if num1.Form == apd.NaN || num2.Form == apd.NaN || + (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { + return pgtypes.NumericNaN, nil + } + if num2.IsZero() { + return nil, errors.Errorf("division by zero") + } + if num1.Form == apd.Infinite { + return num1, nil + } + if num2.Form == apd.Infinite { + return *apd.New(0, 0), nil + } + _, err := sql.HighPrecisionCtx.Quo(&num1, &num1, &num2) + if err != nil { + return nil, err + } + _, err = sql.DecimalCtx.Quantize(&num1, &num1, -16) + if err != nil { + return nil, err + } + return num1, nil +} + // numeric_div represents the PostgreSQL function of the same name, taking the same parameters. var numeric_div = framework.Function2{ Name: "numeric_div", Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, - Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) - if num1.Form == apd.NaN || num2.Form == apd.NaN || - (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { - return pgtypes.NumericNaN, nil - } - if num2.IsZero() { - return nil, errors.Errorf("division by zero") - } - if num1.Form == apd.Infinite { - return num1, nil - } - if num2.Form == apd.Infinite { - return *apd.New(0, 0), nil - } - // TODO: calculate precision and scale accurately - c := apd.BaseContext.WithPrecision(1000000) - _, err := c.QuoInteger(&num1, &num1, &num2) - if err != nil { - return nil, err - } - _, err = c.Quantize(&num1, &num1, -16) - if err != nil { - return nil, err - } - - return num1, nil - }, + Callable: numeric_div_callable, } diff --git a/server/functions/binary/minus.go b/server/functions/binary/minus.go index a2937749ee..1f622523c2 100644 --- a/server/functions/binary/minus.go +++ b/server/functions/binary/minus.go @@ -242,7 +242,7 @@ var numeric_sub = framework.Function2{ Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { num1 := val1.(apd.Decimal) num2 := val2.(apd.Decimal) - _, err := pgtypes.BaseContext.Sub(&num1, &num1, &num2) + _, err := sql.DecimalCtx.Sub(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/binary/mod.go b/server/functions/binary/mod.go index 564073d35c..a49c9be654 100644 --- a/server/functions/binary/mod.go +++ b/server/functions/binary/mod.go @@ -98,8 +98,7 @@ var numeric_mod = framework.Function2{ if num2.Form == apd.Infinite { return *apd.New(0, 0), nil } - c := apd.BaseContext.WithPrecision(1000000) - _, err := c.Rem(&num1, &num1, &num2) + _, err := sql.HighPrecisionCtx.Rem(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/binary/multiply.go b/server/functions/binary/multiply.go index 5377e7f2b0..8fb5ca9d98 100644 --- a/server/functions/binary/multiply.go +++ b/server/functions/binary/multiply.go @@ -230,7 +230,7 @@ var numeric_mul = framework.Function2{ if (num1.Form == apd.Infinite || num2.Form == apd.Infinite) && (num1.IsZero() || num2.IsZero()) { return pgtypes.NumericNaN, nil } - _, err := pgtypes.BaseContext.Mul(&num1, &num1, &num2) + _, err := sql.DecimalCtx.Mul(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/binary/plus.go b/server/functions/binary/plus.go index e3873de328..2250ffabb8 100644 --- a/server/functions/binary/plus.go +++ b/server/functions/binary/plus.go @@ -390,7 +390,7 @@ func numeric_add_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any return pgtypes.NumericInf, nil } - _, err := pgtypes.BaseContext.Add(&num1, &num1, &num2) + _, err := sql.DecimalCtx.Add(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/ceil.go b/server/functions/ceil.go index fc1e10ba17..e4dcc9a6bc 100644 --- a/server/functions/ceil.go +++ b/server/functions/ceil.go @@ -56,7 +56,7 @@ var ceil_numeric = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { dec := val.(apd.Decimal) - _, err := pgtypes.BaseContext.Ceil(&dec, &dec) + _, err := sql.DecimalCtx.Ceil(&dec, &dec) if err != nil { return nil, err } diff --git a/server/functions/date_part.go b/server/functions/date_part.go index 36facc27e5..f59e0be449 100644 --- a/server/functions/date_part.go +++ b/server/functions/date_part.go @@ -278,7 +278,7 @@ func numericFloor(val any) (apd.Decimal, error) { if err != nil { return apd.Decimal{}, err } - _, err = pgtypes.BaseContext.Floor(dec, dec) + _, err = sql.DecimalCtx.Floor(dec, dec) if err != nil { return apd.Decimal{}, err } diff --git a/server/functions/div.go b/server/functions/div.go index a54f2d918d..720bbf2e7e 100644 --- a/server/functions/div.go +++ b/server/functions/div.go @@ -50,9 +50,7 @@ var div_numeric = framework.Function2{ if num2.Form == apd.Infinite { return *apd.New(0, 0), nil } - // TODO: calculate precision and scale accurately - c := apd.BaseContext.WithPrecision(1000000) - _, err := c.QuoInteger(&num1, &num1, &num2) + _, err := sql.DecimalCtx.QuoInteger(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/exp.go b/server/functions/exp.go index dadbca2b91..82e08eb271 100644 --- a/server/functions/exp.go +++ b/server/functions/exp.go @@ -49,7 +49,7 @@ var exp_numeric = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { dec := val.(apd.Decimal) - _, err := pgtypes.BaseContext.WithPrecision(32).Exp(&dec, &dec) + _, err := sql.DecimalCtx.WithPrecision(32).Exp(&dec, &dec) if err != nil { return nil, err } diff --git a/server/functions/extract.go b/server/functions/extract.go index 2a6b833af3..c38a5c00a0 100644 --- a/server/functions/extract.go +++ b/server/functions/extract.go @@ -302,7 +302,7 @@ func numeric(val any, setScale bool, scale int32) (apd.Decimal, error) { return apd.Decimal{}, err } if setScale { - _, err = pgtypes.BaseContext.Quantize(dec, dec, -scale) + _, err = sql.DecimalCtx.Quantize(dec, dec, -scale) if err != nil { return apd.Decimal{}, err } diff --git a/server/functions/floor.go b/server/functions/floor.go index 7e58060fed..a3f51aadf8 100644 --- a/server/functions/floor.go +++ b/server/functions/floor.go @@ -49,7 +49,7 @@ var floor_numeric = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { dec := val.(apd.Decimal) - _, err := pgtypes.BaseContext.Floor(&dec, &dec) + _, err := sql.DecimalCtx.Floor(&dec, &dec) if err != nil { return nil, err } diff --git a/server/functions/generate_series.go b/server/functions/generate_series.go index d7e572557e..d414e5be7c 100644 --- a/server/functions/generate_series.go +++ b/server/functions/generate_series.go @@ -188,9 +188,8 @@ func numericGenerateSeries(start, stop, step apd.Decimal) (*pgtypes.SetReturning } return pgtypes.NewSetReturningFunctionRowIter(func(ctx *sql.Context) (sql.Row, error) { defer func() { - _, err := pgtypes.BaseContext.Add(&start, &start, &step) + _, err := sql.DecimalCtx.Add(&start, &start, &step) if err != nil { - // TODO panic(err) } }() diff --git a/server/functions/ln.go b/server/functions/ln.go index 6f678efa58..6fd95517d7 100644 --- a/server/functions/ln.go +++ b/server/functions/ln.go @@ -60,6 +60,8 @@ var ln_numeric = framework.Function1{ return nil, errors.Errorf("cannot take logarithm of zero") } else if dec.Sign() < 0 { return nil, errors.Errorf("cannot take logarithm of a negative number") + } else if dec.Form == apd.NaN || dec.Form == apd.Infinite { + return dec, nil } // TODO: calculate precision and scale accurately @@ -72,7 +74,7 @@ var ln_numeric = framework.Function1{ } p := uint32(len(parts[0]) + int(-exp)) - c := apd.BaseContext.WithPrecision(p) + c := sql.DecimalCtx.WithPrecision(p) _, err := c.Ln(&dec, &dec) if err != nil { return nil, err diff --git a/server/functions/log.go b/server/functions/log.go index ff759ef745..8b994564fa 100644 --- a/server/functions/log.go +++ b/server/functions/log.go @@ -69,7 +69,7 @@ var log_numeric = framework.Function1{ if dec.Exponent < 0 { p += uint32(-dec.Exponent) } - c := apd.BaseContext.WithPrecision(p) + c := sql.DecimalCtx.WithPrecision(p) _, err := c.Log10(&dec, &dec) if err != nil { return nil, err @@ -103,7 +103,7 @@ var log_numeric_numeric = framework.Function2{ exp = int32(minExp) } p := uint32(int32(math.Max(float64(len(partsNum[0])), float64(len(partsBase[0])))) + (-exp)) - c := apd.BaseContext.WithPrecision(p) + c := sql.DecimalCtx.WithPrecision(p) lnBase := new(apd.Decimal) _, err := c.Ln(lnBase, &base) diff --git a/server/functions/mod.go b/server/functions/mod.go index 9afdb7bc67..f861e6f4b1 100644 --- a/server/functions/mod.go +++ b/server/functions/mod.go @@ -96,8 +96,7 @@ var mod_numeric_numeric = framework.Function2{ if num2.Form == apd.Infinite { return *apd.New(0, 0), nil } - c := apd.BaseContext.WithPrecision(1000000) - _, err := c.Rem(&num1, &num1, &num2) + _, err := sql.HighPrecisionCtx.Rem(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/numeric.go b/server/functions/numeric.go index 47df1875cd..b5c2382bd1 100644 --- a/server/functions/numeric.go +++ b/server/functions/numeric.go @@ -15,14 +15,17 @@ package functions import ( + "encoding/binary" "fmt" "strconv" + "strings" "github.com/cockroachdb/apd/v3" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" + "github.com/dolthub/doltgresql/utils" ) // initNumeric registers the functions to the catalog. @@ -68,7 +71,6 @@ var numeric_out = framework.Function1{ return nil, err } return dec.Text('f'), nil - //return dec.StringFixed(dec.Exponent() * -1), nil }, } @@ -83,10 +85,89 @@ var numeric_recv = framework.Function3{ if data == nil { return nil, nil } - typmod := val3.(int32) - // TODO: chekc this doesn't update the original type - newType := *pgtypes.Numeric.WithAttTypMod(typmod) - return newType.DeserializationFunc(ctx, &newType, data) + //typmod := val3.(int32) + if len(data) == 0 { + return nil, nil + } + reader := utils.NewWireReader(data) + var d apd.Decimal + + // 1. Read Header + ndigits := reader.ReadInt16() + weight := reader.ReadInt16() + sign := reader.ReadInt16() + dscale := reader.ReadInt16() + + // 2. Handle Special Values (NaN, Inf) + // These usually manifest as specific bit patterns in the header + switch uint16(sign) { + case 0xC000: // pgNumericNaN + d.Form = apd.NaN + return d, nil + case 0xD000: // pgNumericPosInf + d.Form = apd.Infinite + return d, nil + case 0xF000: // pgNumericNegInf + d.Form = apd.Infinite + d.Negative = true + return d, nil + } + + // 3. Handle Finite Values + if ndigits == 0 { + d.SetInt64(0) + return d, nil + } + + // Read base-10000 digits + digits := make([]int16, ndigits) + for i := 0; i < int(ndigits); i++ { + digits[i] = reader.ReadInt16() + } + + // 4. Convert base-10000 to string for apd.Decimal + // Each digit is exactly 4 characters wide (except potentially the first) + var sb strings.Builder + if sign == 16384 { + sb.WriteByte('-') + } + + for i, digit := range digits { + // Calculate how many 10000-base digits are before the decimal + // 'weight' is the index of the first digit, where 0 is 10^0 in base 10000 + if i == int(weight)+1 { + sb.WriteByte('.') + } + + sDigit := strconv.Itoa(int(digit)) + // Pad with leading zeros if not the very first digit + if l := len(sDigit); l < 4 { + padding := 4 - l + for p := 0; p < padding; p++ { + sb.WriteByte('0') + } + } + sb.WriteString(sDigit) + } + + // If weight is larger than digits, we need trailing zeros + if int(weight) >= len(digits) { + for i := 0; i < int(weight)-len(digits)+1; i++ { + sb.WriteString("0000") + } + } + dec, _, err := sql.HighPrecisionCtx.NewFromString(sb.String()) + if err != nil { + return nil, err + } + str := dec.Text('f') + if str == " " { + } + _, err = sql.HighPrecisionCtx.Quantize(dec, dec, int32(-dscale)) + if err != nil { + return nil, err + } + return *dec, nil }, } @@ -97,8 +178,93 @@ var numeric_send = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) - return pgtypes.Numeric.SerializationFunc(ctx, t[0], dec) + num := val.(apd.Decimal) + typmod := t[0].GetAttTypMod() + writer := utils.NewWireWriter() + if num.Form == apd.Finite { + // Short-circuit if this is the zero value + if num.IsZero() { + writer.WriteBytes([]byte{0, 0, 0, 0, 0, 0, 0, 0}) + return writer.BufferData(), nil + } + // There's a way to do this more efficiently, but we can do that work once this becomes a performance issue. + // This is based on the terminology used in Postgres' `numeric.c` file + decStr := num.Text('f') + isNegative := false + if strings.HasPrefix(decStr, "-") { + isNegative = true + decStr = decStr[1:] + } + // Split the integer and fractional parts + var intPart string + var fractPart string + if idx := strings.Index(decStr, "."); idx != -1 { + intPart = decStr[:idx] + fractPart = decStr[idx+1:] + } else { + intPart = decStr + } + // Find the "dscale", which is the number of digits in the fractional part + var dscale int16 + if typmod != -1 { + _, dscale32 := pgtypes.GetPrecisionAndScaleFromTypmod(typmod) + dscale = int16(dscale32) + } else { + dscale = int16(len(fractPart)) + } + // Pad the integer and fractional parts so that we can take groups of 4 numbers + if intPart == "0" { + intPart = "" + } else if len(intPart)%4 != 0 { + intPart = strings.Repeat("0", 4-(len(intPart)%4)) + intPart + } + if len(fractPart)%4 != 0 { + // remove trailing zeroes on right side before filling it. + fractPart = strings.TrimRightFunc(fractPart, func(r rune) bool { + return r == '0' + }) + fractPart = fractPart + strings.Repeat("0", 4-(len(fractPart)%4)) + } + // Write the "ndigits" first, or the number of base-10000 digits + writer.WriteInt16(int16((len(intPart) / 4) + (len(fractPart) / 4))) + // Write the "weight", which is the number of base-10000 digits in the integer part subtracted by 1 + writer.WriteInt16(int16((len(intPart) / 4) - 1)) + // Write the "sign" + if isNegative { + writer.WriteInt16(16384) + } else { + writer.WriteInt16(0) + } + // Write the "dscale" + writer.WriteInt16(dscale) + // Write all of the digits + fullPart := intPart + fractPart + for i := 0; i < len(fullPart); i += 4 { + part, err := strconv.Atoi(fullPart[i : i+4]) + if err != nil { + return nil, err + } + writer.WriteInt16(int16(part)) + } + } else { + var buf []byte + wp := len(buf) + buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) + if num.Form == apd.NaN { + binary.BigEndian.PutUint64(buf[wp:], pgNumericNaN) + } else if num.Form == apd.Infinite { + if num.Negative { + binary.BigEndian.PutUint64(buf[wp:], pgNumericNegInf) + } else { + binary.BigEndian.PutUint64(buf[wp:], pgNumericPosInf) + } + } + if typmod == -1 { + binary.BigEndian.PutUint16(buf[6:], uint16(32)) + } + writer.WriteBytes(buf) + } + return writer.BufferData(), nil }, } @@ -158,3 +324,9 @@ var numeric_cmp = framework.Function2{ return int32(pgtypes.NumericCompare(ab, bb)), nil }, } + +const ( + pgNumericNaN = 0x00000000c0000000 + pgNumericPosInf = 0x00000000d0000000 + pgNumericNegInf = 0x00000000f0000000 +) diff --git a/server/functions/oid.go b/server/functions/oid.go index 115efd6044..68099eb9a7 100644 --- a/server/functions/oid.go +++ b/server/functions/oid.go @@ -17,6 +17,7 @@ package functions import ( "cmp" "fmt" + "math" "strconv" "strings" @@ -50,7 +51,7 @@ var oidin = framework.Function1{ return id.Null, pgtypes.ErrInvalidSyntaxForType.New("oid", input) } // Note: This minimum is different (-4294967295) for Postgres 15.4 compiled by Visual C++ - if iVal > pgtypes.MaxUint32 || iVal < pgtypes.MinInt32 { + if iVal > int64(math.MaxUint32) || iVal < int64(math.MinInt32) { return id.Null, pgtypes.ErrValueIsOutOfRangeForType.New(input, "oid") } uVal := uint32(iVal) diff --git a/server/functions/power.go b/server/functions/power.go index d82aeb42e1..45170657cf 100644 --- a/server/functions/power.go +++ b/server/functions/power.go @@ -34,7 +34,7 @@ func initPower() { var ( // errPowerZeroToNegative is an error for raising zero to a negative power in the "power" functions. errPowerZeroToNegative = errors.New("zero raised to a negative power is undefined") - // numericOne is equivalent to decimal.NewFromInt(1), but represented as a value for the sake of efficiency. + // numericOne is equivalent to apt.NewFromInt(1, 0), but represented as a value for the sake of efficiency. numericOne = apd.New(1, 0) ) @@ -98,7 +98,7 @@ var power_numeric_numeric = framework.Function2{ } if dec2.Sign() > 0 { d := *apd.New(0, 0) - _, _ = pgtypes.BaseContext.Quantize(&d, &d, -16) + _, _ = sql.DecimalCtx.Quantize(&d, &d, -16) return d, nil } } @@ -106,15 +106,15 @@ var power_numeric_numeric = framework.Function2{ if dec2.IsZero() || dec1.Cmp(numericOne) == 0 { d := *apd.New(1, 0) - _, _ = pgtypes.BaseContext.Quantize(&d, &d, -16) + _, _ = sql.DecimalCtx.Quantize(&d, &d, -16) return d, nil } // give enough precision that we can round it to 16 exp - _, err := pgtypes.BaseContext.WithPrecision(17).Pow(&dec1, &dec1, &dec2) + _, err := sql.DecimalCtx.WithPrecision(17).Pow(&dec1, &dec1, &dec2) if err != nil { return nil, err } - _, err = pgtypes.BaseContext.Quantize(&dec1, &dec1, -16) + _, err = sql.DecimalCtx.Quantize(&dec1, &dec1, -16) if err != nil { return nil, err } diff --git a/server/functions/round.go b/server/functions/round.go index a4d474850d..17a868548a 100644 --- a/server/functions/round.go +++ b/server/functions/round.go @@ -50,11 +50,11 @@ var round_numeric = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { dec := val.(apd.Decimal) - _, err := pgtypes.BaseContext.Round(&dec, &dec) + _, err := sql.DecimalCtx.Round(&dec, &dec) if err != nil { return nil, err } - _, err = pgtypes.BaseContext.Quantize(&dec, &dec, 0) + _, err = sql.DecimalCtx.Quantize(&dec, &dec, 0) if err != nil { return nil, err } @@ -71,11 +71,11 @@ var round_numeric_int64 = framework.Function2{ Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { dec := val1.(apd.Decimal) places := val2.(int64) - _, err := pgtypes.BaseContext.Round(&dec, &dec) + _, err := sql.DecimalCtx.Round(&dec, &dec) if err != nil { return nil, err } - _, err = pgtypes.BaseContext.Quantize(&dec, &dec, int32(-places)) + _, err = sql.DecimalCtx.Quantize(&dec, &dec, int32(-places)) if err != nil { return nil, err } diff --git a/server/functions/sqrt.go b/server/functions/sqrt.go index bdbcbb1d47..6393c36d68 100644 --- a/server/functions/sqrt.go +++ b/server/functions/sqrt.go @@ -74,7 +74,7 @@ var sqrt_numeric = framework.Function1{ p += uint32(-exp) } - c := apd.BaseContext.WithPrecision(p) + c := sql.DecimalCtx.WithPrecision(p) _, err := c.Sqrt(&dec, &dec) if err != nil { return nil, err diff --git a/server/functions/to_char.go b/server/functions/to_char.go index 00463f3c12..d948e8db0a 100644 --- a/server/functions/to_char.go +++ b/server/functions/to_char.go @@ -138,7 +138,7 @@ var to_char_numeric_text = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Text}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - //timestamp := val1.(decimal.Decimal) + //timestamp := val1.(apd.Decimal) //format := val2.(string) return nil, errors.Errorf(`to_char(numeric,text) is not supported yet`) diff --git a/server/functions/trunc.go b/server/functions/trunc.go index 4b37778a06..0818830a49 100644 --- a/server/functions/trunc.go +++ b/server/functions/trunc.go @@ -50,9 +50,7 @@ var trunc_numeric = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { dec := val.(apd.Decimal) - // TODO: calculate precision and scale accurately - c := apd.BaseContext.WithPrecision(1000000) - _, err := c.Quantize(&dec, &dec, 0) + _, err := sql.HighPrecisionCtx.Quantize(&dec, &dec, 0) if err != nil { return nil, err } @@ -67,12 +65,9 @@ var trunc_numeric_int64 = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Int32}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - //TODO: test for negative values in places dec := val1.(apd.Decimal) places := val2.(int32) - // TODO: calculate precision and scale accurately - c := apd.BaseContext.WithPrecision(1000000) - _, err := c.Quantize(&dec, &dec, -places) + _, err := sql.HighPrecisionCtx.Quantize(&dec, &dec, -places) if err != nil { return nil, err } diff --git a/server/functions/width_bucket.go b/server/functions/width_bucket.go index 86145db24b..92087eed51 100644 --- a/server/functions/width_bucket.go +++ b/server/functions/width_bucket.go @@ -87,24 +87,24 @@ var width_bucket_numeric_numeric_numeric_int64 = framework.Function4{ return int32(1), nil } bucket := new(apd.Decimal) - _, err := pgtypes.BaseContext.Sub(bucket, &high, &low) + _, err := sql.DecimalCtx.Sub(bucket, &high, &low) if err != nil { return nil, err } - _, err = pgtypes.BaseContext.Quo(bucket, bucket, apd.New(int64(count), 0)) + _, err = sql.DecimalCtx.Quo(bucket, bucket, apd.New(int64(count), 0)) if err != nil { return nil, err } result := new(apd.Decimal) - _, err = pgtypes.BaseContext.Sub(result, &operand, &low) + _, err = sql.DecimalCtx.Sub(result, &operand, &low) if err != nil { return nil, err } - _, err = pgtypes.BaseContext.Sub(result, result, bucket) + _, err = sql.DecimalCtx.Sub(result, result, bucket) if err != nil { return nil, err } - _, err = pgtypes.BaseContext.Ceil(result, result) + _, err = sql.DecimalCtx.Ceil(result, result) if err != nil { return nil, err } diff --git a/server/tables/dtables/ignore.go b/server/tables/dtables/ignore.go index db0ac0125f..ea3fe4290a 100644 --- a/server/tables/dtables/ignore.go +++ b/server/tables/dtables/ignore.go @@ -60,7 +60,7 @@ func convertTupleToIgnoreBoolean(ctx context.Context, valueDesc *val.TupleDesc, // getIgnoreTablePatternKey reads the pattern key from a tuple and returns it. func getIgnoreTablePatternKey(ctx context.Context, keyDesc *val.TupleDesc, keyTuple val.Tuple) (string, error) { - key, ok, err := keyDesc.GetStringAdaptiveValue(0, nil, keyTuple) + key, ok, err := keyDesc.GetStringAdaptiveValue(ctx, 0, nil, keyTuple) if err != nil { return "", err } diff --git a/server/tables/dtables/rebase.go b/server/tables/dtables/rebase.go index a8a4da93b7..99ccbc54f3 100644 --- a/server/tables/dtables/rebase.go +++ b/server/tables/dtables/rebase.go @@ -22,7 +22,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/rebase" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dprocedures" "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" + gmstypes "github.com/dolthub/go-mysql-server/sql/types" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -89,7 +89,7 @@ func convertRowToRebasePlanStep(ctx context.Context, row sql.Row) (rebase.Rebase } return rebase.RebasePlanStep{ - RebaseOrder: decimal.NewFromFloat32(order), + RebaseOrder: gmstypes.DecimalFromFloat32(order), Action: rebaseAction, CommitHash: commitHash, CommitMsg: commitMsg, diff --git a/server/types/numeric.go b/server/types/numeric.go index 047013f88e..f8d3ec8701 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -15,47 +15,26 @@ package types import ( - "encoding/binary" - "strconv" - "strings" + "math" "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/utils" -) - -const ( - MaxUint32 = 4294967295 // MaxUint32 is the largest possible value of Uint32 - MinInt32 = -2147483648 // MinInt32 is the smallest possible value of Int32 - MaxPrecision = uint32(100000) -) - -const ( - pgNumericNaN = 0x00000000c0000000 - pgNumericNaNSign = 0xc000 - - pgNumericPosInf = 0x00000000d0000000 - pgNumericPosInfSign = 0xd000 - - pgNumericNegInf = 0x00000000f0000000 - pgNumericNegInfSign = 0xf000 ) var ( - NumericValueMaxInt16 = apd.New(32767, 0) // NumericValueMaxInt16 is the max Int16 value for NUMERIC types - NumericValueMaxInt32 = apd.New(2147483647, 0) // NumericValueMaxInt32 is the max Int32 value for NUMERIC types - NumericValueMaxInt64 = apd.New(9223372036854775807, 0) // NumericValueMaxInt64 is the max Int64 value for NUMERIC types - NumericValueMinInt16 = apd.New(-32768, 0) // NumericValueMinInt16 is the min Int16 value for NUMERIC types - NumericValueMinInt32 = apd.New(MinInt32, 0) // NumericValueMinInt32 is the min Int32 value for NUMERIC types - NumericValueMinInt64 = apd.New(-9223372036854775808, 0) // NumericValueMinInt64 is the min Int64 value for NUMERIC types - NumericValueMaxUint32 = apd.New(MaxUint32, 0) // NumericValueMaxUint32 is the max Uint32 value for NUMERIC types - NumericNaN = apd.Decimal{Form: apd.NaN} - NumericInf = apd.Decimal{Form: apd.Infinite} - NumericNegInf = apd.Decimal{Form: apd.Infinite, Negative: true} - BaseContext = apd.BaseContext // .WithPrecision(MaxPrecision) + NumericValueMaxInt16 = types.DecimalFromInt64(math.MaxInt16) // NumericValueMaxInt16 is the max Int16 value for NUMERIC types + NumericValueMaxInt32 = types.DecimalFromInt64(math.MaxInt32) // NumericValueMaxInt32 is the max Int32 value for NUMERIC types + NumericValueMaxInt64 = types.DecimalFromInt64(math.MaxInt64) // NumericValueMaxInt64 is the max Int64 value for NUMERIC types + NumericValueMinInt16 = types.DecimalFromInt64(math.MinInt16) // NumericValueMinInt16 is the min Int16 value for NUMERIC types + NumericValueMinInt32 = types.DecimalFromInt64(math.MinInt32) // NumericValueMinInt32 is the min Int32 value for NUMERIC types + NumericValueMinInt64 = types.DecimalFromInt64(math.MinInt64) // NumericValueMinInt64 is the min Int64 value for NUMERIC types + NumericNaN = apd.Decimal{Form: apd.NaN} + NumericInf = apd.Decimal{Form: apd.Infinite} + NumericNegInf = apd.Decimal{Form: apd.Infinite, Negative: true} ) // Numeric is a precise and unbounded decimal value. @@ -133,7 +112,7 @@ func GetNumericValueWithTypmod(val apd.Decimal, typmod int32) (apd.Decimal, erro } res := new(apd.Decimal) precision, scale := GetPrecisionAndScaleFromTypmod(typmod) - _, err := BaseContext.WithPrecision(uint32(precision)).Quantize(res, &val, -scale) + _, err := sql.DecimalCtx.WithPrecision(uint32(precision)).Quantize(res, &val, -scale) if err != nil { return apd.Decimal{}, errors.Errorf("numeric field overflow - A field with precision %v, scale %v must round to an absolute value less than 10^%v", precision, scale, precision-scale) } @@ -143,7 +122,7 @@ func GetNumericValueWithTypmod(val apd.Decimal, typmod int32) (apd.Decimal, erro // GetNumericValueFromStringWithTypmod returns either given numeric value or truncated or error // depending on the precision and scale decoded from given type modifier value. func GetNumericValueFromStringWithTypmod(val string, typmod int32) (apd.Decimal, error) { - dec, cond, err := BaseContext.WithPrecision(MaxPrecision).NewFromString(val) + dec, cond, err := sql.HighPrecisionCtx.NewFromString(val) if err != nil { return apd.Decimal{}, err } @@ -156,93 +135,8 @@ func GetNumericValueFromStringWithTypmod(val string, typmod int32) (apd.Decimal, // serializeTypeNumeric handles serialization from the standard representation to our serialized representation that is // written in Dolt. func serializeTypeNumeric(ctx *sql.Context, t *DoltgresType, val any) ([]byte, error) { - num := val.(apd.Decimal) - typmod := t.GetAttTypMod() - writer := utils.NewWireWriter() - if num.Form == apd.Finite { - // Short-circuit if this is the zero value - if num.IsZero() { - writer.WriteBytes([]byte{0, 0, 0, 0, 0, 0, 0, 0}) - return writer.BufferData(), nil - } - // There's a way to do this more efficiently, but we can do that work once this becomes a performance issue. - // This is based on the terminology used in Postgres' `numeric.c` file - decStr := num.Text('f') - isNegative := false - if strings.HasPrefix(decStr, "-") { - isNegative = true - decStr = decStr[1:] - } - // Split the integer and fractional parts - var intPart string - var fractPart string - if idx := strings.Index(decStr, "."); idx != -1 { - intPart = decStr[:idx] - fractPart = decStr[idx+1:] - } else { - intPart = decStr - } - // Find the "dscale", which is the number of digits in the fractional part - var dscale int16 - if typmod != -1 { - _, dscale32 := GetPrecisionAndScaleFromTypmod(typmod) - dscale = int16(dscale32) - } else { - dscale = int16(len(fractPart)) - } - // Pad the integer and fractional parts so that we can take groups of 4 numbers - if intPart == "0" { - intPart = "" - } else if len(intPart)%4 != 0 { - intPart = strings.Repeat("0", 4-(len(intPart)%4)) + intPart - } - if len(fractPart)%4 != 0 { - // remove trailing zeroes on right side before filling it. - fractPart = strings.TrimRightFunc(fractPart, func(r rune) bool { - return r == '0' - }) - fractPart = fractPart + strings.Repeat("0", 4-(len(fractPart)%4)) - } - // Write the "ndigits" first, or the number of base-10000 digits - writer.WriteInt16(int16((len(intPart) / 4) + (len(fractPart) / 4))) - // Write the "weight", which is the number of base-10000 digits in the integer part subtracted by 1 - writer.WriteInt16(int16((len(intPart) / 4) - 1)) - // Write the "sign" - if isNegative { - writer.WriteInt16(16384) - } else { - writer.WriteInt16(0) - } - // Write the "dscale" - writer.WriteInt16(dscale) - // Write all of the digits - fullPart := intPart + fractPart - for i := 0; i < len(fullPart); i += 4 { - part, err := strconv.Atoi(fullPart[i : i+4]) - if err != nil { - return nil, err - } - writer.WriteInt16(int16(part)) - } - } else { - var buf []byte - wp := len(buf) - buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) - if num.Form == apd.NaN { - binary.BigEndian.PutUint64(buf[wp:], pgNumericNaN) - } else if num.Form == apd.Infinite { - if num.Negative { - binary.BigEndian.PutUint64(buf[wp:], pgNumericNegInf) - } else { - binary.BigEndian.PutUint64(buf[wp:], pgNumericPosInf) - } - } - if typmod == -1 { - binary.BigEndian.PutUint16(buf[6:], uint16(32)) - } - writer.WriteBytes(buf) - } - return writer.BufferData(), nil + d := val.(apd.Decimal) + return d.MarshalText() } // deserializeTypeNumeric handles deserialization from the Dolt serialized format to our standard representation used by @@ -251,80 +145,9 @@ func deserializeTypeNumeric(ctx *sql.Context, t *DoltgresType, data []byte) (any if len(data) == 0 { return nil, nil } - reader := utils.NewWireReader(data) - var d apd.Decimal - - // 1. Read Header - ndigits := reader.ReadInt16() - weight := reader.ReadInt16() - sign := reader.ReadInt16() - dscale := reader.ReadInt16() - - // 2. Handle Special Values (NaN, Inf) - // These usually manifest as specific bit patterns in the header - switch uint16(sign) { - case 0xC000: // pgNumericNaN - d.Form = apd.NaN - return d, nil - case 0xD000: // pgNumericPosInf - d.Form = apd.Infinite - return d, nil - case 0xF000: // pgNumericNegInf - d.Form = apd.Infinite - d.Negative = true - return d, nil - } - - // 3. Handle Finite Values - if ndigits == 0 { - d.SetInt64(0) - return d, nil - } - - // Read base-10000 digits - digits := make([]int16, ndigits) - for i := 0; i < int(ndigits); i++ { - digits[i] = reader.ReadInt16() - } - - // 4. Convert base-10000 to string for apd.Decimal - // Each digit is exactly 4 characters wide (except potentially the first) - var sb strings.Builder - if sign == 16384 { - sb.WriteByte('-') - } - - for i, digit := range digits { - // Calculate how many 10000-base digits are before the decimal - // 'weight' is the index of the first digit, where 0 is 10^0 in base 10000 - if i == int(weight)+1 { - sb.WriteByte('.') - } - - sDigit := strconv.Itoa(int(digit)) - // Pad with leading zeros if not the very first digit - if l := len(sDigit); l < 4 { - padding := 4 - l - for p := 0; p < padding; p++ { - sb.WriteByte('0') - } - } - sb.WriteString(sDigit) - } - - // If weight is larger than digits, we need trailing zeros - if int(weight) >= len(digits) { - for i := 0; i < int(weight)-len(digits)+1; i++ { - sb.WriteString("0000") - } - } - - dec, _, err := BaseContext.NewFromString(sb.String()) - if err != nil { - return nil, err - } - _, _ = BaseContext.Quantize(dec, dec, int32(-dscale)) - return *dec, err + retVal := *apd.New(0, 0) + err := retVal.UnmarshalText(data) + return retVal, err } // NumericCompare compares two apd.Decimal values handling NaN separately. diff --git a/server/types/typeinfo.go b/server/types/typeinfo.go index 52a441af76..93b00b151e 100644 --- a/server/types/typeinfo.go +++ b/server/types/typeinfo.go @@ -65,8 +65,8 @@ func (t typeInfo) Encoding() val.Encoding { return val.Float32Enc case "float8": return val.Float64Enc - //case "numeric", "decimal": - // return val.DecimalEnc + case "numeric", "decimal": + return val.DecimalEnc case "bytea": return val.BytesAdaptiveEnc // TODO: use dolt JSON document encoding here diff --git a/testing/generation/function_coverage/output/framework_test.go b/testing/generation/function_coverage/output/framework_test.go index 9188a40b33..f71303f3a2 100644 --- a/testing/generation/function_coverage/output/framework_test.go +++ b/testing/generation/function_coverage/output/framework_test.go @@ -246,7 +246,7 @@ func Numeric(str string) pgtype.Numeric { return numeric } -// NumericToDecimal converts a pgtype.Numeric value to a decimal.Decimal value. +// NumericToDecimal converts a pgtype.Numeric value to a apd.Decimal value. func NumericToDecimal(val pgtype.Numeric) apd.Decimal { if val.NaN { return pgtypes.NumericNaN @@ -297,7 +297,7 @@ func CompareRows(t *testing.T, a sql.Row, b sql.Row) bool { case pgtype.Numeric: aDec := NumericToDecimal(aVal.(pgtype.Numeric)) bDec := NumericToDecimal(bVal.(pgtype.Numeric)) - _, err := pgtypes.BaseContext.Sub(&aDec, &aDec, &bDec) + _, err := sql.DecimalCtx.Sub(&aDec, &aDec, &bDec) if err != nil { return false } diff --git a/testing/go/framework.go b/testing/go/framework.go index cb8fe07d66..314b51dac3 100644 --- a/testing/go/framework.go +++ b/testing/go/framework.go @@ -843,12 +843,6 @@ func NormalizeIntsAndFloats(v any) any { // Numeric creates a numeric value from a string. func Numeric(str string) pgtype.Numeric { - //// 250.0 != 250 and 42.90 != 42.9, so we trim all trailing fractional zeroes (and decimal if no fractional zeroes) - //// to ensure that the input strings are homogenized, which will give us comparable representations for the same value - //if idx := strings.Index(str, "."); idx != -1 { - // str = strings.TrimRight(str, "0") - //} - //str = strings.TrimRight(str, ".") numeric := pgtype.Numeric{} if err := numeric.Scan(str); err != nil { panic(err) From 2c3c52f21037d5a92e98ffda58bcef7f9a53b44f Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 5 May 2026 16:54:03 -0700 Subject: [PATCH 05/16] update dolt --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index e5fdad42b1..c49f33a04c 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,13 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v3 v3.2.3 github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20260505174628-1b8bc4a1f0fd + github.com/dolthub/dolt/go v0.40.5-0.20260505235150-c334899b1138 github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.20.1-0.20260504234906-b4fc3a6e3cc5 + github.com/dolthub/go-mysql-server v0.20.1-0.20260505215953-bee2a70fec58 github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 - github.com/dolthub/vitess v0.0.0-20260504232330-bcb1b9015c48 + github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390 github.com/fatih/color v1.13.0 github.com/goccy/go-json v0.10.2 github.com/gogo/protobuf v1.3.2 diff --git a/go.sum b/go.sum index cf3cdf9493..f7f5842b9e 100644 --- a/go.sum +++ b/go.sum @@ -245,8 +245,8 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= github.com/dolthub/dolt-mcp v0.3.4 h1:AyG5cw+fNWXDHXujtQnqUPZrpWtPg6FN6yYtjv1pP44= github.com/dolthub/dolt-mcp v0.3.4/go.mod h1:bCZ7KHvDYs+M0e+ySgmGiNvLhcwsN7bbf5YCyillLrk= -github.com/dolthub/dolt/go v0.40.5-0.20260505174628-1b8bc4a1f0fd h1:nWpxjJMGeNNrKnMzqcihLdsOcdzHPZXw1VfwKrh1eHU= -github.com/dolthub/dolt/go v0.40.5-0.20260505174628-1b8bc4a1f0fd/go.mod h1:bjYYYVBPSlOKvpsX5AdGftveZQCIC5cyqvFUK5I+0CQ= +github.com/dolthub/dolt/go v0.40.5-0.20260505235150-c334899b1138 h1:ON9buSp8ADonlYXELGoct3+3gZ8czXbx3ILogqUS8ok= +github.com/dolthub/dolt/go v0.40.5-0.20260505235150-c334899b1138/go.mod h1:r9y445V5FvEh1GPGko1NSspH9pzdN1yLml9OvELAUbQ= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 h1:JShhbqMw26nKx3pqqu/cFxOpzBkN+4elVhzuUfgDw2k= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69/go.mod h1:SSLraQS/jGLYFgff3vuZ+JbVUct6vyEeMzjLBqWqoyM= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -255,8 +255,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866 h1:U6gSf5I0e6h6GP1/5Sa7D2lWW1CWfcVPtY5wkyHq6jY= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE= -github.com/dolthub/go-mysql-server v0.20.1-0.20260504234906-b4fc3a6e3cc5 h1:F7prHhrdLPuzacr9+hpklMxZ4Q2M2dqsKzxX8FqiRXs= -github.com/dolthub/go-mysql-server v0.20.1-0.20260504234906-b4fc3a6e3cc5/go.mod h1:hhGHXWslZ2AFzPkgJqoHtH6fZ/lNMk25SAe2RHKpQLU= +github.com/dolthub/go-mysql-server v0.20.1-0.20260505215953-bee2a70fec58 h1:vaw9NBd3aI2qo09GfUNktff8zpTFePIeZw+upfVb4qc= +github.com/dolthub/go-mysql-server v0.20.1-0.20260505215953-bee2a70fec58/go.mod h1:uE//4hYOi/1zIgOYgr8D+u7OvDfI4yPXFsL4dTZWk5I= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20260414231531-5f031e3e9037 h1:oIW9HwuWrhxv+4HZxA+QQSKHLqWFyXZ2FmNjUYwkdiM= @@ -267,8 +267,8 @@ github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 h1:GY17cGA4 github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1/go.mod h1:qnrZP3/1slFl2Bq5yw38HLOsArZareGwdpEceriblLc= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= -github.com/dolthub/vitess v0.0.0-20260504232330-bcb1b9015c48 h1:znkwaIJnYUf9kL6hNctwWOaKDH0h1jWevwdmo0Otojo= -github.com/dolthub/vitess v0.0.0-20260504232330-bcb1b9015c48/go.mod h1:dKAkzdfRkAudpc0g8JOQ0eiEjV83TYIFz/yNIEdcjXM= +github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390 h1:FjUnJYan3i3mrk/i+qENwFOQhA+XfH+wgAucbgi/4sQ= +github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390/go.mod h1:dKAkzdfRkAudpc0g8JOQ0eiEjV83TYIFz/yNIEdcjXM= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= From 3b092c58e3d2d7b9a6930a1eb898660422bf7856 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Wed, 6 May 2026 14:07:43 -0700 Subject: [PATCH 06/16] fix regressed tests --- server/cast/jsonb.go | 19 +++------------ server/functions/binary/minus.go | 6 ++++- server/functions/binary/multiply.go | 2 +- server/functions/binary/plus.go | 6 ++++- server/functions/floor.go | 9 ++++++- server/functions/ln.go | 24 ++++++++++-------- server/functions/numeric.go | 5 +--- server/functions/round.go | 4 +-- server/functions/sqrt.go | 21 ++++++---------- server/types/json_document.go | 2 +- testing/go/types_test.go | 38 +++++++++++++++++++++++++++++ 11 files changed, 86 insertions(+), 50 deletions(-) diff --git a/server/cast/jsonb.go b/server/cast/jsonb.go index be60923e3c..6b3b767a4d 100644 --- a/server/cast/jsonb.go +++ b/server/cast/jsonb.go @@ -17,6 +17,7 @@ package cast import ( "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/go-mysql-server/sql" @@ -118,11 +119,7 @@ func jsonbExplicit() { if d.Cmp(&pgtypes.NumericValueMinInt16) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Errorf("smallint out of range") } - i, err := d.Int64() - if err != nil { - return nil, err - } - return int16(i), nil + return int16(types.DecimalIntPart(d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -148,11 +145,7 @@ func jsonbExplicit() { if d.Cmp(&pgtypes.NumericValueMinInt32) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Errorf("integer out of range") } - i, err := d.Int64() - if err != nil { - return nil, err - } - return int32(i), nil + return int32(types.DecimalIntPart(d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -178,11 +171,7 @@ func jsonbExplicit() { if d.Cmp(&pgtypes.NumericValueMinInt64) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Errorf("bigint out of range") } - i, err := d.Int64() - if err != nil { - return nil, err - } - return int64(i), nil + return int64(types.DecimalIntPart(d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: diff --git a/server/functions/binary/minus.go b/server/functions/binary/minus.go index 1f622523c2..30c8730620 100644 --- a/server/functions/binary/minus.go +++ b/server/functions/binary/minus.go @@ -242,7 +242,11 @@ var numeric_sub = framework.Function2{ Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { num1 := val1.(apd.Decimal) num2 := val2.(apd.Decimal) - _, err := sql.DecimalCtx.Sub(&num1, &num1, &num2) + p := num1.NumDigits() + if p2 := num2.NumDigits(); p < p2 { + p = p2 + } + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Sub(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/binary/multiply.go b/server/functions/binary/multiply.go index 8fb5ca9d98..1b7eb9f896 100644 --- a/server/functions/binary/multiply.go +++ b/server/functions/binary/multiply.go @@ -230,7 +230,7 @@ var numeric_mul = framework.Function2{ if (num1.Form == apd.Infinite || num2.Form == apd.Infinite) && (num1.IsZero() || num2.IsZero()) { return pgtypes.NumericNaN, nil } - _, err := sql.DecimalCtx.Mul(&num1, &num1, &num2) + _, err := sql.DecimalCtx.WithPrecision(uint32(num1.NumDigits()+num2.NumDigits())).Mul(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/binary/plus.go b/server/functions/binary/plus.go index 2250ffabb8..42cccaef00 100644 --- a/server/functions/binary/plus.go +++ b/server/functions/binary/plus.go @@ -390,7 +390,11 @@ func numeric_add_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any return pgtypes.NumericInf, nil } - _, err := sql.DecimalCtx.Add(&num1, &num1, &num2) + p := num1.NumDigits() + if p2 := num2.NumDigits(); p < p2 { + p = p2 + } + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Add(&num1, &num1, &num2) if err != nil { return nil, err } diff --git a/server/functions/floor.go b/server/functions/floor.go index a3f51aadf8..8b1f89004c 100644 --- a/server/functions/floor.go +++ b/server/functions/floor.go @@ -49,10 +49,17 @@ var floor_numeric = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { dec := val.(apd.Decimal) - _, err := sql.DecimalCtx.Floor(&dec, &dec) + newDecimalCtx := *sql.DecimalCtx + newDecimalCtx.Rounding = apd.RoundFloor + + _, err := newDecimalCtx.Floor(&dec, &dec) if err != nil { return nil, err } + // floor(-0.1) returns -0, which is -1 in postgres because postgres does not support -0 + if dec.IsZero() && dec.Negative { + return *apd.New(-1, 0), nil + } return dec, nil }, } diff --git a/server/functions/ln.go b/server/functions/ln.go index 6fd95517d7..4bef292081 100644 --- a/server/functions/ln.go +++ b/server/functions/ln.go @@ -64,22 +64,26 @@ var ln_numeric = framework.Function1{ return dec, nil } - // TODO: calculate precision and scale accurately - s := dec.Text('f') - parts := strings.Split(s, ".") - - exp := int32(-16) - if dec.Exponent < exp { - exp = dec.Exponent + exp := dec.Exponent + c := sql.DecimalCtx + if nd := uint32(dec.NumDigits()); nd > c.Precision { + c = c.WithPrecision(nd) } - p := uint32(len(parts[0]) + int(-exp)) - c := sql.DecimalCtx.WithPrecision(p) _, err := c.Ln(&dec, &dec) if err != nil { return nil, err } - _, err = c.Quantize(&dec, &dec, exp) + + // TODO: calculate precision and scale accurately + if exp > -16 { + // use ln result + parts := strings.Split(dec.Text('f'), ".") + whole := int32(len(parts[0]) / 2) + exp = whole - 16 + } + + _, err = sql.HighPrecisionCtx.Quantize(&dec, &dec, exp) if err != nil { return nil, err } diff --git a/server/functions/numeric.go b/server/functions/numeric.go index b5c2382bd1..05a6459bdf 100644 --- a/server/functions/numeric.go +++ b/server/functions/numeric.go @@ -48,7 +48,7 @@ var numeric_in = framework.Function3{ Callable: func(ctx *sql.Context, _ [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) { input := val1.(string) typmod := val3.(int32) - dec, _, err := apd.NewFromString(input) + dec, _, err := apd.NewFromString(strings.TrimSpace(input)) if err != nil { return nil, pgtypes.ErrInvalidSyntaxForType.New("numeric", input) } @@ -160,9 +160,6 @@ var numeric_recv = framework.Function3{ if err != nil { return nil, err } - str := dec.Text('f') - if str == " " { - } _, err = sql.HighPrecisionCtx.Quantize(dec, dec, int32(-dscale)) if err != nil { return nil, err diff --git a/server/functions/round.go b/server/functions/round.go index 17a868548a..820bdbc4ca 100644 --- a/server/functions/round.go +++ b/server/functions/round.go @@ -71,11 +71,11 @@ var round_numeric_int64 = framework.Function2{ Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { dec := val1.(apd.Decimal) places := val2.(int64) - _, err := sql.DecimalCtx.Round(&dec, &dec) + _, err := sql.HighPrecisionCtx.Round(&dec, &dec) if err != nil { return nil, err } - _, err = sql.DecimalCtx.Quantize(&dec, &dec, int32(-places)) + _, err = sql.HighPrecisionCtx.Quantize(&dec, &dec, int32(-places)) if err != nil { return nil, err } diff --git a/server/functions/sqrt.go b/server/functions/sqrt.go index 6393c36d68..365ee5929d 100644 --- a/server/functions/sqrt.go +++ b/server/functions/sqrt.go @@ -16,7 +16,6 @@ package functions import ( "math" - "strings" "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" @@ -58,23 +57,17 @@ var sqrt_numeric = framework.Function1{ return nil, errors.Errorf("cannot take square root of a negative number") } + exp := dec.Exponent + c := sql.DecimalCtx + nd := uint32(dec.NumDigits()) + if nd > c.Precision { + c = c.WithPrecision(nd) + } // TODO: calculate precision and scale accurately - s := dec.Text('f') - parts := strings.Split(s, ".") - - exp := int32(-16) - whole := int32(len(parts[0]) / 2) if dec.Exponent == 0 { - exp = whole - 16 - } else if dec.Exponent < -16 { - exp = dec.Exponent - } - p := uint32(whole) + 1 - if exp < 0 { - p += uint32(-exp) + exp = int32(nd/2) - 16 } - c := sql.DecimalCtx.WithPrecision(p) _, err := c.Sqrt(&dec, &dec) if err != nil { return nil, err diff --git a/server/types/json_document.go b/server/types/json_document.go index 00eb638c58..ac049e3729 100644 --- a/server/types/json_document.go +++ b/server/types/json_document.go @@ -421,7 +421,7 @@ func ConvertToJsonDocument(val interface{}) (JsonValue, error) { case json.Number: str := string(val) // Strip trailing fractional zeros: "25.0"→{250,-1} and "25"→{25,0} differ in MarshalBinary, breaking GROUP BY hash equality. - if strings.IndexByte(str, '.') != -1 { + if strings.IndexByte(str, '.') != -1 && strings.IndexByte(str, 'e') == -1 { // remove trailing 0s after '.' str = strings.TrimRightFunc(str, func(r rune) bool { return r == '0' diff --git a/testing/go/types_test.go b/testing/go/types_test.go index 00c3a1b58a..ecc1bd6572 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -1296,6 +1296,16 @@ var typesTests = []ScriptTest{ {"t"}, }, }, + { + Query: `SELECT '1.3e100'::jsonb;`, + Expected: []sql.Row{ + {"13000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + }, + }, + { + Query: `select '12345.05'::jsonb::int2;`, + Expected: []sql.Row{{12345}}, + }, }, }, { @@ -1876,6 +1886,10 @@ var typesTests = []ScriptTest{ "CREATE TABLE t_numeric (id INTEGER primary key, v1 NUMERIC(5,2));", "INSERT INTO t_numeric VALUES (1, 123.45), (2, 67.89), (3, 100.3);", "CREATE TABLE fract_only (id int, val numeric(4,4));", + "CREATE TABLE num_data (id int4, val numeric(210,10));", + "INSERT INTO num_data VALUES (2, '-34338492.215397047');", + "CREATE TABLE ceil_floor_round (a numeric);", + "INSERT INTO ceil_floor_round VALUES ('-0.000001');", }, Assertions: []ScriptTestAssertion{ { @@ -1934,6 +1948,30 @@ var typesTests = []ScriptTest{ Query: "SELECT 'infinity'::numeric;", Expected: []sql.Row{{Numeric("Infinity")}}, }, + { + Query: "SELECT ' 123'::numeric;", + Expected: []sql.Row{{Numeric("123")}}, + }, + { + Query: "SELECT t1.id, t2.id, round(t1.val * t2.val, 30) FROM num_data t1, num_data t2;", + Expected: []sql.Row{{2, 2, Numeric("1179132047626883.596862135856320209000000000000")}}, + }, + { + Query: "select sqrt(1.000000000000004::numeric);", + Expected: []sql.Row{{Numeric("1.000000000000002")}}, + }, + { + Query: "select ln(5.80397490724e5);", + Expected: []sql.Row{{Numeric("13.271468476626518")}}, + }, + { + Query: "select 4770999999999999999999999999999999999999999999999999999999999999999999999999999999999999 * 9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999;", + Expected: []sql.Row{{Numeric("47709999999999999999999999999999999999999999999999999999999999999999999999999999999999985229000000000000000000000000000000000000000000000000000000000000000000000000000000000001")}}, + }, + { + Query: "SELECT floor(-0.000001);", + Expected: []sql.Row{{Numeric("-1")}}, + }, }, }, { From 81456d6d944b1601fdbb2840b137d4de8421e3c5 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Wed, 6 May 2026 14:29:55 -0700 Subject: [PATCH 07/16] feedback update --- server/functions/trunc.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/functions/trunc.go b/server/functions/trunc.go index 0818830a49..0e37c12ca8 100644 --- a/server/functions/trunc.go +++ b/server/functions/trunc.go @@ -64,10 +64,10 @@ var trunc_numeric_int64 = framework.Function2{ Return: pgtypes.Numeric, Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Int32}, Strict: true, - Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - dec := val1.(apd.Decimal) - places := val2.(int32) - _, err := sql.HighPrecisionCtx.Quantize(&dec, &dec, -places) + Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, num any, places any) (any, error) { + dec := num.(apd.Decimal) + scale := places.(int32) + _, err := sql.HighPrecisionCtx.Quantize(&dec, &dec, -scale) if err != nil { return nil, err } From eb5b75a35a01d9770f60d2546187c557fb58e38a Mon Sep 17 00:00:00 2001 From: jennifersp Date: Mon, 11 May 2026 23:36:30 -0700 Subject: [PATCH 08/16] update to use pointer --- go.mod | 4 +-- go.sum | 12 ++++--- server/analyzer/resolve_values_types.go | 2 +- server/analyzer/type_sanitizer.go | 2 +- server/ast/expr.go | 2 +- server/cast/float32.go | 2 +- server/cast/float64.go | 2 +- server/cast/int16.go | 2 +- server/cast/int32.go | 2 +- server/cast/int64.go | 2 +- server/cast/jsonb.go | 12 +++---- server/cast/numeric.go | 18 +++++----- server/expression/gms_cast.go | 8 ++--- server/expression/literal.go | 7 ++-- server/functions/abs.go | 6 ++-- server/functions/binary/divide.go | 16 ++++----- server/functions/binary/equal.go | 2 +- server/functions/binary/greater.go | 2 +- server/functions/binary/greater_equal.go | 2 +- server/functions/binary/less.go | 2 +- server/functions/binary/less_equal.go | 2 +- server/functions/binary/minus.go | 10 +++--- server/functions/binary/mod.go | 13 +++---- server/functions/binary/multiply.go | 9 ++--- server/functions/binary/not_equal.go | 2 +- server/functions/binary/plus.go | 10 +++--- server/functions/ceil.go | 10 ++++-- server/functions/date_part.go | 10 +++--- server/functions/div.go | 11 +++--- server/functions/exp.go | 7 ++-- server/functions/extract.go | 17 +++++----- server/functions/factorial.go | 2 +- server/functions/floor.go | 18 ++++++---- server/functions/generate_series.go | 22 ++++++------ server/functions/ln.go | 23 ++++++------- server/functions/log.go | 26 ++++++-------- server/functions/min_scale.go | 2 +- server/functions/mod.go | 12 ++++--- server/functions/numeric.go | 28 +++++++-------- server/functions/power.go | 30 +++++++--------- server/functions/round.go | 22 +++++------- server/functions/sign.go | 4 +-- server/functions/sqrt.go | 26 ++++++-------- server/functions/to_char.go | 2 +- server/functions/trim_scale.go | 2 +- server/functions/trunc.go | 16 +++------ server/functions/unary/minus.go | 6 ++-- server/functions/width_bucket.go | 16 ++++----- server/types/json_document.go | 5 +-- server/types/numeric.go | 34 +++++++++---------- server/types/type.go | 6 ++-- .../output/framework_test.go | 14 +++++--- testing/go/framework.go | 2 +- 53 files changed, 259 insertions(+), 267 deletions(-) diff --git a/go.mod b/go.mod index e422a02006..6e128a450c 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,10 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v3 v3.2.3 github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20260506222236-acbcbc3583f4 + github.com/dolthub/dolt/go v0.40.5-0.20260512062409-c59eb8452854 github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.20.1-0.20260505171600-5a9dda3f04ff + github.com/dolthub/go-mysql-server v0.20.1-0.20260512060642-0776ab53f95d github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390 diff --git a/go.sum b/go.sum index b23111bf13..7237752a81 100644 --- a/go.sum +++ b/go.sum @@ -245,8 +245,10 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= github.com/dolthub/dolt-mcp v0.3.4 h1:AyG5cw+fNWXDHXujtQnqUPZrpWtPg6FN6yYtjv1pP44= github.com/dolthub/dolt-mcp v0.3.4/go.mod h1:bCZ7KHvDYs+M0e+ySgmGiNvLhcwsN7bbf5YCyillLrk= -github.com/dolthub/dolt/go v0.40.5-0.20260506222236-acbcbc3583f4 h1:V9XgEwdwKaktPt1nQTYUBocFtCLZMdqms3BOktpq7fA= -github.com/dolthub/dolt/go v0.40.5-0.20260506222236-acbcbc3583f4/go.mod h1:APpkaZsKNNwupmOvCSNyyYstpUHn+sSo9Ix/7mGx6/0= +github.com/dolthub/dolt/go v0.40.5-0.20260511212315-baeef95a98d5 h1:cGEG3nV79qSeZ9f3rCbn+kBOCANTHzu/rxIdNhyr144= +github.com/dolthub/dolt/go v0.40.5-0.20260511212315-baeef95a98d5/go.mod h1:Qj6WUBBBNPRakBaGzPaBQ/tYNuKaMFA+pV23gAyZXiA= +github.com/dolthub/dolt/go v0.40.5-0.20260512062409-c59eb8452854 h1:Ad+I/PiVmeM/L0nVCrYBASOgHxDeBA9mr3N83WRYMaM= +github.com/dolthub/dolt/go v0.40.5-0.20260512062409-c59eb8452854/go.mod h1:jZU260NYTVVIX3cwwlMt5etPPQNvi3EpdBGyWpdRADU= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 h1:JShhbqMw26nKx3pqqu/cFxOpzBkN+4elVhzuUfgDw2k= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69/go.mod h1:SSLraQS/jGLYFgff3vuZ+JbVUct6vyEeMzjLBqWqoyM= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -255,8 +257,10 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866 h1:U6gSf5I0e6h6GP1/5Sa7D2lWW1CWfcVPtY5wkyHq6jY= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE= -github.com/dolthub/go-mysql-server v0.20.1-0.20260505171600-5a9dda3f04ff h1:q3GZb7jKVgbn0hU4f458E4vTX4NMzEVSkddC3M3dI80= -github.com/dolthub/go-mysql-server v0.20.1-0.20260505171600-5a9dda3f04ff/go.mod h1:55n1yslSIZ5uewFbtd82DsYt3f9vUKwnRN5GZJie+nE= +github.com/dolthub/go-mysql-server v0.20.1-0.20260512052117-1770783d298d h1:j2IpyW3xIS/VP0W734Zh9z0JdphzoVtfnvTE8uOmslQ= +github.com/dolthub/go-mysql-server v0.20.1-0.20260512052117-1770783d298d/go.mod h1:uE//4hYOi/1zIgOYgr8D+u7OvDfI4yPXFsL4dTZWk5I= +github.com/dolthub/go-mysql-server v0.20.1-0.20260512060642-0776ab53f95d h1:8W3f9ey7QVeK503w2NubP1c17xMZ64eoSHNIaOXzUsI= +github.com/dolthub/go-mysql-server v0.20.1-0.20260512060642-0776ab53f95d/go.mod h1:uE//4hYOi/1zIgOYgr8D+u7OvDfI4yPXFsL4dTZWk5I= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20260414231531-5f031e3e9037 h1:oIW9HwuWrhxv+4HZxA+QQSKHLqWFyXZ2FmNjUYwkdiM= diff --git a/server/analyzer/resolve_values_types.go b/server/analyzer/resolve_values_types.go index f319648fb1..a4f17317e7 100644 --- a/server/analyzer/resolve_values_types.go +++ b/server/analyzer/resolve_values_types.go @@ -113,7 +113,7 @@ func ResolveValuesTypes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, s // MIN now returns numeric, so GroupBy produces numeric. But the // Project's GetField still says int4 because its tableId=GroupBy, // which wasn't in transformedVDTs. At runtime this causes a panic - // because the actual value is apd.Decimal but the type says int32. + // because the actual value is *apd.Decimal but the type says int32. // // This pass catches those: for each GetField, check if its type // disagrees with what the child node actually produces. diff --git a/server/analyzer/type_sanitizer.go b/server/analyzer/type_sanitizer.go index e74cab427c..245aeed2c5 100644 --- a/server/analyzer/type_sanitizer.go +++ b/server/analyzer/type_sanitizer.go @@ -173,7 +173,7 @@ func typeSanitizerLiterals(ctx *sql.Context, gmsLiteral *expression.Literal) (sq } return pgexprs.NewRawLiteralFloat64(newVal.(float64)), transform.NewTree, nil case query.Type_DECIMAL: - dec, ok := gmsLiteral.Value().(apd.Decimal) + dec, ok := gmsLiteral.Value().(*apd.Decimal) if !ok { return nil, transform.NewTree, errors.Errorf("SANITIZER: expected decimal type: %T", gmsLiteral.Value()) } diff --git a/server/ast/expr.go b/server/ast/expr.go index 244a04c29b..30a21427be 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -518,7 +518,7 @@ func nodeExpr(ctx *Context, node tree.Expr) (vitess.Expr, error) { }, nil case *tree.DDecimal: return vitess.InjectedExpr{ - Expression: pgexprs.NewRawLiteralNumeric(node.Decimal)}, nil + Expression: pgexprs.NewRawLiteralNumeric(&node.Decimal)}, nil case *tree.DEnum: return nil, errors.Errorf("the statement is not yet supported") case *tree.DFloat: diff --git a/server/cast/float32.go b/server/cast/float32.go index 0692f111d8..e1aa998ced 100644 --- a/server/cast/float32.go +++ b/server/cast/float32.go @@ -75,7 +75,7 @@ func float32Assignment() { if err != nil { return nil, err } - return pgtypes.GetNumericValueWithTypmod(*d, targetType.GetAttTypMod()) + return pgtypes.GetNumericValueWithTypmod(d, targetType.GetAttTypMod()) }, }) } diff --git a/server/cast/float64.go b/server/cast/float64.go index 05ce629589..a87788809f 100644 --- a/server/cast/float64.go +++ b/server/cast/float64.go @@ -81,7 +81,7 @@ func float64Assignment() { if err != nil { return nil, err } - return pgtypes.GetNumericValueWithTypmod(*d, targetType.GetAttTypMod()) + return pgtypes.GetNumericValueWithTypmod(d, targetType.GetAttTypMod()) }, }) } diff --git a/server/cast/int16.go b/server/cast/int16.go index 03e18e58a5..425b7d8877 100644 --- a/server/cast/int16.go +++ b/server/cast/int16.go @@ -62,7 +62,7 @@ func int16Implicit() { FromType: pgtypes.Int16, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return pgtypes.GetNumericValueWithTypmod(*apd.New(int64(val.(int16)), 0), targetType.GetAttTypMod()) + return pgtypes.GetNumericValueWithTypmod(apd.New(int64(val.(int16)), 0), targetType.GetAttTypMod()) }, }) framework.MustAddImplicitTypeCast(framework.TypeCast{ diff --git a/server/cast/int32.go b/server/cast/int32.go index 55df7aa07e..6ce1cb1283 100644 --- a/server/cast/int32.go +++ b/server/cast/int32.go @@ -84,7 +84,7 @@ func int32Implicit() { FromType: pgtypes.Int32, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return pgtypes.GetNumericValueWithTypmod(*apd.New(int64(val.(int32)), 0), targetType.GetAttTypMod()) + return pgtypes.GetNumericValueWithTypmod(apd.New(int64(val.(int32)), 0), targetType.GetAttTypMod()) }, }) framework.MustAddImplicitTypeCast(framework.TypeCast{ diff --git a/server/cast/int64.go b/server/cast/int64.go index 7da9a674a8..e238dd4211 100644 --- a/server/cast/int64.go +++ b/server/cast/int64.go @@ -76,7 +76,7 @@ func int64Implicit() { FromType: pgtypes.Int64, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return pgtypes.GetNumericValueWithTypmod(*apd.New(val.(int64), 0), targetType.GetAttTypMod()) + return pgtypes.GetNumericValueWithTypmod(apd.New(val.(int64), 0), targetType.GetAttTypMod()) }, }) framework.MustAddImplicitTypeCast(framework.TypeCast{ diff --git a/server/cast/jsonb.go b/server/cast/jsonb.go index 6b3b767a4d..0fe90e2c24 100644 --- a/server/cast/jsonb.go +++ b/server/cast/jsonb.go @@ -116,10 +116,10 @@ func jsonbExplicit() { return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: d := apd.Decimal(value) - if d.Cmp(&pgtypes.NumericValueMinInt16) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt16) > 0 { + if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Errorf("smallint out of range") } - return int16(types.DecimalIntPart(d)), nil + return int16(types.DecimalIntPart(&d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -142,10 +142,10 @@ func jsonbExplicit() { return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: d := apd.Decimal(value) - if d.Cmp(&pgtypes.NumericValueMinInt32) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt32) > 0 { + if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Errorf("integer out of range") } - return int32(types.DecimalIntPart(d)), nil + return int32(types.DecimalIntPart(&d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -168,10 +168,10 @@ func jsonbExplicit() { return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: d := apd.Decimal(value) - if d.Cmp(&pgtypes.NumericValueMinInt64) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt64) > 0 { + if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Errorf("bigint out of range") } - return int64(types.DecimalIntPart(d)), nil + return int64(types.DecimalIntPart(&d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: diff --git a/server/cast/numeric.go b/server/cast/numeric.go index 7f9d859219..100a0fb219 100644 --- a/server/cast/numeric.go +++ b/server/cast/numeric.go @@ -36,8 +36,8 @@ func numericAssignment() { FromType: pgtypes.Numeric, ToType: pgtypes.Int16, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(apd.Decimal) - if d.Cmp(&pgtypes.NumericValueMinInt16) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt16) > 0 { + d := val.(*apd.Decimal) + if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") } i := types.DecimalIntPart(d) @@ -48,8 +48,8 @@ func numericAssignment() { FromType: pgtypes.Numeric, ToType: pgtypes.Int32, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(apd.Decimal) - if d.Cmp(&pgtypes.NumericValueMinInt32) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt32) > 0 { + d := val.(*apd.Decimal) + if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") } i := types.DecimalIntPart(d) @@ -60,8 +60,8 @@ func numericAssignment() { FromType: pgtypes.Numeric, ToType: pgtypes.Int64, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(apd.Decimal) - if d.Cmp(&pgtypes.NumericValueMinInt64) < 0 || d.Cmp(&pgtypes.NumericValueMaxInt64) > 0 { + d := val.(*apd.Decimal) + if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range") } return types.DecimalIntPart(d), nil @@ -75,7 +75,7 @@ func numericImplicit() { FromType: pgtypes.Numeric, ToType: pgtypes.Float32, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(apd.Decimal) + d := val.(*apd.Decimal) f, _ := d.Float64() return float32(f), nil }, @@ -84,7 +84,7 @@ func numericImplicit() { FromType: pgtypes.Numeric, ToType: pgtypes.Float64, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - d := val.(apd.Decimal) + d := val.(*apd.Decimal) f, _ := d.Float64() return f, nil }, @@ -93,7 +93,7 @@ func numericImplicit() { FromType: pgtypes.Numeric, ToType: pgtypes.Numeric, Function: func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { - return pgtypes.GetNumericValueWithTypmod(val.(apd.Decimal), targetType.GetAttTypMod()) + return pgtypes.GetNumericValueWithTypmod(val.(*apd.Decimal), targetType.GetAttTypMod()) }, }) } diff --git a/server/expression/gms_cast.go b/server/expression/gms_cast.go index 580eeb34b6..fc39017917 100644 --- a/server/expression/gms_cast.go +++ b/server/expression/gms_cast.go @@ -122,9 +122,9 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - dec, ok := newVal.(apd.Decimal) + dec, ok := newVal.(*apd.Decimal) if !ok { - return nil, errors.Errorf("GMSCast expected type `apd.Decimal`, got `%T`", val) + return nil, errors.Errorf("GMSCast expected type `*apd.Decimal`, got `%T`", val) } return dec, nil case query.Type_FLOAT32: @@ -150,9 +150,9 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - dec, ok := newVal.(apd.Decimal) + dec, ok := newVal.(*apd.Decimal) if !ok { - return nil, errors.Errorf("GMSCast expected type `apd.Decimal`, got `%T`", val) + return nil, errors.Errorf("GMSCast expected type `*apd.Decimal`, got `%T`", val) } return dec, nil case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP: diff --git a/server/expression/literal.go b/server/expression/literal.go index 49012cc7ce..2b01a784a2 100644 --- a/server/expression/literal.go +++ b/server/expression/literal.go @@ -32,7 +32,6 @@ import ( // NewNumericLiteral returns a new *expression.Literal containing a NUMERIC value. func NewNumericLiteral(numericValue string) (*expression.Literal, error) { - //TODO: should use the input function of the type d, err := pgtypes.GetNumericValueFromStringWithTypmod(numericValue, -1) if err != nil { return nil, err @@ -107,8 +106,8 @@ func NewRawLiteralFloat64(val float64) *expression.Literal { return expression.NewLiteral(val, pgtypes.Float64) } -// NewRawLiteralNumeric returns a new *expression.Literal containing an apd.Decimal value. -func NewRawLiteralNumeric(val apd.Decimal) *expression.Literal { +// NewRawLiteralNumeric returns a new *expression.Literal containing an *apd.Decimal value. +func NewRawLiteralNumeric(val *apd.Decimal) *expression.Literal { return expression.NewLiteral(val, pgtypes.Numeric) } @@ -176,7 +175,7 @@ func ToVitessLiteral(l *expression.Literal) *vitess.SQLVal { case pgtypes.Int64.ID: return vitess.NewIntVal([]byte(strconv.FormatInt(l.Value().(int64), 10))) case pgtypes.Numeric.ID: - d := l.Value().(apd.Decimal) + d := l.Value().(*apd.Decimal) return vitess.NewFloatVal([]byte(d.String())) case pgtypes.Text.ID: return vitess.NewStrVal([]byte(l.Value().(string))) diff --git a/server/functions/abs.go b/server/functions/abs.go index fa4cbf573b..62c2af0e7e 100644 --- a/server/functions/abs.go +++ b/server/functions/abs.go @@ -83,8 +83,8 @@ var abs_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - dec := val1.(apd.Decimal) - abs := dec.Abs(&dec) - return *abs, nil + dec := val1.(*apd.Decimal) + res := new(apd.Decimal) + return res.Abs(dec), nil }, } diff --git a/server/functions/binary/divide.go b/server/functions/binary/divide.go index e5e6827fbc..40d0a893d5 100644 --- a/server/functions/binary/divide.go +++ b/server/functions/binary/divide.go @@ -286,8 +286,8 @@ var interval_div = framework.Function2{ // numeric_div_callable is the callable logic for the numeric_div function. func numeric_div_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) + num1 := val1.(*apd.Decimal) + num2 := val2.(*apd.Decimal) if num1.Form == apd.NaN || num2.Form == apd.NaN || (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { return pgtypes.NumericNaN, nil @@ -299,17 +299,15 @@ func numeric_div_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any return num1, nil } if num2.Form == apd.Infinite { - return *apd.New(0, 0), nil + return apd.New(0, 0), nil } - _, err := sql.HighPrecisionCtx.Quo(&num1, &num1, &num2) - if err != nil { - return nil, err - } - _, err = sql.DecimalCtx.Quantize(&num1, &num1, -16) + + res := new(apd.Decimal) + _, err := sql.DecimalHighPrecisionCtx.Quo(res, num1, num2) if err != nil { return nil, err } - return num1, nil + return sql.DecimalRound(res, 16) } // numeric_div represents the PostgreSQL function of the same name, taking the same parameters. diff --git a/server/functions/binary/equal.go b/server/functions/binary/equal.go index 244db200dc..34428762d1 100644 --- a/server/functions/binary/equal.go +++ b/server/functions/binary/equal.go @@ -459,7 +459,7 @@ var nameeqtext = framework.Function2{ // numeric_eq_callable is the callable logic for the numeric_eq function. func numeric_eq_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(*apd.Decimal), val2.(*apd.Decimal)) return res == 0, err } diff --git a/server/functions/binary/greater.go b/server/functions/binary/greater.go index e953c2a808..1331615661 100644 --- a/server/functions/binary/greater.go +++ b/server/functions/binary/greater.go @@ -385,7 +385,7 @@ var numeric_gt = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(*apd.Decimal), val2.(*apd.Decimal)) return res == 1, err }, } diff --git a/server/functions/binary/greater_equal.go b/server/functions/binary/greater_equal.go index c35ea6c707..0e82ed29f5 100644 --- a/server/functions/binary/greater_equal.go +++ b/server/functions/binary/greater_equal.go @@ -385,7 +385,7 @@ var numeric_ge = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(*apd.Decimal), val2.(*apd.Decimal)) return res >= 0, err }, } diff --git a/server/functions/binary/less.go b/server/functions/binary/less.go index d8fed02a42..9f8ec91ade 100644 --- a/server/functions/binary/less.go +++ b/server/functions/binary/less.go @@ -385,7 +385,7 @@ var numeric_lt = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(*apd.Decimal), val2.(*apd.Decimal)) return res == -1, err }, } diff --git a/server/functions/binary/less_equal.go b/server/functions/binary/less_equal.go index 8f4eb2a15f..d34e6bde87 100644 --- a/server/functions/binary/less_equal.go +++ b/server/functions/binary/less_equal.go @@ -385,7 +385,7 @@ var numeric_le = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(*apd.Decimal), val2.(*apd.Decimal)) return res <= 0, err }, } diff --git a/server/functions/binary/minus.go b/server/functions/binary/minus.go index 30c8730620..8ac3fd968a 100644 --- a/server/functions/binary/minus.go +++ b/server/functions/binary/minus.go @@ -240,17 +240,19 @@ var numeric_sub = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) + num1 := val1.(*apd.Decimal) + num2 := val2.(*apd.Decimal) p := num1.NumDigits() if p2 := num2.NumDigits(); p < p2 { p = p2 } - _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Sub(&num1, &num1, &num2) + res := new(apd.Decimal) + // TODO does this need precision?? + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Sub(res, num1, num2) if err != nil { return nil, err } - return num1, nil + return res, nil }, } diff --git a/server/functions/binary/mod.go b/server/functions/binary/mod.go index a49c9be654..93f50416da 100644 --- a/server/functions/binary/mod.go +++ b/server/functions/binary/mod.go @@ -18,6 +18,7 @@ import ( "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -83,8 +84,8 @@ var numeric_mod = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) + num1 := val1.(*apd.Decimal) + num2 := val2.(*apd.Decimal) if num1.Form == apd.NaN || num2.Form == apd.NaN || (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { return pgtypes.NumericNaN, nil @@ -96,12 +97,8 @@ var numeric_mod = framework.Function2{ return num1, nil } if num2.Form == apd.Infinite { - return *apd.New(0, 0), nil + return apd.New(0, 0), nil } - _, err := sql.HighPrecisionCtx.Rem(&num1, &num1, &num2) - if err != nil { - return nil, err - } - return num1, nil + return types.DecimalMod(num1, num2) }, } diff --git a/server/functions/binary/multiply.go b/server/functions/binary/multiply.go index 1b7eb9f896..fbb39205b7 100644 --- a/server/functions/binary/multiply.go +++ b/server/functions/binary/multiply.go @@ -225,16 +225,17 @@ var numeric_mul = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) + num1 := val1.(*apd.Decimal) + num2 := val2.(*apd.Decimal) if (num1.Form == apd.Infinite || num2.Form == apd.Infinite) && (num1.IsZero() || num2.IsZero()) { return pgtypes.NumericNaN, nil } - _, err := sql.DecimalCtx.WithPrecision(uint32(num1.NumDigits()+num2.NumDigits())).Mul(&num1, &num1, &num2) + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(uint32(num1.NumDigits()+num2.NumDigits())).Mul(res, num1, num2) if err != nil { return nil, err } - return num1, nil + return res, nil }, } diff --git a/server/functions/binary/not_equal.go b/server/functions/binary/not_equal.go index 88f2854e73..3264f426b4 100644 --- a/server/functions/binary/not_equal.go +++ b/server/functions/binary/not_equal.go @@ -386,7 +386,7 @@ var numeric_ne = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - res, err := pgtypes.Numeric.Compare(ctx, val1.(apd.Decimal), val2.(apd.Decimal)) + res, err := pgtypes.Numeric.Compare(ctx, val1.(*apd.Decimal), val2.(*apd.Decimal)) return res != 0, err }, } diff --git a/server/functions/binary/plus.go b/server/functions/binary/plus.go index 42cccaef00..133374d356 100644 --- a/server/functions/binary/plus.go +++ b/server/functions/binary/plus.go @@ -376,8 +376,8 @@ var interval_pl_timestamptz = framework.Function2{ // numeric_add_callable is the callable logic for the numeric_add function. func numeric_add_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) + num1 := val1.(*apd.Decimal) + num2 := val2.(*apd.Decimal) if num1.Form == apd.NaN || num2.Form == apd.NaN || (num1.Form == apd.Infinite && num2.Form == apd.Infinite && num2.Negative) || (num2.Form == apd.Infinite && num1.Form == apd.Infinite && num1.Negative) { @@ -394,11 +394,13 @@ func numeric_add_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any if p2 := num2.NumDigits(); p < p2 { p = p2 } - _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Add(&num1, &num1, &num2) + + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Add(res, num1, num2) if err != nil { return nil, err } - return num1, nil + return res, nil } // numeric_add represents the PostgreSQL function of the same name, taking the same parameters. diff --git a/server/functions/ceil.go b/server/functions/ceil.go index e4dcc9a6bc..539e4e164d 100644 --- a/server/functions/ceil.go +++ b/server/functions/ceil.go @@ -55,11 +55,15 @@ var ceil_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) - _, err := sql.DecimalCtx.Ceil(&dec, &dec) + dec := val.(*apd.Decimal) + if dec.Form != apd.Finite { + return dec, nil + } + res := new(apd.Decimal) + _, err := sql.DecimalCtx.Ceil(res, dec) if err != nil { return nil, err } - return dec, nil + return res, nil }, } diff --git a/server/functions/date_part.go b/server/functions/date_part.go index f59e0be449..b6630df76d 100644 --- a/server/functions/date_part.go +++ b/server/functions/date_part.go @@ -266,21 +266,21 @@ var date_part_text_interval = framework.Function2{ }, } -func numericFloor(val any) (apd.Decimal, error) { +func numericFloor(val any) (*apd.Decimal, error) { switch val.(type) { case int64, float64: // expects these types to Scan from default: - return apd.Decimal{}, cerrors.Errorf("invalid type for numeric convert: %T", val) + return nil, cerrors.Errorf("invalid type for numeric convert: %T", val) } dec := new(apd.Decimal) err := dec.Scan(val) if err != nil { - return apd.Decimal{}, err + return nil, err } _, err = sql.DecimalCtx.Floor(dec, dec) if err != nil { - return apd.Decimal{}, err + return nil, err } - return *dec, nil + return dec, nil } diff --git a/server/functions/div.go b/server/functions/div.go index 720bbf2e7e..c48fff0649 100644 --- a/server/functions/div.go +++ b/server/functions/div.go @@ -35,8 +35,8 @@ var div_numeric = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) + num1 := val1.(*apd.Decimal) + num2 := val2.(*apd.Decimal) if num1.Form == apd.NaN || num2.Form == apd.NaN || (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { return pgtypes.NumericNaN, nil @@ -48,12 +48,13 @@ var div_numeric = framework.Function2{ return num1, nil } if num2.Form == apd.Infinite { - return *apd.New(0, 0), nil + return apd.New(0, 0), nil } - _, err := sql.DecimalCtx.QuoInteger(&num1, &num1, &num2) + res := new(apd.Decimal) + _, err := sql.DecimalHighPrecisionCtx.QuoInteger(res, num1, num2) if err != nil { return nil, err } - return num1, nil + return res, nil }, } diff --git a/server/functions/exp.go b/server/functions/exp.go index 82e08eb271..74069d94e5 100644 --- a/server/functions/exp.go +++ b/server/functions/exp.go @@ -48,11 +48,12 @@ var exp_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) - _, err := sql.DecimalCtx.WithPrecision(32).Exp(&dec, &dec) + dec := val.(*apd.Decimal) + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(32).Exp(res, dec) if err != nil { return nil, err } - return dec, nil + return res, nil }, } diff --git a/server/functions/extract.go b/server/functions/extract.go index c38a5c00a0..07ae85d991 100644 --- a/server/functions/extract.go +++ b/server/functions/extract.go @@ -226,7 +226,7 @@ var extract_text_interval = framework.Function2{ } // getFieldFromTimeVal returns the value for given field extracted from non-interval values. -func getFieldFromTimeVal(field string, tVal time.Time) (apd.Decimal, error) { +func getFieldFromTimeVal(field string, tVal time.Time) (*apd.Decimal, error) { switch strings.ToLower(field) { case "century", "centuries": if year := tVal.Year(); year <= 0 { @@ -285,27 +285,28 @@ func getFieldFromTimeVal(field string, tVal time.Time) (apd.Decimal, error) { case "year", "years": return numeric(int64(tVal.Year()), false, 0) default: - return apd.Decimal{}, cerrors.Errorf("unknown field given: %s", field) + return nil, cerrors.Errorf("unknown field given: %s", field) } } -func numeric(val any, setScale bool, scale int32) (apd.Decimal, error) { +func numeric(val any, setScale bool, scale int32) (*apd.Decimal, error) { switch val.(type) { case int64, float64: // expects these types to Scan from default: - return apd.Decimal{}, cerrors.Errorf("invalid type for numeric convert: %T", val) + return nil, cerrors.Errorf("invalid type for numeric convert: %T", val) } dec := new(apd.Decimal) err := dec.Scan(val) if err != nil { - return apd.Decimal{}, err + return nil, err } + if setScale { - _, err = sql.DecimalCtx.Quantize(dec, dec, -scale) + dec, err = sql.DecimalRound(dec, scale) if err != nil { - return apd.Decimal{}, err + return nil, err } } - return *dec, nil + return dec, nil } diff --git a/server/functions/factorial.go b/server/functions/factorial.go index 0ef3568ca9..45650af82b 100644 --- a/server/functions/factorial.go +++ b/server/functions/factorial.go @@ -44,6 +44,6 @@ var factorial_int64 = framework.Function1{ for i := int64(2); i <= n; i++ { total *= i } - return *apd.New(total, 0), nil + return apd.New(total, 0), nil }, } diff --git a/server/functions/floor.go b/server/functions/floor.go index 8b1f89004c..c5f7894f64 100644 --- a/server/functions/floor.go +++ b/server/functions/floor.go @@ -48,18 +48,22 @@ var floor_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) - newDecimalCtx := *sql.DecimalCtx - newDecimalCtx.Rounding = apd.RoundFloor + dec := val.(*apd.Decimal) + if dec.Form != apd.Finite { + return dec, nil + } - _, err := newDecimalCtx.Floor(&dec, &dec) + res := new(apd.Decimal) + newDecimalCtx := sql.DecimalCtx + newDecimalCtx.Rounding = apd.RoundFloor + _, err := newDecimalCtx.Floor(res, dec) if err != nil { return nil, err } // floor(-0.1) returns -0, which is -1 in postgres because postgres does not support -0 - if dec.IsZero() && dec.Negative { - return *apd.New(-1, 0), nil + if res.IsZero() && res.Negative { + return apd.New(-1, 0), nil } - return dec, nil + return res, nil }, } diff --git a/server/functions/generate_series.go b/server/functions/generate_series.go index d414e5be7c..15ce6481cd 100644 --- a/server/functions/generate_series.go +++ b/server/functions/generate_series.go @@ -143,10 +143,10 @@ var generate_series_numeric_numeric = framework.Function2{ Strict: true, SRF: true, Callable: func(ctx *sql.Context, t [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - start := val1.(apd.Decimal) - stop := val2.(apd.Decimal) + start := val1.(*apd.Decimal) + stop := val2.(*apd.Decimal) step := numericOne // by default - return numericGenerateSeries(start, stop, *step) + return numericGenerateSeries(start, stop, step) }, } @@ -158,16 +158,16 @@ var generate_series_numeric_numeric_numeric = framework.Function3{ Strict: true, SRF: true, Callable: func(ctx *sql.Context, t [4]*pgtypes.DoltgresType, val1, val2, val3 any) (any, error) { - start := val1.(apd.Decimal) - stop := val2.(apd.Decimal) - step := val3.(apd.Decimal) + start := val1.(*apd.Decimal) + stop := val2.(*apd.Decimal) + step := val3.(*apd.Decimal) return numericGenerateSeries(start, stop, step) }, } // numericGenerateSeries returns RowIter for generate_series function results for given numeric values. // This function checks for error of step being zero. -func numericGenerateSeries(start, stop, step apd.Decimal) (*pgtypes.SetReturningFunctionRowIter, error) { +func numericGenerateSeries(start, stop, step *apd.Decimal) (*pgtypes.SetReturningFunctionRowIter, error) { if step.IsZero() { return nil, errStepSizeZero } @@ -188,15 +188,17 @@ func numericGenerateSeries(start, stop, step apd.Decimal) (*pgtypes.SetReturning } return pgtypes.NewSetReturningFunctionRowIter(func(ctx *sql.Context) (sql.Row, error) { defer func() { - _, err := sql.DecimalCtx.Add(&start, &start, &step) + + _, err := sql.DecimalCtx.Add(start, start, step) if err != nil { panic(err) } }() - if (step.Sign() > 0 && start.Cmp(&stop) > 0) || (step.Sign() < 0 && start.Cmp(&stop) < 0) { + if (step.Sign() > 0 && start.Cmp(stop) > 0) || (step.Sign() < 0 && start.Cmp(stop) < 0) { return nil, io.EOF } - return sql.Row{start}, nil + res := new(*start) + return sql.Row{res}, nil }), nil } diff --git a/server/functions/ln.go b/server/functions/ln.go index 4bef292081..b568b94a43 100644 --- a/server/functions/ln.go +++ b/server/functions/ln.go @@ -55,7 +55,7 @@ var ln_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - dec := val1.(apd.Decimal) + dec := val1.(*apd.Decimal) if dec.Sign() == 0 { return nil, errors.Errorf("cannot take logarithm of zero") } else if dec.Sign() < 0 { @@ -64,29 +64,28 @@ var ln_numeric = framework.Function1{ return dec, nil } + // TODO: calculate precision and scale accurately exp := dec.Exponent - c := sql.DecimalCtx - if nd := uint32(dec.NumDigits()); nd > c.Precision { - c = c.WithPrecision(nd) + p := dec.NumDigits() + if exp < 0 { + p += int64(-exp) + } else if exp == 0 { + p += 16 } - _, err := c.Ln(&dec, &dec) + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Ln(res, dec) if err != nil { return nil, err } - // TODO: calculate precision and scale accurately if exp > -16 { // use ln result - parts := strings.Split(dec.Text('f'), ".") + parts := strings.Split(res.Text('f'), ".") whole := int32(len(parts[0]) / 2) exp = whole - 16 } - _, err = sql.HighPrecisionCtx.Quantize(&dec, &dec, exp) - if err != nil { - return nil, err - } - return dec, nil + return sql.DecimalRound(res, -exp) }, } diff --git a/server/functions/log.go b/server/functions/log.go index 8b994564fa..5b5115d734 100644 --- a/server/functions/log.go +++ b/server/functions/log.go @@ -57,7 +57,7 @@ var log_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - dec := val1.(apd.Decimal) + dec := val1.(*apd.Decimal) if dec.IsZero() { return nil, errors.Errorf("cannot take logarithm of zero") } else if dec.Sign() < 0 { @@ -69,12 +69,13 @@ var log_numeric = framework.Function1{ if dec.Exponent < 0 { p += uint32(-dec.Exponent) } - c := sql.DecimalCtx.WithPrecision(p) - _, err := c.Log10(&dec, &dec) + + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(p).Log10(res, dec) if err != nil { return nil, err } - return dec, nil + return res, nil }, } @@ -85,8 +86,8 @@ var log_numeric_numeric = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - base := val1.(apd.Decimal) - num := val2.(apd.Decimal) + base := val1.(*apd.Decimal) + num := val2.(*apd.Decimal) if base.IsZero() || num.IsZero() { return nil, errors.Errorf("cannot take logarithm of zero") } else if base.Sign() < 0 || num.Sign() < 0 { @@ -106,7 +107,7 @@ var log_numeric_numeric = framework.Function2{ c := sql.DecimalCtx.WithPrecision(p) lnBase := new(apd.Decimal) - _, err := c.Ln(lnBase, &base) + _, err := c.Ln(lnBase, base) if err != nil { return nil, err } @@ -115,12 +116,12 @@ var log_numeric_numeric = framework.Function2{ } lnNum := new(apd.Decimal) - _, err = c.Ln(lnNum, &num) + _, err = c.Ln(lnNum, num) if err != nil { return nil, err } if lnNum.IsZero() { - return *apd.New(0, -16), nil + return apd.New(0, -16), nil } res := new(apd.Decimal) @@ -128,11 +129,6 @@ var log_numeric_numeric = framework.Function2{ if err != nil { return nil, err } - - _, err = c.Quantize(res, res, exp) - if err != nil { - return nil, err - } - return *res, nil + return sql.DecimalRound(res, -exp) }, } diff --git a/server/functions/min_scale.go b/server/functions/min_scale.go index 65105a8cba..a937d9079d 100644 --- a/server/functions/min_scale.go +++ b/server/functions/min_scale.go @@ -36,7 +36,7 @@ var min_scale_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - dec := val1.(apd.Decimal) + dec := val1.(*apd.Decimal) if dec.Form == apd.NaN || dec.Form == apd.Infinite { return nil, nil } diff --git a/server/functions/mod.go b/server/functions/mod.go index f861e6f4b1..643eb12624 100644 --- a/server/functions/mod.go +++ b/server/functions/mod.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -81,8 +82,8 @@ var mod_numeric_numeric = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - num1 := val1.(apd.Decimal) - num2 := val2.(apd.Decimal) + num1 := val1.(*apd.Decimal) + num2 := val2.(*apd.Decimal) if num1.Form == apd.NaN || num2.Form == apd.NaN || (num1.Form == apd.Infinite && num2.Form == apd.Infinite) { return pgtypes.NumericNaN, nil @@ -94,12 +95,13 @@ var mod_numeric_numeric = framework.Function2{ return num1, nil } if num2.Form == apd.Infinite { - return *apd.New(0, 0), nil + return apd.New(0, 0), nil } - _, err := sql.HighPrecisionCtx.Rem(&num1, &num1, &num2) + + res, err := types.DecimalMod(num1, num2) if err != nil { return nil, err } - return num1, nil + return res, nil }, } diff --git a/server/functions/numeric.go b/server/functions/numeric.go index 05a6459bdf..dae3e41a5f 100644 --- a/server/functions/numeric.go +++ b/server/functions/numeric.go @@ -52,7 +52,7 @@ var numeric_in = framework.Function3{ if err != nil { return nil, pgtypes.ErrInvalidSyntaxForType.New("numeric", input) } - return pgtypes.GetNumericValueWithTypmod(*dec, typmod) + return pgtypes.GetNumericValueWithTypmod(dec, typmod) }, } @@ -64,13 +64,13 @@ var numeric_out = framework.Function1{ Strict: true, Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { typ := t[0] - dec := val.(apd.Decimal) + dec := val.(*apd.Decimal) tm := typ.GetAttTypMod() - dec, err := pgtypes.GetNumericValueWithTypmod(dec, tm) + res, err := pgtypes.GetNumericValueWithTypmod(dec, tm) if err != nil { return nil, err } - return dec.Text('f'), nil + return res.Text('f'), nil }, } @@ -90,7 +90,7 @@ var numeric_recv = framework.Function3{ return nil, nil } reader := utils.NewWireReader(data) - var d apd.Decimal + var d *apd.Decimal // 1. Read Header ndigits := reader.ReadInt16() @@ -125,7 +125,7 @@ var numeric_recv = framework.Function3{ digits[i] = reader.ReadInt16() } - // 4. Convert base-10000 to string for apd.Decimal + // 4. Convert base-10000 to string for *apd.Decimal // Each digit is exactly 4 characters wide (except potentially the first) var sb strings.Builder if sign == 16384 { @@ -156,15 +156,13 @@ var numeric_recv = framework.Function3{ sb.WriteString("0000") } } - dec, _, err := sql.HighPrecisionCtx.NewFromString(sb.String()) - if err != nil { - return nil, err - } - _, err = sql.HighPrecisionCtx.Quantize(dec, dec, int32(-dscale)) + + dec, _, err := sql.DecimalHighPrecisionCtx.NewFromString(sb.String()) if err != nil { return nil, err } - return *dec, nil + + return sql.DecimalRound(dec, int32(dscale)) }, } @@ -175,7 +173,7 @@ var numeric_send = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { - num := val.(apd.Decimal) + num := val.(*apd.Decimal) typmod := t[0].GetAttTypMod() writer := utils.NewWireWriter() if num.Form == apd.Finite { @@ -316,8 +314,8 @@ var numeric_cmp = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - ab := val1.(apd.Decimal) - bb := val2.(apd.Decimal) + ab := val1.(*apd.Decimal) + bb := val2.(*apd.Decimal) return int32(pgtypes.NumericCompare(ab, bb)), nil }, } diff --git a/server/functions/power.go b/server/functions/power.go index 45170657cf..456a4acc84 100644 --- a/server/functions/power.go +++ b/server/functions/power.go @@ -61,8 +61,8 @@ var power_numeric_numeric = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - dec1 := val1.(apd.Decimal) - dec2 := val2.(apd.Decimal) + dec1 := val1.(*apd.Decimal) + dec2 := val2.(*apd.Decimal) if dec1.Form == apd.NaN || dec2.Form == apd.NaN { return pgtypes.NumericNaN, nil } @@ -84,40 +84,34 @@ var power_numeric_numeric = framework.Function2{ return pgtypes.NumericNegInf, nil } if (dec2.Form == apd.Infinite && dec2.Negative) || dec2.Sign() < 0 { - return *apd.New(0, 0), nil + return apd.New(0, 0), nil } - return *apd.New(1, 0), nil + return apd.New(1, 0), nil } + if dec1.IsZero() { if dec2.Sign() < 0 { // includes neg inf return nil, errPowerZeroToNegative } if dec2.Form == apd.Infinite { - return *apd.New(0, 0), nil + return apd.New(0, 0), nil } if dec2.Sign() > 0 { - d := *apd.New(0, 0) - _, _ = sql.DecimalCtx.Quantize(&d, &d, -16) - return d, nil + return sql.DecimalRound(apd.New(0, 0), 16) } } // decimal.Pow() does not handle the zero exponent properly, so we special case it - if dec2.IsZero() || dec1.Cmp(numericOne) == 0 { - d := *apd.New(1, 0) - _, _ = sql.DecimalCtx.Quantize(&d, &d, -16) - return d, nil + return sql.DecimalRound(apd.New(1, 0), 16) } + // give enough precision that we can round it to 16 exp - _, err := sql.DecimalCtx.WithPrecision(17).Pow(&dec1, &dec1, &dec2) - if err != nil { - return nil, err - } - _, err = sql.DecimalCtx.Quantize(&dec1, &dec1, -16) + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(17).Pow(res, dec1, dec2) if err != nil { return nil, err } - return dec1, nil + return sql.DecimalRound(res, 16) }, } diff --git a/server/functions/round.go b/server/functions/round.go index 820bdbc4ca..3835db3e2f 100644 --- a/server/functions/round.go +++ b/server/functions/round.go @@ -49,16 +49,13 @@ var round_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) - _, err := sql.DecimalCtx.Round(&dec, &dec) + dec := val.(*apd.Decimal) + res := new(apd.Decimal) + _, err := sql.DecimalHighPrecisionCtx.Round(res, dec) if err != nil { return nil, err } - _, err = sql.DecimalCtx.Quantize(&dec, &dec, 0) - if err != nil { - return nil, err - } - return dec, nil + return sql.DecimalRound(res, 0) }, } @@ -69,16 +66,13 @@ var round_numeric_int64 = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Int64}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { - dec := val1.(apd.Decimal) + dec := val1.(*apd.Decimal) places := val2.(int64) - _, err := sql.HighPrecisionCtx.Round(&dec, &dec) - if err != nil { - return nil, err - } - _, err = sql.HighPrecisionCtx.Quantize(&dec, &dec, int32(-places)) + res := new(apd.Decimal) + _, err := sql.DecimalHighPrecisionCtx.Round(res, dec) if err != nil { return nil, err } - return dec, nil + return sql.DecimalRound(res, int32(places)) }, } diff --git a/server/functions/sign.go b/server/functions/sign.go index daf99dfcc8..5cfd6b6f97 100644 --- a/server/functions/sign.go +++ b/server/functions/sign.go @@ -53,7 +53,7 @@ var sign_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) - return *apd.New(int64(dec.Sign()), 0), nil + dec := val.(*apd.Decimal) + return apd.New(int64(dec.Sign()), 0), nil }, } diff --git a/server/functions/sqrt.go b/server/functions/sqrt.go index 365ee5929d..446214953c 100644 --- a/server/functions/sqrt.go +++ b/server/functions/sqrt.go @@ -52,30 +52,26 @@ var sqrt_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) + dec := val.(*apd.Decimal) if dec.Sign() < 0 { return nil, errors.Errorf("cannot take square root of a negative number") } - exp := dec.Exponent - c := sql.DecimalCtx - nd := uint32(dec.NumDigits()) - if nd > c.Precision { - c = c.WithPrecision(nd) - } // TODO: calculate precision and scale accurately - if dec.Exponent == 0 { - exp = int32(nd/2) - 16 + exp := dec.Exponent + p := dec.NumDigits() + if exp < 0 { + p += int64(-exp) + } else if exp == 0 { + exp = int32(p/2) - 16 + p += 16 } - _, err := c.Sqrt(&dec, &dec) - if err != nil { - return nil, err - } - _, err = c.Quantize(&dec, &dec, exp) + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Sqrt(res, dec) if err != nil { return nil, err } - return dec, nil + return sql.DecimalRound(res, -exp) }, } diff --git a/server/functions/to_char.go b/server/functions/to_char.go index d948e8db0a..6d4b98601c 100644 --- a/server/functions/to_char.go +++ b/server/functions/to_char.go @@ -138,7 +138,7 @@ var to_char_numeric_text = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Text}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - //timestamp := val1.(apd.Decimal) + //timestamp := val1.(*apd.Decimal) //format := val2.(string) return nil, errors.Errorf(`to_char(numeric,text) is not supported yet`) diff --git a/server/functions/trim_scale.go b/server/functions/trim_scale.go index 7a66e0f5e8..cb19850c9b 100644 --- a/server/functions/trim_scale.go +++ b/server/functions/trim_scale.go @@ -37,6 +37,6 @@ var trim_scale_numeric = framework.Function1{ Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { // We don't store the scale in the value, so I'm not sure if this is functionally correct. // Seems like we'd need to modify the type of the return value (by trimming the scale), rather than the value itself. - return val.(apd.Decimal), nil + return val.(*apd.Decimal), nil }, } diff --git a/server/functions/trunc.go b/server/functions/trunc.go index 0e37c12ca8..15c6a7272c 100644 --- a/server/functions/trunc.go +++ b/server/functions/trunc.go @@ -49,12 +49,8 @@ var trunc_numeric = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - dec := val.(apd.Decimal) - _, err := sql.HighPrecisionCtx.Quantize(&dec, &dec, 0) - if err != nil { - return nil, err - } - return dec, nil + dec := val.(*apd.Decimal) + return sql.DecimalRound(dec, 0) }, } @@ -65,12 +61,8 @@ var trunc_numeric_int64 = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Int32}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, num any, places any) (any, error) { - dec := num.(apd.Decimal) + dec := num.(*apd.Decimal) scale := places.(int32) - _, err := sql.HighPrecisionCtx.Quantize(&dec, &dec, -scale) - if err != nil { - return nil, err - } - return dec, nil + return sql.DecimalRound(dec, -scale) }, } diff --git a/server/functions/unary/minus.go b/server/functions/unary/minus.go index 126b0bda4e..04b11f281c 100644 --- a/server/functions/unary/minus.go +++ b/server/functions/unary/minus.go @@ -112,8 +112,8 @@ var numeric_uminus = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.Numeric}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val1 any) (any, error) { - dec := val1.(apd.Decimal) - neg := dec.Neg(&dec) - return *neg, nil + dec := val1.(*apd.Decimal) + res := new(apd.Decimal) + return res.Neg(dec), nil }, } diff --git a/server/functions/width_bucket.go b/server/functions/width_bucket.go index 92087eed51..814348411c 100644 --- a/server/functions/width_bucket.go +++ b/server/functions/width_bucket.go @@ -71,23 +71,23 @@ var width_bucket_numeric_numeric_numeric_int64 = framework.Function4{ Parameters: [4]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Numeric, pgtypes.Numeric, pgtypes.Int32}, Strict: true, Callable: func(ctx *sql.Context, _ [5]*pgtypes.DoltgresType, operandInterface any, lowInterface any, highInterface any, countInterface any) (any, error) { - operand := operandInterface.(apd.Decimal) - low := lowInterface.(apd.Decimal) - high := highInterface.(apd.Decimal) - if low.Cmp(&high) == 0 { + operand := operandInterface.(*apd.Decimal) + low := lowInterface.(*apd.Decimal) + high := highInterface.(*apd.Decimal) + if low.Cmp(high) == 0 { return nil, errors.Errorf("lower bound cannot equal upper bound") } count := countInterface.(int32) if count <= 0 { return nil, errors.Errorf("count must be greater than zero") } - if operand.Cmp(&high) == 0 { + if operand.Cmp(high) == 0 { return count + 1, nil - } else if operand.Cmp(&low) == 0 { + } else if operand.Cmp(low) == 0 { return int32(1), nil } bucket := new(apd.Decimal) - _, err := sql.DecimalCtx.Sub(bucket, &high, &low) + _, err := sql.DecimalCtx.Sub(bucket, high, low) if err != nil { return nil, err } @@ -96,7 +96,7 @@ var width_bucket_numeric_numeric_numeric_int64 = framework.Function4{ return nil, err } result := new(apd.Decimal) - _, err = sql.DecimalCtx.Sub(result, &operand, &low) + _, err = sql.DecimalCtx.Sub(result, operand, low) if err != nil { return nil, err } diff --git a/server/types/json_document.go b/server/types/json_document.go index ac049e3729..744c3fcbce 100644 --- a/server/types/json_document.go +++ b/server/types/json_document.go @@ -252,7 +252,8 @@ func JsonValueSerialize(writer *utils.Writer, value JsonValue) { case JsonValueNumber: writer.Byte(byte(JsonValueType_Number)) // MarshalBinary cannot error, so we can safely ignore it - bytes, _ := Numeric.SerializationFunc(nil, Numeric, apd.Decimal(value)) + v := apd.Decimal(value) + bytes, _ := Numeric.SerializationFunc(nil, Numeric, &v) writer.ByteSlice(bytes) case JsonValueBoolean: writer.Byte(byte(JsonValueType_Boolean)) @@ -299,7 +300,7 @@ func JsonValueDeserialize(reader *utils.Reader) (_ JsonValue, err error) { if d == nil { d = apd.Decimal{} } - return JsonValueNumber(d.(apd.Decimal)), err + return JsonValueNumber(*(d.(*apd.Decimal))), err case JsonValueType_Boolean: return JsonValueBoolean(reader.Bool()), nil case JsonValueType_Null: diff --git a/server/types/numeric.go b/server/types/numeric.go index f8d3ec8701..eb283daddf 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -32,9 +32,9 @@ var ( NumericValueMinInt16 = types.DecimalFromInt64(math.MinInt16) // NumericValueMinInt16 is the min Int16 value for NUMERIC types NumericValueMinInt32 = types.DecimalFromInt64(math.MinInt32) // NumericValueMinInt32 is the min Int32 value for NUMERIC types NumericValueMinInt64 = types.DecimalFromInt64(math.MinInt64) // NumericValueMinInt64 is the min Int64 value for NUMERIC types - NumericNaN = apd.Decimal{Form: apd.NaN} - NumericInf = apd.Decimal{Form: apd.Infinite} - NumericNegInf = apd.Decimal{Form: apd.Infinite, Negative: true} + NumericNaN = &apd.Decimal{Form: apd.NaN} + NumericInf = &apd.Decimal{Form: apd.Infinite} + NumericNegInf = &apd.Decimal{Form: apd.Infinite, Negative: true} ) // Numeric is a precise and unbounded decimal value. @@ -106,36 +106,36 @@ func GetPrecisionAndScaleFromTypmod(typmod int32) (int32, int32) { // GetNumericValueWithTypmod returns either given numeric value or truncated or error // depending on the precision and scale decoded from given type modifier value. -func GetNumericValueWithTypmod(val apd.Decimal, typmod int32) (apd.Decimal, error) { +func GetNumericValueWithTypmod(val *apd.Decimal, typmod int32) (*apd.Decimal, error) { if typmod == -1 { return val, nil } res := new(apd.Decimal) precision, scale := GetPrecisionAndScaleFromTypmod(typmod) - _, err := sql.DecimalCtx.WithPrecision(uint32(precision)).Quantize(res, &val, -scale) + _, err := sql.DecimalCtx.WithPrecision(uint32(precision)).Quantize(res, val, -scale) if err != nil { - return apd.Decimal{}, errors.Errorf("numeric field overflow - A field with precision %v, scale %v must round to an absolute value less than 10^%v", precision, scale, precision-scale) + return nil, errors.Errorf("numeric field overflow - A field with precision %v, scale %v must round to an absolute value less than 10^%v", precision, scale, precision-scale) } - return *res, nil + return res, nil } // GetNumericValueFromStringWithTypmod returns either given numeric value or truncated or error // depending on the precision and scale decoded from given type modifier value. -func GetNumericValueFromStringWithTypmod(val string, typmod int32) (apd.Decimal, error) { - dec, cond, err := sql.HighPrecisionCtx.NewFromString(val) +func GetNumericValueFromStringWithTypmod(val string, typmod int32) (*apd.Decimal, error) { + dec, cond, err := sql.DecimalHighPrecisionCtx.NewFromString(val) if err != nil { - return apd.Decimal{}, err + return nil, err } if cond.Inexact() || cond.Rounded() { - return apd.Decimal{}, errors.Errorf(`numeric precision was lost or truncated for %s`, val) + return nil, errors.Errorf(`numeric precision was lost or truncated for %s`, val) } - return GetNumericValueWithTypmod(*dec, typmod) + return GetNumericValueWithTypmod(dec, typmod) } // serializeTypeNumeric handles serialization from the standard representation to our serialized representation that is // written in Dolt. func serializeTypeNumeric(ctx *sql.Context, t *DoltgresType, val any) ([]byte, error) { - d := val.(apd.Decimal) + d := val.(*apd.Decimal) return d.MarshalText() } @@ -145,13 +145,13 @@ func deserializeTypeNumeric(ctx *sql.Context, t *DoltgresType, data []byte) (any if len(data) == 0 { return nil, nil } - retVal := *apd.New(0, 0) + retVal := apd.New(0, 0) err := retVal.UnmarshalText(data) return retVal, err } -// NumericCompare compares two apd.Decimal values handling NaN separately. -func NumericCompare(ab, bb apd.Decimal) int { +// NumericCompare compares two *apd.Decimal values handling NaN separately. +func NumericCompare(ab, bb *apd.Decimal) int { if (ab.Form == apd.NaN && bb.Form == apd.NaN) || (ab.Form == apd.Infinite && bb.Form == apd.Infinite && ab.Negative == bb.Negative) { return 0 @@ -162,5 +162,5 @@ func NumericCompare(ab, bb apd.Decimal) int { if bb.Form == apd.NaN { return -1 } - return ab.Cmp(&bb) + return ab.Cmp(bb) } diff --git a/server/types/type.go b/server/types/type.go index fc1fce9821..aa188a814b 100644 --- a/server/types/type.go +++ b/server/types/type.go @@ -328,8 +328,8 @@ func (t *DoltgresType) Compare(ctx context.Context, v1 interface{}, v2 interface case JsonDocument: bb := v2.(JsonDocument) return JsonValueCompare(ab.Value, bb.Value), nil - case apd.Decimal: - bb := v2.(apd.Decimal) + case *apd.Decimal: + bb := v2.(*apd.Decimal) return NumericCompare(ab, bb), nil case timeofday.TimeOfDay: bb := v2.(timeofday.TimeOfDay) @@ -1060,7 +1060,7 @@ func (t *DoltgresType) Zero() interface{} { case "int8": return int64(0) case "numeric": - return *apd.New(0, 0) + return apd.New(0, 0) case "oid", "regclass", "regproc", "regtype": return id.Null default: diff --git a/testing/generation/function_coverage/output/framework_test.go b/testing/generation/function_coverage/output/framework_test.go index f71303f3a2..5ad94542a1 100644 --- a/testing/generation/function_coverage/output/framework_test.go +++ b/testing/generation/function_coverage/output/framework_test.go @@ -247,7 +247,7 @@ func Numeric(str string) pgtype.Numeric { } // NumericToDecimal converts a pgtype.Numeric value to a apd.Decimal value. -func NumericToDecimal(val pgtype.Numeric) apd.Decimal { +func NumericToDecimal(val pgtype.Numeric) *apd.Decimal { if val.NaN { return pgtypes.NumericNaN } else if val.InfinityModifier == pgtype.Infinity { @@ -256,7 +256,7 @@ func NumericToDecimal(val pgtype.Numeric) apd.Decimal { return pgtypes.NumericNegInf } - return *apd.New(val.Int.Int64(), val.Exp) + return apd.New(val.Int.Int64(), val.Exp) } // CompareResults compares two sets of results, taking the equivalence thresholds into account when making the @@ -297,15 +297,19 @@ func CompareRows(t *testing.T, a sql.Row, b sql.Row) bool { case pgtype.Numeric: aDec := NumericToDecimal(aVal.(pgtype.Numeric)) bDec := NumericToDecimal(bVal.(pgtype.Numeric)) - _, err := sql.DecimalCtx.Sub(&aDec, &aDec, &bDec) + p := aDec.NumDigits() + if p2 := bDec.NumDigits(); p < p2 { + p = p2 + } + res := new(apd.Decimal) + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Sub(res, aDec, bDec) if err != nil { return false } - aDec = *aDec.Abs(&aDec) // EquivalenceThresholdNumeric represents the allowable delta for values to be considered equivalent. // This is computed using the float64 variant so that they're equivalent. EquivalenceThresholdNumeric, _, _ := apd.NewFromString(strconv.FormatFloat(EquivalenceThresholdFloat64, 'f', -1, 64)) - if aDec.Cmp(EquivalenceThresholdNumeric) == 1 { + if res.Abs(res).Cmp(EquivalenceThresholdNumeric) == 1 { return false } default: diff --git a/testing/go/framework.go b/testing/go/framework.go index 314b51dac3..00ed40f995 100644 --- a/testing/go/framework.go +++ b/testing/go/framework.go @@ -777,7 +777,7 @@ func NormalizeVal(dt *types.DoltgresType, v any) any { } else if !val.Valid { return nil } else { - return *apd.New(val.Int.Int64(), val.Exp) + return apd.New(val.Int.Int64(), val.Exp) } case pgtype.Time: // This value type is used for TIME type. From cd4e5e9a2aa7680ab3edcac5621680bd4fcb181a Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 12 May 2026 10:01:28 -0700 Subject: [PATCH 09/16] fix regressed test --- server/cast/jsonb.go | 3 ++- testing/go/types_test.go | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/server/cast/jsonb.go b/server/cast/jsonb.go index 0fe90e2c24..568d605877 100644 --- a/server/cast/jsonb.go +++ b/server/cast/jsonb.go @@ -193,7 +193,8 @@ func jsonbExplicit() { case pgtypes.JsonValueString: return nil, errors.Errorf("cannot cast jsonb string to type %s", targetType.String()) case pgtypes.JsonValueNumber: - return apd.Decimal(value), nil + v := apd.Decimal(value) + return &v, nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: diff --git a/testing/go/types_test.go b/testing/go/types_test.go index ecc1bd6572..a59c110917 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -1972,6 +1972,10 @@ var typesTests = []ScriptTest{ Query: "SELECT floor(-0.000001);", Expected: []sql.Row{{Numeric("-1")}}, }, + { + Query: `select '12345'::jsonb::numeric;`, + Expected: []sql.Row{{Numeric("12345")}}, + }, }, }, { From d953ec8b98514567d9a22a50284bc17eb4f7b451 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 12 May 2026 16:20:21 -0700 Subject: [PATCH 10/16] fix tests and feedback --- go.mod | 4 ++-- go.sum | 8 ++++---- server/cast/jsonb.go | 6 +++--- server/cast/numeric.go | 6 +++--- server/functions/binary/divide.go | 15 +++++++++++++-- server/functions/binary/minus.go | 9 +++------ server/functions/binary/multiply.go | 3 ++- server/functions/binary/plus.go | 9 +++------ server/functions/ln.go | 7 ++++--- server/functions/log.go | 2 +- server/functions/numeric.go | 2 +- server/functions/power.go | 17 +++++++++++++---- server/functions/sqrt.go | 2 +- server/functions/to_char.go | 3 --- testing/go/functions_test.go | 1 - testing/go/operators_test.go | 13 ++++++++++++- 16 files changed, 65 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index 84bd35fa08..824a321b72 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,10 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v3 v3.2.3 github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20260511182609-ab34f8c300fb + github.com/dolthub/dolt/go v0.40.5-0.20260512211612-add2b8cc21e8 github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.20.1-0.20260507202550-43d6daf5958b + github.com/dolthub/go-mysql-server v0.20.1-0.20260512202859-bb7a7d4fe7ee github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390 diff --git a/go.sum b/go.sum index addca0bb13..d2e9ab56c2 100644 --- a/go.sum +++ b/go.sum @@ -245,8 +245,8 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= github.com/dolthub/dolt-mcp v0.3.4 h1:AyG5cw+fNWXDHXujtQnqUPZrpWtPg6FN6yYtjv1pP44= github.com/dolthub/dolt-mcp v0.3.4/go.mod h1:bCZ7KHvDYs+M0e+ySgmGiNvLhcwsN7bbf5YCyillLrk= -github.com/dolthub/dolt/go v0.40.5-0.20260511182609-ab34f8c300fb h1:KUkux82SVxpTIRtuzbBIu9T3FDTs4zCmbrMMScwpCSo= -github.com/dolthub/dolt/go v0.40.5-0.20260511182609-ab34f8c300fb/go.mod h1:zQgkmBI5uIRpzo+liO3Q6by+rcLWGcvyJhAAI/rmkHE= +github.com/dolthub/dolt/go v0.40.5-0.20260512211612-add2b8cc21e8 h1:xArHGgqwLIezRsv/ccEvnarZ3Ln2dYhLqgZfYD4XZo0= +github.com/dolthub/dolt/go v0.40.5-0.20260512211612-add2b8cc21e8/go.mod h1:Ptwc329qRFlzs6oK+dYJ7ELZViw0YDqidlF9fyuCw18= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 h1:JShhbqMw26nKx3pqqu/cFxOpzBkN+4elVhzuUfgDw2k= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69/go.mod h1:SSLraQS/jGLYFgff3vuZ+JbVUct6vyEeMzjLBqWqoyM= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -255,8 +255,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866 h1:U6gSf5I0e6h6GP1/5Sa7D2lWW1CWfcVPtY5wkyHq6jY= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE= -github.com/dolthub/go-mysql-server v0.20.1-0.20260507202550-43d6daf5958b h1:Ew5nacrGUHRVyae1+/0vUjt3ZbX/6vaDEZfeDNMCSj8= -github.com/dolthub/go-mysql-server v0.20.1-0.20260507202550-43d6daf5958b/go.mod h1:55n1yslSIZ5uewFbtd82DsYt3f9vUKwnRN5GZJie+nE= +github.com/dolthub/go-mysql-server v0.20.1-0.20260512202859-bb7a7d4fe7ee h1:SKOX+Z9td39qkUrJg7y52JFy5pvOApZVWAnnSx4Qekk= +github.com/dolthub/go-mysql-server v0.20.1-0.20260512202859-bb7a7d4fe7ee/go.mod h1:uE//4hYOi/1zIgOYgr8D+u7OvDfI4yPXFsL4dTZWk5I= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20260414231531-5f031e3e9037 h1:oIW9HwuWrhxv+4HZxA+QQSKHLqWFyXZ2FmNjUYwkdiM= diff --git a/server/cast/jsonb.go b/server/cast/jsonb.go index 568d605877..2302913d52 100644 --- a/server/cast/jsonb.go +++ b/server/cast/jsonb.go @@ -119,7 +119,7 @@ func jsonbExplicit() { if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Errorf("smallint out of range") } - return int16(types.DecimalIntPart(&d)), nil + return int16(types.DecimalRoundedIntPart(&d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -145,7 +145,7 @@ func jsonbExplicit() { if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Errorf("integer out of range") } - return int32(types.DecimalIntPart(&d)), nil + return int32(types.DecimalRoundedIntPart(&d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: @@ -171,7 +171,7 @@ func jsonbExplicit() { if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Errorf("bigint out of range") } - return int64(types.DecimalIntPart(&d)), nil + return int64(types.DecimalRoundedIntPart(&d)), nil case pgtypes.JsonValueBoolean: return nil, errors.Errorf("cannot cast jsonb boolean to type %s", targetType.String()) case pgtypes.JsonValueNull: diff --git a/server/cast/numeric.go b/server/cast/numeric.go index 100a0fb219..814c73fa9f 100644 --- a/server/cast/numeric.go +++ b/server/cast/numeric.go @@ -40,7 +40,7 @@ func numericAssignment() { if d.Cmp(pgtypes.NumericValueMinInt16) < 0 || d.Cmp(pgtypes.NumericValueMaxInt16) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "smallint out of range") } - i := types.DecimalIntPart(d) + i := types.DecimalRoundedIntPart(d) return int16(i), nil }, }) @@ -52,7 +52,7 @@ func numericAssignment() { if d.Cmp(pgtypes.NumericValueMinInt32) < 0 || d.Cmp(pgtypes.NumericValueMaxInt32) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "integer out of range") } - i := types.DecimalIntPart(d) + i := types.DecimalRoundedIntPart(d) return int32(i), nil }, }) @@ -64,7 +64,7 @@ func numericAssignment() { if d.Cmp(pgtypes.NumericValueMinInt64) < 0 || d.Cmp(pgtypes.NumericValueMaxInt64) > 0 { return nil, errors.Wrap(pgtypes.ErrCastOutOfRange, "bigint out of range") } - return types.DecimalIntPart(d), nil + return types.DecimalRoundedIntPart(d), nil }, }) } diff --git a/server/functions/binary/divide.go b/server/functions/binary/divide.go index 40d0a893d5..b8df8f45df 100644 --- a/server/functions/binary/divide.go +++ b/server/functions/binary/divide.go @@ -15,6 +15,9 @@ package binary import ( + "math" + "strings" + "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" @@ -303,11 +306,19 @@ func numeric_div_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any } res := new(apd.Decimal) - _, err := sql.DecimalHighPrecisionCtx.Quo(res, num1, num2) + // enough precision to scale to at most 16 decimal places + p := num1.NumDigits() + int64(math.Abs(float64(num1.Exponent))) + 16 + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Quo(res, num1, num2) if err != nil { return nil, err } - return sql.DecimalRound(res, 16) + + // the exponent value depends on the number of digits before decimal places. + parts := strings.Split(res.Text('f'), ".") + whole := (len(parts[0]) + 4 - 1) / 4 + exp := int32(16 - (whole-1)*4) + + return sql.DecimalRound(res, exp) } // numeric_div represents the PostgreSQL function of the same name, taking the same parameters. diff --git a/server/functions/binary/minus.go b/server/functions/binary/minus.go index 8ac3fd968a..1ea782ab1c 100644 --- a/server/functions/binary/minus.go +++ b/server/functions/binary/minus.go @@ -242,13 +242,10 @@ var numeric_sub = framework.Function2{ Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any, val2 any) (any, error) { num1 := val1.(*apd.Decimal) num2 := val2.(*apd.Decimal) - p := num1.NumDigits() - if p2 := num2.NumDigits(); p < p2 { - p = p2 - } res := new(apd.Decimal) - // TODO does this need precision?? - _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Sub(res, num1, num2) + p := uint32(math.Max(float64(num1.NumDigits()), float64(num2.NumDigits())+ + math.Max(math.Abs(float64(num1.Exponent)), math.Abs(float64(num2.Exponent))))) + _, err := sql.DecimalCtx.WithPrecision(p).Sub(res, num1, num2) if err != nil { return nil, err } diff --git a/server/functions/binary/multiply.go b/server/functions/binary/multiply.go index fbb39205b7..cff962bcfc 100644 --- a/server/functions/binary/multiply.go +++ b/server/functions/binary/multiply.go @@ -231,7 +231,8 @@ var numeric_mul = framework.Function2{ return pgtypes.NumericNaN, nil } res := new(apd.Decimal) - _, err := sql.DecimalCtx.WithPrecision(uint32(num1.NumDigits()+num2.NumDigits())).Mul(res, num1, num2) + p := num1.NumDigits() + num2.NumDigits() + int64(math.Max(math.Abs(float64(num1.Exponent)), math.Abs(float64(num2.Exponent)))) + _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Mul(res, num1, num2) if err != nil { return nil, err } diff --git a/server/functions/binary/plus.go b/server/functions/binary/plus.go index 133374d356..f9a7933c2b 100644 --- a/server/functions/binary/plus.go +++ b/server/functions/binary/plus.go @@ -390,13 +390,10 @@ func numeric_add_callable(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1 any return pgtypes.NumericInf, nil } - p := num1.NumDigits() - if p2 := num2.NumDigits(); p < p2 { - p = p2 - } - res := new(apd.Decimal) - _, err := sql.DecimalCtx.WithPrecision(uint32(p)).Add(res, num1, num2) + p := uint32(math.Max(float64(num1.NumDigits()), float64(num2.NumDigits())+ + math.Max(math.Abs(float64(num1.Exponent)), math.Abs(float64(num2.Exponent))))) + _, err := sql.DecimalCtx.WithPrecision(p).Add(res, num1, num2) if err != nil { return nil, err } diff --git a/server/functions/ln.go b/server/functions/ln.go index b568b94a43..fd827e0246 100644 --- a/server/functions/ln.go +++ b/server/functions/ln.go @@ -64,12 +64,12 @@ var ln_numeric = framework.Function1{ return dec, nil } - // TODO: calculate precision and scale accurately + // calculate precision and scale exp := dec.Exponent p := dec.NumDigits() - if exp < 0 { + if exp < -16 { p += int64(-exp) - } else if exp == 0 { + } else { p += 16 } @@ -79,6 +79,7 @@ var ln_numeric = framework.Function1{ return nil, err } + // calculate exponent if exp > -16 { // use ln result parts := strings.Split(res.Text('f'), ".") diff --git a/server/functions/log.go b/server/functions/log.go index 5b5115d734..a3ffd3cd26 100644 --- a/server/functions/log.go +++ b/server/functions/log.go @@ -64,7 +64,7 @@ var log_numeric = framework.Function1{ return nil, errors.Errorf("cannot take logarithm of a negative number") } - // TODO: calculate precision and scale accurately + // calculate precision and scale p := uint32(17) if dec.Exponent < 0 { p += uint32(-dec.Exponent) diff --git a/server/functions/numeric.go b/server/functions/numeric.go index dae3e41a5f..d2b2aa646f 100644 --- a/server/functions/numeric.go +++ b/server/functions/numeric.go @@ -178,7 +178,7 @@ var numeric_send = framework.Function1{ writer := utils.NewWireWriter() if num.Form == apd.Finite { // Short-circuit if this is the zero value - if num.IsZero() { + if num.IsZero() && num.Exponent == 0 { writer.WriteBytes([]byte{0, 0, 0, 0, 0, 0, 0, 0}) return writer.BufferData(), nil } diff --git a/server/functions/power.go b/server/functions/power.go index 456a4acc84..7b5a20c3a2 100644 --- a/server/functions/power.go +++ b/server/functions/power.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" @@ -67,6 +68,7 @@ var power_numeric_numeric = framework.Function2{ return pgtypes.NumericNaN, nil } if dec1.Form == apd.Infinite && dec1.Negative { + // dec1 is -Infinity even := dec2.Form == apd.Infinite && !dec2.Negative if dec2.Form == apd.Finite { i, err := dec2.Int64() @@ -75,7 +77,6 @@ var power_numeric_numeric = framework.Function2{ } even = i%2 == 0 } - if dec2.Sign() > 0 { // +inf will return neginf == fix!! if even { @@ -87,9 +88,17 @@ var power_numeric_numeric = framework.Function2{ return apd.New(0, 0), nil } return apd.New(1, 0), nil - } - - if dec1.IsZero() { + } else if dec1.Form == apd.Infinite { + // dec1 is +Infinity + d2Sign := dec2.Sign() + if dec2.Sign() < 0 { + return apd.New(0, 0), nil + } else if d2Sign == 0 { + return apd.New(1, 0), nil + } + return types.DecimalPosInf, nil + } else if dec1.IsZero() { + // dec1 is 0 if dec2.Sign() < 0 { // includes neg inf return nil, errPowerZeroToNegative diff --git a/server/functions/sqrt.go b/server/functions/sqrt.go index 446214953c..09780f2dd2 100644 --- a/server/functions/sqrt.go +++ b/server/functions/sqrt.go @@ -57,7 +57,7 @@ var sqrt_numeric = framework.Function1{ return nil, errors.Errorf("cannot take square root of a negative number") } - // TODO: calculate precision and scale accurately + // calculate precision and scale exp := dec.Exponent p := dec.NumDigits() if exp < 0 { diff --git a/server/functions/to_char.go b/server/functions/to_char.go index 6d4b98601c..ea4f0fd8b9 100644 --- a/server/functions/to_char.go +++ b/server/functions/to_char.go @@ -138,9 +138,6 @@ var to_char_numeric_text = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Numeric, pgtypes.Text}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - //timestamp := val1.(*apd.Decimal) - //format := val2.(string) - return nil, errors.Errorf(`to_char(numeric,text) is not supported yet`) }, } diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 85b1a12e78..6a3aea56ca 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -638,7 +638,6 @@ func TestFunctionsMath(t *testing.T) { }, }, { - Skip: true, //TODO: fix Query: `select power('0'::numeric, '3'::numeric);`, Expected: []sql.Row{ {Numeric("0.0000000000000000")}, diff --git a/testing/go/operators_test.go b/testing/go/operators_test.go index 3b8627254d..97a0dcd211 100644 --- a/testing/go/operators_test.go +++ b/testing/go/operators_test.go @@ -749,10 +749,21 @@ func TestOperators(t *testing.T) { Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { - Skip: true, // TODO: fix scaling for division + Query: `SELECT 444000804::numeric / 2::int2;`, + Expected: []sql.Row{{Numeric("222000402.00000000")}}, + }, + { Query: `SELECT 44400080::numeric / 2::int2;`, Expected: []sql.Row{{Numeric("22200040.000000000000")}}, }, + { + Query: `SELECT 44400::numeric / 2::int2;`, + Expected: []sql.Row{{Numeric("22200.000000000000")}}, + }, + { + Query: `SELECT 4440::numeric / 2::int2;`, + Expected: []sql.Row{{Numeric("2220.0000000000000000")}}, + }, { Query: `SELECT 8::numeric / 2::int4;`, Expected: []sql.Row{{Numeric("4.0000000000000000")}}, From 3dc622965296ff7ab4fcc0d4c3e0b706a8043a52 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 12 May 2026 16:36:21 -0700 Subject: [PATCH 11/16] update backward compatibility version to v0.56.0 --- .../compatibility/test_files/backward_compatible_versions.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/integration-tests/compatibility/test_files/backward_compatible_versions.txt b/integration-tests/compatibility/test_files/backward_compatible_versions.txt index d6cea8599b..2aaafe4d2a 100644 --- a/integration-tests/compatibility/test_files/backward_compatible_versions.txt +++ b/integration-tests/compatibility/test_files/backward_compatible_versions.txt @@ -3,5 +3,4 @@ # Format: one version tag per line, e.g. v0.8.0 # Keep this list reasonably short; each entry downloads a binary and runs the full # test suite, so CI time grows linearly. -v0.55.0 -v0.51.0 +v0.56.0 From b53e0f5cbffecefd7cb5d7045cb8e82eb2cefe2c Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 12 May 2026 21:14:33 -0700 Subject: [PATCH 12/16] set backward compatibility to v0.56.2 --- .../compatibility/test_files/backward_compatible_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-tests/compatibility/test_files/backward_compatible_versions.txt b/integration-tests/compatibility/test_files/backward_compatible_versions.txt index 2aaafe4d2a..5c37ff4aa7 100644 --- a/integration-tests/compatibility/test_files/backward_compatible_versions.txt +++ b/integration-tests/compatibility/test_files/backward_compatible_versions.txt @@ -3,4 +3,4 @@ # Format: one version tag per line, e.g. v0.8.0 # Keep this list reasonably short; each entry downloads a binary and runs the full # test suite, so CI time grows linearly. -v0.56.0 +v0.56.2 From 2888a9800b0306ea5f01ba8dc3d881fc363bdde4 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Wed, 13 May 2026 15:19:23 -0700 Subject: [PATCH 13/16] bug fix for backwards compat --- server/types/numeric.go | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/server/types/numeric.go b/server/types/numeric.go index eb283daddf..fb2ff28f12 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" + "github.com/shopspring/decimal" "github.com/dolthub/doltgresql/core/id" ) @@ -134,26 +135,42 @@ func GetNumericValueFromStringWithTypmod(val string, typmod int32) (*apd.Decimal // serializeTypeNumeric handles serialization from the standard representation to our serialized representation that is // written in Dolt. +// Note: this function is only used for values serialized by older clients, which is why it uses a decimal.Decimal. +// Newer clients will use Dolt's serialization, which uses apd.Decimal directly. +// Deprecated. func serializeTypeNumeric(ctx *sql.Context, t *DoltgresType, val any) ([]byte, error) { - d := val.(*apd.Decimal) - return d.MarshalText() + switch d := val.(type) { + case decimal.Decimal: + return d.MarshalBinary() + case *apd.Decimal: + dec := decimal.NewFromBigInt(d.Coeff.MathBigInt(), d.Exponent) + return dec.MarshalBinary() + default: + return nil, errors.Errorf("cannot serialize value of type %T as numeric", val) + } } // deserializeTypeNumeric handles deserialization from the Dolt serialized format to our standard representation used by // expressions and nodes. +// Note: this function is only used for values serialized by older clients, which is why it uses a decimal.Decimal. +// Newer clients will use Dolt's serialization, which uses apd.Decimal directly. +// Deprecated. func deserializeTypeNumeric(ctx *sql.Context, t *DoltgresType, data []byte) (any, error) { if len(data) == 0 { return nil, nil } - retVal := apd.New(0, 0) - err := retVal.UnmarshalText(data) - return retVal, err + retVal := decimal.NewFromInt(0) + err := retVal.UnmarshalBinary(data) + if err != nil { + return nil, err + } + return apd.New(retVal.CoefficientInt64(), retVal.Exponent()), nil } // NumericCompare compares two *apd.Decimal values handling NaN separately. func NumericCompare(ab, bb *apd.Decimal) int { if (ab.Form == apd.NaN && bb.Form == apd.NaN) || - (ab.Form == apd.Infinite && bb.Form == apd.Infinite && ab.Negative == bb.Negative) { + (ab.Form == apd.Infinite && bb.Form == apd.Infinite && ab.Negative == bb.Negative) { return 0 } if ab.Form == apd.NaN { From 80d7cadf7985150519ce4d6204e6cf620ec62029 Mon Sep 17 00:00:00 2001 From: zachmu Date: Wed, 13 May 2026 22:25:50 +0000 Subject: [PATCH 14/16] [ga-format-pr] Run scripts/format_repo.sh --- server/types/numeric.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/types/numeric.go b/server/types/numeric.go index fb2ff28f12..82d8183bf4 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -170,7 +170,7 @@ func deserializeTypeNumeric(ctx *sql.Context, t *DoltgresType, data []byte) (any // NumericCompare compares two *apd.Decimal values handling NaN separately. func NumericCompare(ab, bb *apd.Decimal) int { if (ab.Form == apd.NaN && bb.Form == apd.NaN) || - (ab.Form == apd.Infinite && bb.Form == apd.Infinite && ab.Negative == bb.Negative) { + (ab.Form == apd.Infinite && bb.Form == apd.Infinite && ab.Negative == bb.Negative) { return 0 } if ab.Form == apd.NaN { From ff4317ba9fa6ac3f4eae98d9a0e4cfd098d3cf61 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Wed, 13 May 2026 16:08:26 -0700 Subject: [PATCH 15/16] fix --- .../test_files/backward_compatible_versions.txt | 2 +- server/types/numeric.go | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/integration-tests/compatibility/test_files/backward_compatible_versions.txt b/integration-tests/compatibility/test_files/backward_compatible_versions.txt index 5c37ff4aa7..2aaafe4d2a 100644 --- a/integration-tests/compatibility/test_files/backward_compatible_versions.txt +++ b/integration-tests/compatibility/test_files/backward_compatible_versions.txt @@ -3,4 +3,4 @@ # Format: one version tag per line, e.g. v0.8.0 # Keep this list reasonably short; each entry downloads a binary and runs the full # test suite, so CI time grows linearly. -v0.56.2 +v0.56.0 diff --git a/server/types/numeric.go b/server/types/numeric.go index 82d8183bf4..e3235c0db0 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -143,7 +143,11 @@ func serializeTypeNumeric(ctx *sql.Context, t *DoltgresType, val any) ([]byte, e case decimal.Decimal: return d.MarshalBinary() case *apd.Decimal: - dec := decimal.NewFromBigInt(d.Coeff.MathBigInt(), d.Exponent) + bigInt := d.Coeff.MathBigInt() + if d.Negative { + bigInt.Neg(bigInt) + } + dec := decimal.NewFromBigInt(bigInt, d.Exponent) return dec.MarshalBinary() default: return nil, errors.Errorf("cannot serialize value of type %T as numeric", val) From 6647fb3bf14208669cf82bc2e582ac667cf6e3a9 Mon Sep 17 00:00:00 2001 From: jennifersp <44716627+jennifersp@users.noreply.github.com> Date: Thu, 14 May 2026 02:21:45 +0000 Subject: [PATCH 16/16] [ga-bump-dep] Bump dependency in Doltgres by jennifersp --- go.mod | 5 +++-- go.sum | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index a762cd89c7..79047fcc80 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,10 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20260511182609-ab34f8c300fb + github.com/dolthub/dolt/go v0.40.5-0.20260514021628-29c8b0a35aa9 github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.20.1-0.20260507202550-43d6daf5958b + github.com/dolthub/go-mysql-server v0.20.1-0.20260513232454-bbbd50eb8e47 github.com/dolthub/pg_query_go/v6 v6.0.0-20251215122834-fb20be4254d1 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20260505163811-77e5224be390 @@ -97,6 +97,7 @@ require ( github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect + github.com/cockroachdb/apd/v3 v3.2.3 // indirect github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f // indirect github.com/cockroachdb/redact v1.0.6 // indirect github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2 // indirect diff --git a/go.sum b/go.sum index 06fd32dcdd..36b0df13e7 100644 --- a/go.sum +++ b/go.sum @@ -207,6 +207,8 @@ github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a h1:9VFe4R5FRCUyidB1rdm3XdCRVuD/75P7Y4PtzEGhEE4= github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a/go.mod h1:DDxRlzC2lo3/vSlmSoS7JkqbbrARPuFOGr0B9pvN3Gw= +github.com/cockroachdb/apd/v3 v3.2.3 h1:4Zx+I3R35bFXMnltzmjP79i2cravE4jTRL6ps9Aux80= +github.com/cockroachdb/apd/v3 v3.2.3/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= github.com/cockroachdb/datadriven v1.0.0/go.mod h1:5Ib8Meh+jk1RlHIXej6Pzevx/NLlNvQB9pmSBZErGA4= github.com/cockroachdb/errors v1.6.1/go.mod h1:tm6FTP5G81vwJ5lC0SizQo374JNCOPrHyXGitRJoDqM= github.com/cockroachdb/errors v1.7.5 h1:ptyO1BLW+sBxwBTSKJfS6kGzYCVKhI7MyBhoXAnPIKM= @@ -245,8 +247,8 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= github.com/dolthub/dolt-mcp v0.3.4 h1:AyG5cw+fNWXDHXujtQnqUPZrpWtPg6FN6yYtjv1pP44= github.com/dolthub/dolt-mcp v0.3.4/go.mod h1:bCZ7KHvDYs+M0e+ySgmGiNvLhcwsN7bbf5YCyillLrk= -github.com/dolthub/dolt/go v0.40.5-0.20260511182609-ab34f8c300fb h1:KUkux82SVxpTIRtuzbBIu9T3FDTs4zCmbrMMScwpCSo= -github.com/dolthub/dolt/go v0.40.5-0.20260511182609-ab34f8c300fb/go.mod h1:zQgkmBI5uIRpzo+liO3Q6by+rcLWGcvyJhAAI/rmkHE= +github.com/dolthub/dolt/go v0.40.5-0.20260514021628-29c8b0a35aa9 h1:PPYToX+QqSr89DiAKkxm5KxT7OIdwLJMPD9IujLTxjM= +github.com/dolthub/dolt/go v0.40.5-0.20260514021628-29c8b0a35aa9/go.mod h1:mhl/QwBfzZrITl6XsjaAamvfYnKWWqXvLs992+oFz3o= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69 h1:JShhbqMw26nKx3pqqu/cFxOpzBkN+4elVhzuUfgDw2k= github.com/dolthub/eventsapi_schema v0.0.0-20260310172945-37a9265ade69/go.mod h1:SSLraQS/jGLYFgff3vuZ+JbVUct6vyEeMzjLBqWqoyM= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -255,8 +257,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866 h1:U6gSf5I0e6h6GP1/5Sa7D2lWW1CWfcVPtY5wkyHq6jY= github.com/dolthub/go-icu-regex v0.0.0-20260412212219-49724d547866/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE= -github.com/dolthub/go-mysql-server v0.20.1-0.20260507202550-43d6daf5958b h1:Ew5nacrGUHRVyae1+/0vUjt3ZbX/6vaDEZfeDNMCSj8= -github.com/dolthub/go-mysql-server v0.20.1-0.20260507202550-43d6daf5958b/go.mod h1:55n1yslSIZ5uewFbtd82DsYt3f9vUKwnRN5GZJie+nE= +github.com/dolthub/go-mysql-server v0.20.1-0.20260513232454-bbbd50eb8e47 h1:59cpEXHrN4oubATtax1ybjHWDorO6h1nE5GRudFCyVc= +github.com/dolthub/go-mysql-server v0.20.1-0.20260513232454-bbbd50eb8e47/go.mod h1:uE//4hYOi/1zIgOYgr8D+u7OvDfI4yPXFsL4dTZWk5I= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20260414231531-5f031e3e9037 h1:oIW9HwuWrhxv+4HZxA+QQSKHLqWFyXZ2FmNjUYwkdiM=