Skip to content

Commit 7e2cc56

Browse files
vlaskynoreplyclaude
committed
Add Lua binding with IEEE 754 compliant float serialization
/examples/simple-lua/ contains a demo script and runner. Incorporates upstream PR asg017#237 with the following bugfixes: Extension loading: - Fix return value check: lsqlite3's load_extension returns true on success, not sqlite3.OK (which is 0). Changed from `if ok then` to `if ok and result then` to properly detect successful loads. - Add vec0 naming paths alongside sqlite-vec paths for this fork. IEEE 754 float serialization (float_to_bytes): - Switch from half-round-up to round-half-to-even (banker's rounding) for IEEE 754 compliance. This prevents systematic bias when processing large datasets where half-values accumulate. - Handle special cases: NaN, Inf, -Inf, and -0.0 which the original implementation did not support. - Fix subnormal number encoding: corrected formula from 2^(exp+126) to 2^(exp+127) so minimum subnormal 2^(-149) encodes correctly. - Add mantissa overflow carry: when rounding causes mantissa >= 2^23, carry into exponent field. - Add exponent overflow handling: values too large now return ±Inf instead of producing corrupted output. - Use epsilon comparison (1e-9) for 0.5 tie detection to handle floating-point precision issues. JSON serialization (serialize_json): - Error on NaN and Infinity values which are not valid JSON. - Convert -0.0 to 0.0 for JSON compatibility. Co-Authored-By: asr1 <[email protected]> Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 920b0d2 commit 7e2cc56

File tree

5 files changed

+427
-0
lines changed

5 files changed

+427
-0
lines changed

bindings/lua/sqlite_vec.lua

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
-- sqlite_vec.lua Lua 5.1 compatible version with JSON fallback
2+
local sqlite3 = require("lsqlite3")
3+
4+
local M = {}
5+
6+
-- Function to load extension
7+
function M.load(db)
8+
local possible_paths = {
9+
-- vec0 naming (this fork)
10+
"../../dist/vec0.so", -- Linux
11+
"../../dist/vec0.dll", -- Windows
12+
"../../dist/vec0.dylib", -- macOS
13+
"./dist/vec0.so",
14+
"./dist/vec0.dll",
15+
"./dist/vec0.dylib",
16+
"../dist/vec0.so",
17+
"../dist/vec0.dll",
18+
"../dist/vec0.dylib",
19+
"vec0",
20+
-- sqlite-vec naming (upstream)
21+
"../../sqlite-vec.so",
22+
"../../sqlite-vec.dll",
23+
"../../sqlite-vec.dylib",
24+
"./sqlite-vec.so",
25+
"./sqlite-vec.dll",
26+
"./sqlite-vec.dylib",
27+
"../sqlite-vec.so",
28+
"../sqlite-vec.dll",
29+
"../sqlite-vec.dylib",
30+
"sqlite-vec",
31+
}
32+
33+
local entry_point = "sqlite3_vec_init"
34+
35+
if db.enable_load_extension then
36+
db:enable_load_extension(true)
37+
for _, path in ipairs(possible_paths) do
38+
local ok, result = pcall(function()
39+
return db:load_extension(path, entry_point)
40+
end)
41+
-- lsqlite3 load_extension returns true on success
42+
if ok and result then
43+
db:enable_load_extension(false)
44+
return true
45+
end
46+
end
47+
db:enable_load_extension(false)
48+
error("Failed to load extension from all paths")
49+
else
50+
for _, path in ipairs(possible_paths) do
51+
local ok, result = pcall(function()
52+
return db:load_extension(path, entry_point)
53+
end)
54+
-- lsqlite3 load_extension returns true on success
55+
if ok and result then
56+
return true
57+
end
58+
end
59+
error("Failed to load extension from all paths")
60+
end
61+
end
62+
63+
-- Lua 5.1 compatible float to binary conversion function (IEEE 754 single precision, little-endian)
64+
local function float_to_bytes(f)
65+
-- Handle special cases: NaN, Inf, -Inf, -0.0
66+
if f ~= f then
67+
-- NaN: exponent=255, mantissa!=0, sign=0 (quiet NaN)
68+
return string.char(0, 0, 192, 127)
69+
elseif f == math.huge then
70+
-- +Inf: exponent=255, mantissa=0, sign=0
71+
return string.char(0, 0, 128, 127)
72+
elseif f == -math.huge then
73+
-- -Inf: exponent=255, mantissa=0, sign=1
74+
return string.char(0, 0, 128, 255)
75+
elseif f == 0 then
76+
-- Check for -0.0 vs +0.0
77+
if 1/f == -math.huge then
78+
-- -0.0: sign=1, exponent=0, mantissa=0
79+
return string.char(0, 0, 0, 128)
80+
else
81+
-- +0.0
82+
return string.char(0, 0, 0, 0)
83+
end
84+
end
85+
86+
local sign = 0
87+
if f < 0 then
88+
sign = 1
89+
f = -f
90+
end
91+
92+
local mantissa, exponent = math.frexp(f)
93+
-- math.frexp returns mantissa in [0.5, 1), we need [1, 2) for IEEE 754
94+
exponent = exponent - 1
95+
96+
local is_subnormal = exponent < -126
97+
if is_subnormal then
98+
-- Subnormal number: exponent field is 0, mantissa is denormalized
99+
-- Formula: mantissa_stored = value * 2^149 = m * 2^(e + 149)
100+
-- Since exponent = e - 1, we need: m * 2^(exponent + 1 + 149) = m * 2^(exponent + 150)
101+
-- After multiplying by 2^23 later: m * 2^(exponent + 150) becomes the stored mantissa
102+
-- Simplified: mantissa = m * 2^(exponent + 127) before the 2^23 scaling
103+
mantissa = mantissa * 2^(exponent + 127)
104+
exponent = 0
105+
else
106+
-- Normal number: remove implicit leading 1
107+
-- frexp returns mantissa in [0.5, 1), convert to [0, 1) for IEEE 754
108+
mantissa = (mantissa - 0.5) * 2
109+
exponent = exponent + 127
110+
end
111+
112+
-- Round half to even (banker's rounding) for IEEE 754 compliance
113+
local scaled = mantissa * 2^23
114+
local floor_val = math.floor(scaled)
115+
local frac = scaled - floor_val
116+
-- Use epsilon comparison for 0.5 to handle floating-point precision issues
117+
local is_half = math.abs(frac - 0.5) < 1e-9
118+
if frac > 0.5 + 1e-9 or (is_half and floor_val % 2 == 1) then
119+
mantissa = floor_val + 1
120+
else
121+
mantissa = floor_val
122+
end
123+
124+
-- Handle mantissa overflow from rounding (mantissa >= 2^23)
125+
if mantissa >= 2^23 then
126+
if is_subnormal then
127+
-- Subnormal rounded up to smallest normal
128+
mantissa = 0
129+
exponent = 1
130+
else
131+
-- Normal number: carry into exponent
132+
mantissa = 0
133+
exponent = exponent + 1
134+
end
135+
end
136+
137+
-- Handle exponent overflow -> Infinity
138+
if exponent >= 255 then
139+
-- Return ±Infinity
140+
if sign == 1 then
141+
return string.char(0, 0, 128, 255) -- -Inf
142+
else
143+
return string.char(0, 0, 128, 127) -- +Inf
144+
end
145+
end
146+
147+
-- Encode as little-endian IEEE 754 single precision
148+
local bytes = {}
149+
bytes[1] = mantissa % 256
150+
mantissa = math.floor(mantissa / 256)
151+
bytes[2] = mantissa % 256
152+
mantissa = math.floor(mantissa / 256)
153+
bytes[3] = (mantissa % 128) + (exponent % 2) * 128
154+
exponent = math.floor(exponent / 2)
155+
bytes[4] = exponent + sign * 128
156+
157+
return string.char(bytes[1], bytes[2], bytes[3], bytes[4])
158+
end
159+
160+
-- Helper function: serialize float vector to binary format (little-endian IEEE 754)
161+
function M.serialize_f32(vector)
162+
local buffer = {}
163+
164+
if string.pack then
165+
-- Use "<f" for little-endian float (Lua 5.3+)
166+
for _, v in ipairs(vector) do
167+
table.insert(buffer, string.pack("<f", v))
168+
end
169+
else
170+
-- Lua 5.1/5.2 fallback
171+
for _, v in ipairs(vector) do
172+
table.insert(buffer, float_to_bytes(v))
173+
end
174+
end
175+
176+
return table.concat(buffer)
177+
end
178+
179+
-- JSON format vector serialization
180+
-- Note: JSON does not support NaN, Inf, or -0.0, so these will error
181+
function M.serialize_json(vector)
182+
local values = {}
183+
for i, v in ipairs(vector) do
184+
-- Check for NaN
185+
if v ~= v then
186+
error("serialize_json: NaN at index " .. i .. " is not valid JSON")
187+
end
188+
-- Check for Inf/-Inf
189+
if v == math.huge or v == -math.huge then
190+
error("serialize_json: Infinity at index " .. i .. " is not valid JSON")
191+
end
192+
-- Handle -0.0 (convert to 0.0 for JSON compatibility)
193+
if v == 0 and 1/v == -math.huge then
194+
v = 0.0
195+
end
196+
table.insert(values, tostring(v))
197+
end
198+
return "[" .. table.concat(values, ",") .. "]"
199+
end
200+
201+
return M

examples/simple-lua/.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Lua bytecode
2+
*.luac
3+
4+
# SQLite databases
5+
*.db
6+
*.sqlite
7+
*.sqlite3

examples/simple-lua/README.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# SQLite-Vec Simple Lua Example
2+
3+
This example demonstrates how to use sqlite-vec with Lua and the lsqlite3 binding.
4+
5+
## Prerequisites
6+
7+
1. **Lua 5.1+** - The example is compatible with Lua 5.1 and later
8+
2. **lsqlite3** - Lua SQLite3 binding
9+
3. **sqlite-vec extension** - Built for your platform
10+
11+
## Installation
12+
13+
### Install lsqlite3
14+
15+
Using LuaRocks:
16+
```bash
17+
luarocks install lsqlite3
18+
```
19+
20+
Or on Ubuntu/Debian:
21+
```bash
22+
apt install lua-sql-sqlite3
23+
```
24+
25+
### Build sqlite-vec
26+
27+
From the repository root:
28+
```bash
29+
make loadable
30+
```
31+
32+
This creates `dist/vec0.so` (Linux), `dist/vec0.dylib` (macOS), or `dist/vec0.dll` (Windows).
33+
34+
## Running the Example
35+
36+
From this directory:
37+
```bash
38+
lua demo.lua
39+
```
40+
41+
Or using the run script:
42+
```bash
43+
./run.sh
44+
```
45+
46+
## Expected Output
47+
48+
```
49+
=== SQLite-Vec Simple Lua Example ===
50+
sqlite_version=3.x.x, vec_version=v0.x.x
51+
sqlite-vec extension loaded successfully
52+
Inserting vector data...
53+
Inserted 5 vectors
54+
Executing KNN query...
55+
Results (closest to [0.3, 0.3, 0.3, 0.3]):
56+
rowid=3 distance=0.000000
57+
rowid=2 distance=0.200000
58+
rowid=4 distance=0.200000
59+
60+
Testing binary serialization...
61+
Binary round-trip: rowid=1 distance=0.000000
62+
63+
Demo completed successfully
64+
```
65+
66+
## Using the Binding in Your Project
67+
68+
```lua
69+
local sqlite3 = require("lsqlite3")
70+
local sqlite_vec = require("sqlite_vec")
71+
72+
local db = sqlite3.open_memory()
73+
74+
-- Option 1: Auto-detect extension path
75+
sqlite_vec.load(db)
76+
77+
-- Option 2: Explicit path
78+
sqlite_vec.load(db, "/path/to/vec0.so")
79+
80+
-- Serialize vectors
81+
local json_vec = sqlite_vec.serialize_json({1.0, 2.0, 3.0}) -- "[1.0,2.0,3.0]"
82+
local binary_vec = sqlite_vec.serialize_f32({1.0, 2.0, 3.0}) -- 12 bytes
83+
```

0 commit comments

Comments
 (0)