Skip to content

Commit c00a455

Browse files
committed
add tests for special float values in vector search
1 parent 9f5a1a6 commit c00a455

File tree

1 file changed

+66
-5
lines changed

1 file changed

+66
-5
lines changed

search_test.go

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
package redis_test
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/binary"
57
"fmt"
6-
"strconv"
8+
"math"
79
"strings"
810
"time"
911

1012
. "github.com/bsm/ginkgo/v2"
1113
. "github.com/bsm/gomega"
1214
"github.com/redis/go-redis/v9"
15+
"github.com/redis/go-redis/v9/helper"
1316
)
1417

1518
func WaitForIndexing(c *redis.Client, index string) {
@@ -27,6 +30,14 @@ func WaitForIndexing(c *redis.Client, index string) {
2730
}
2831
}
2932

33+
func encodeFloat32Vector(vec []float32) []byte {
34+
buf := new(bytes.Buffer)
35+
for _, v := range vec {
36+
binary.Write(buf, binary.LittleEndian, v)
37+
}
38+
return buf.Bytes()
39+
}
40+
3041
var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
3142
ctx := context.TODO()
3243
var client *redis.Client
@@ -693,9 +704,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
693704
Expect(err).NotTo(HaveOccurred())
694705
Expect(res).ToNot(BeNil())
695706
Expect(len(res.Rows)).To(BeEquivalentTo(2))
696-
score1, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]), 64)
707+
score1, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[0].Fields["__score"]))
697708
Expect(err).NotTo(HaveOccurred())
698-
score2, err := strconv.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]), 64)
709+
score2, err := helper.ParseFloat(fmt.Sprintf("%s", res.Rows[1].Fields["__score"]))
699710
Expect(err).NotTo(HaveOccurred())
700711
Expect(score1).To(BeNumerically(">", score2))
701712

@@ -712,9 +723,9 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
712723
Expect(err).NotTo(HaveOccurred())
713724
Expect(resDM).ToNot(BeNil())
714725
Expect(len(resDM.Rows)).To(BeEquivalentTo(2))
715-
score1DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]), 64)
726+
score1DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[0].Fields["__score"]))
716727
Expect(err).NotTo(HaveOccurred())
717-
score2DM, err := strconv.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]), 64)
728+
score2DM, err := helper.ParseFloat(fmt.Sprintf("%s", resDM.Rows[1].Fields["__score"]))
718729
Expect(err).NotTo(HaveOccurred())
719730
Expect(score1DM).To(BeNumerically(">", score2DM))
720731

@@ -1684,6 +1695,56 @@ var _ = Describe("RediSearch commands Resp 2", Label("search"), func() {
16841695
Expect(resUint8.Docs[0].ID).To(BeEquivalentTo("doc1"))
16851696
})
16861697

1698+
It("should return special float scores in FT.SEARCH vecsim", Label("search", "ftsearch", "vecsim"), func() {
1699+
SkipBeforeRedisVersion(7.4, "doesn't work with older redis stack images")
1700+
1701+
vecField := &redis.FTFlatOptions{
1702+
Type: "FLOAT32",
1703+
Dim: 2,
1704+
DistanceMetric: "IP",
1705+
}
1706+
_, err := client.FTCreate(ctx, "idx_vec",
1707+
&redis.FTCreateOptions{OnHash: true, Prefix: []interface{}{"doc:"}},
1708+
&redis.FieldSchema{FieldName: "vector", FieldType: redis.SearchFieldTypeVector, VectorArgs: &redis.FTVectorArgs{FlatOptions: vecField}}).Result()
1709+
Expect(err).NotTo(HaveOccurred())
1710+
WaitForIndexing(client, "idx_vec")
1711+
1712+
bigPos := []float32{1e38, 1e38}
1713+
bigNeg := []float32{-1e38, -1e38}
1714+
nanVec := []float32{float32(math.NaN()), 0}
1715+
negNanVec := []float32{float32(math.Copysign(math.NaN(), -1)), 0}
1716+
1717+
client.HSet(ctx, "doc:1", "vector", encodeFloat32Vector(bigPos))
1718+
client.HSet(ctx, "doc:2", "vector", encodeFloat32Vector(bigNeg))
1719+
client.HSet(ctx, "doc:3", "vector", encodeFloat32Vector(nanVec))
1720+
client.HSet(ctx, "doc:4", "vector", encodeFloat32Vector(negNanVec))
1721+
1722+
searchOptions := &redis.FTSearchOptions{WithScores: true, Params: map[string]interface{}{"vec": encodeFloat32Vector(bigPos)}}
1723+
res, err := client.FTSearchWithArgs(ctx, "idx_vec", "*=>[KNN 4 @vector $vec]", searchOptions).Result()
1724+
Expect(err).NotTo(HaveOccurred())
1725+
Expect(res.Total).To(BeEquivalentTo(4))
1726+
1727+
var scores []float64
1728+
for _, row := range res.Docs {
1729+
raw := fmt.Sprintf("%v", row.Fields["__vector_score"])
1730+
f, err := helper.ParseFloat(raw)
1731+
Expect(err).NotTo(HaveOccurred())
1732+
scores = append(scores, f)
1733+
}
1734+
1735+
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(1))))
1736+
Expect(scores).To(ContainElement(BeNumerically("==", math.Inf(-1))))
1737+
1738+
// For NaN values, use a custom check since NaN != NaN in floating point math
1739+
nanCount := 0
1740+
for _, score := range scores {
1741+
if math.IsNaN(score) {
1742+
nanCount++
1743+
}
1744+
}
1745+
Expect(nanCount).To(Equal(2))
1746+
})
1747+
16871748
It("should fail when using a non-zero offset with a zero limit", Label("search", "ftsearch"), func() {
16881749
SkipBeforeRedisVersion(7.9, "requires Redis 8.x")
16891750
val, err := client.FTCreate(ctx, "testIdx", &redis.FTCreateOptions{}, &redis.FieldSchema{

0 commit comments

Comments
 (0)