Skip to content

Commit fc4597f

Browse files
authored
PostgreSQL: filter pushdown over numeric columns (#340)
1 parent 58dbb72 commit fc4597f

File tree

10 files changed

+376
-6
lines changed

10 files changed

+376
-6
lines changed

app/server/datasource/rdbms/postgresql/sql_formatter.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ func (f sqlFormatter) supportsConstantValueExpression(t *Ydb.Type) bool {
6060
return f.supportsType(v.TypeId)
6161
case *Ydb.Type_OptionalType:
6262
return f.supportsConstantValueExpression(v.OptionalType.Item)
63+
case *Ydb.Type_DecimalType:
64+
return true
6365
default:
6466
return false
6567
}

app/server/datasource/rdbms/utils/predicate_builder.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
api_common "github.com/ydb-platform/fq-connector-go/api/common"
1313
api_service_protos "github.com/ydb-platform/fq-connector-go/api/service/protos"
14+
"github.com/ydb-platform/fq-connector-go/app/server/utils/decimal"
1415
"github.com/ydb-platform/fq-connector-go/common"
1516
)
1617

@@ -36,11 +37,11 @@ func (pb *predicateBuilder) formatValue(
3637
return pb.formatOptionalValue(value)
3738
}
3839

39-
return pb.formatPrimitiveValue(value, embedBool)
40+
return pb.formatTypedValue(value, embedBool)
4041
}
4142

4243
//nolint:gocyclo
43-
func (pb *predicateBuilder) formatPrimitiveValue(
44+
func (pb *predicateBuilder) formatTypedValue(
4445
value *Ydb.TypedValue,
4546
embedBool bool, // remove after YQ-4191, KIKIMR-22852 is fixed
4647
) (string, error) {
@@ -86,8 +87,18 @@ func (pb *predicateBuilder) formatPrimitiveValue(
8687
pb.args.AddTyped(value.Type, v.DoubleValue)
8788
return pb.formatter.GetPlaceholder(pb.args.Count() - 1), nil
8889
case *Ydb.Value_BytesValue:
89-
pb.args.AddTyped(value.Type, v.BytesValue)
90-
return pb.formatter.GetPlaceholder(pb.args.Count() - 1), nil
90+
switch t := value.Type.Type.(type) {
91+
case *Ydb.Type_TypeId:
92+
pb.args.AddTyped(value.Type, v.BytesValue)
93+
return pb.formatter.GetPlaceholder(pb.args.Count() - 1), nil
94+
case *Ydb.Type_DecimalType:
95+
decimalValue := decimal.Deserialize(v.BytesValue, t.DecimalType.Scale)
96+
pb.args.AddTyped(value.Type, decimalValue)
97+
98+
return pb.formatter.GetPlaceholder(pb.args.Count() - 1), nil
99+
default:
100+
return "", fmt.Errorf("unsupported type '%T' for bytes value: %w", v, common.ErrUnimplementedTypedValue)
101+
}
91102
case *Ydb.Value_TextValue:
92103
pb.args.AddTyped(value.Type, v.TextValue)
93104
return pb.formatter.GetPlaceholder(pb.args.Count() - 1), nil
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Package decimal provides utilities for working with decimal numbers.
2+
package decimal
3+
4+
import (
5+
"math/big"
6+
"slices"
7+
8+
"github.com/shopspring/decimal"
9+
)
10+
11+
// Deserialize converts a byte array representation to a decimal value
12+
func Deserialize(
13+
src []byte, // source byte array
14+
scale uint32, // scale factor
15+
) *decimal.Decimal {
16+
// Make a copy of the source to avoid modifying the original
17+
buf := make([]byte, len(src))
18+
copy(buf, src)
19+
20+
// LittleEndian -> BigEndian
21+
slices.Reverse(buf)
22+
23+
// Create a new big.Int from the bytes
24+
bigInt := new(big.Int).SetBytes(buf)
25+
26+
// Check if the number is negative (most significant bit is set)
27+
isNegative := len(buf) > 0 && (buf[0]&0x80) != 0
28+
29+
if isNegative {
30+
// For negative numbers: subtract from 2^{8*blobSize} to get the original negative value
31+
twoToThe128 := new(big.Int).Lsh(big.NewInt(1), uint(blobSize*8))
32+
bigInt = new(big.Int).Sub(twoToThe128, bigInt)
33+
bigInt.Neg(bigInt)
34+
}
35+
36+
// Create decimal from big.Int
37+
result := decimal.NewFromBigInt(bigInt, 0)
38+
39+
// Only shift when scale > 0
40+
if scale > 0 {
41+
result = result.Shift(-int32(scale))
42+
}
43+
44+
return &result
45+
}
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
package decimal
2+
3+
import (
4+
"testing"
5+
6+
"github.com/shopspring/decimal"
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestDeserialize(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
input []byte
14+
scale uint32
15+
expected string
16+
}{
17+
{
18+
name: "positive small number (1)",
19+
input: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
20+
scale: 0,
21+
expected: "1",
22+
},
23+
{
24+
name: "positive number with two bytes (257)",
25+
input: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
26+
scale: 0,
27+
expected: "257",
28+
},
29+
{
30+
name: "negative number (-2)",
31+
input: []byte{254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
32+
scale: 0,
33+
expected: "-2",
34+
},
35+
{
36+
name: "zero",
37+
input: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
38+
scale: 0,
39+
expected: "0",
40+
},
41+
{
42+
name: "large positive number",
43+
input: []byte{255, 255, 255, 255, 255, 255, 255, 127, 0, 0, 0, 0, 0, 0, 0, 0},
44+
scale: 0,
45+
expected: "9223372036854775807", // max int64
46+
},
47+
{
48+
name: "with scale - divide by 10",
49+
input: []byte{206, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // 1230
50+
scale: 1,
51+
expected: "123",
52+
},
53+
{
54+
name: "example from task - scale 0",
55+
input: []byte{174, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
56+
scale: 0,
57+
expected: "2222",
58+
},
59+
{
60+
name: "example from task - scale 2",
61+
input: []byte{174, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
62+
scale: 2,
63+
expected: "22.22",
64+
},
65+
}
66+
67+
for _, tt := range tests {
68+
t.Run(tt.name, func(t *testing.T) {
69+
// Call Deserialize
70+
result := Deserialize(tt.input, tt.scale)
71+
72+
// Check result
73+
expected, err := decimal.NewFromString(tt.expected)
74+
assert.NoError(t, err)
75+
assert.True(t, expected.Equal(*result), "Expected %s, got %s", expected.String(), result.String())
76+
})
77+
}
78+
}
79+
80+
// TestDeserializeEdgeCases tests edge cases for the Deserialize function
81+
func TestDeserializeEdgeCases(t *testing.T) {
82+
// Test with a very large decimal number that requires big.Int
83+
t.Run("very large number", func(t *testing.T) {
84+
input := []byte{0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0}
85+
result := Deserialize(input, 0)
86+
87+
expected, err := decimal.NewFromString("9223372036854775808") // One more than max int64
88+
assert.NoError(t, err)
89+
assert.True(t, expected.Equal(*result), "Expected %s, got %s", expected.String(), result.String())
90+
})
91+
92+
// Test with a very small negative number
93+
t.Run("very small negative number", func(t *testing.T) {
94+
input := []byte{255, 255, 255, 255, 255, 255, 255, 127, 255, 255, 255, 255, 255, 255, 255, 255}
95+
result := Deserialize(input, 0)
96+
97+
expected, err := decimal.NewFromString("-9223372036854775809") // One less than min int64
98+
assert.NoError(t, err)
99+
assert.True(t, expected.Equal(*result), "Expected %s, got %s", expected.String(), result.String())
100+
})
101+
}
102+
103+
func TestRoundTrip(t *testing.T) {
104+
serializer := NewSerializer()
105+
106+
tests := []struct {
107+
name string
108+
value string
109+
scale uint32
110+
}{
111+
{
112+
name: "positive integer",
113+
value: "12345",
114+
scale: 0,
115+
},
116+
{
117+
name: "negative integer",
118+
value: "-98765",
119+
scale: 0,
120+
},
121+
{
122+
name: "decimal value",
123+
value: "123.45",
124+
scale: 2,
125+
},
126+
{
127+
name: "negative decimal value",
128+
value: "-987.65",
129+
scale: 2,
130+
},
131+
{
132+
name: "large number",
133+
value: "9223372036854775807",
134+
scale: 0,
135+
},
136+
{
137+
name: "very large number",
138+
value: "9223372036854775808",
139+
scale: 0,
140+
},
141+
}
142+
143+
for _, tt := range tests {
144+
t.Run(tt.name, func(t *testing.T) {
145+
// Create original decimal
146+
original, err := decimal.NewFromString(tt.value)
147+
assert.NoError(t, err)
148+
149+
// Serialize
150+
buffer := make([]byte, blobSize)
151+
serializer.Serialize(&original, tt.scale, buffer)
152+
153+
// Deserialize
154+
result := Deserialize(buffer, tt.scale)
155+
156+
// Compare
157+
assert.True(t, original.Equal(*result), "Round trip failed: original %s, got %s", original.String(), result.String())
158+
})
159+
}
160+
}
161+
162+
func BenchmarkDeserialize(b *testing.B) {
163+
tests := []struct {
164+
name string
165+
input []byte
166+
scale uint32
167+
}{
168+
{
169+
name: "positive small number",
170+
input: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
171+
scale: 0,
172+
},
173+
{
174+
name: "positive number with two bytes",
175+
input: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
176+
scale: 0,
177+
},
178+
{
179+
name: "negative number",
180+
input: []byte{254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
181+
scale: 0,
182+
},
183+
{
184+
name: "zero",
185+
input: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
186+
scale: 0,
187+
},
188+
{
189+
name: "large positive number",
190+
input: []byte{255, 255, 255, 255, 255, 255, 255, 127, 0, 0, 0, 0, 0, 0, 0, 0},
191+
scale: 0,
192+
},
193+
{
194+
name: "with scale",
195+
input: []byte{206, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
196+
scale: 1,
197+
},
198+
{
199+
name: "very large number",
200+
input: []byte{0, 0, 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 0},
201+
scale: 0,
202+
},
203+
{
204+
name: "very small negative number",
205+
input: []byte{255, 255, 255, 255, 255, 255, 255, 127, 255, 255, 255, 255, 255, 255, 255, 255},
206+
scale: 0,
207+
},
208+
}
209+
210+
for _, tt := range tests {
211+
b.Run(tt.name, func(b *testing.B) {
212+
b.ReportAllocs()
213+
214+
for i := 0; i < b.N; i++ {
215+
Deserialize(tt.input, tt.scale)
216+
}
217+
})
218+
}
219+
}
File renamed without changes.

scripts/debug/kqprun/script.postgresql.local.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
-- SELECT * FROM external_datasource.lineitem WHERE l_linenumber > 0;
44

5-
SELECT col_27_numeric_int, col_28_numeric_rational FROM external_datasource.primitives WHERE col_27_numeric_int = Decimal(1);
5+
SELECT col_27_numeric_int, col_28_numeric_rational FROM external_datasource.primitives WHERE col_27_numeric_int = Decimal("1", 10, 0);
6+
7+
-- SELECT col_27_numeric_int, col_28_numeric_rational FROM external_datasource.primitives WHERE col_28_numeric_rational = Decimal("-22.22", 4, 2);
68

7-
-- SELECT col_27_numeric_int, col_28_numeric_rational FROM external_datasource.primitives WHERE id = 1;

tests/infra/datasource/postgresql/init/init_db.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,16 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-E
9494
(2, 20, 'b'), \
9595
(3, 30, 'c'), \
9696
(4, NULL, NULL);
97+
EOSQL
98+
99+
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
100+
DROP TABLE IF EXISTS pushdown_decimal;
101+
CREATE TABLE pushdown_decimal (
102+
id int,
103+
col_27_numeric_int numeric(10, 0),
104+
col_28_numeric_rational numeric(4, 2)
105+
);
106+
INSERT INTO pushdown_decimal (id, col_27_numeric_int, col_28_numeric_rational) VALUES \
107+
(1, 1, 11.11), \
108+
(2, -2, -22.22);
97109
EOSQL

tests/infra/datasource/postgresql/suite.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,36 @@ func NewSuite(
345345

346346
return result
347347
}
348+
349+
func (s *Suite) TestPushdownDecimalIntEQ() {
350+
// Test for: SELECT * FROM table WHERE col_27_numeric_int = Decimal("1", 10, 0);
351+
s.ValidateTable(
352+
s.dataSource,
353+
tables["pushdown_decimal_int_EQ"],
354+
suite.WithPredicate(&api_service_protos.TPredicate{
355+
Payload: tests_utils.MakePredicateComparisonColumn(
356+
"col_27_numeric_int",
357+
api_service_protos.TPredicate_TComparison_EQ,
358+
common.MakeTypedValue(common.MakeDecimalType(10, 0), []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
359+
),
360+
}),
361+
)
362+
}
363+
364+
func (s *Suite) TestPushdownDecimalRationalEQ() {
365+
// Test for: SELECT * FROM table WHERE col_28_numeric_rational = Decimal("-22.22", 4, 2);
366+
s.ValidateTable(
367+
s.dataSource,
368+
tables["pushdown_decimal_rational_EQ"],
369+
suite.WithPredicate(&api_service_protos.TPredicate{
370+
Payload: tests_utils.MakePredicateComparisonColumn(
371+
"col_28_numeric_rational",
372+
api_service_protos.TPredicate_TComparison_EQ,
373+
common.MakeTypedValue(
374+
common.MakeDecimalType(4, 2),
375+
[]byte{82, 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255},
376+
),
377+
),
378+
}),
379+
)
380+
}

0 commit comments

Comments
 (0)