Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3310,23 +3310,80 @@ def load_rows(fp):
rows, label="Embedding", show_percent=True, length=expected_length
) as rows:

def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]:
for row in rows:
def tuples_generator(rows_list) -> Iterable[Tuple[str, Union[bytes, str]]]:
for row in rows_list:
values = list(row.values())
id: str = prefix + str(values[0])
content: Optional[Union[bytes, str]] = None
if binary:
content = cast(bytes, values[1])
else:
content = " ".join(v or "" for v in values[1:])
# Skip metadata if it exists - only concatenate string/numeric values
content_values = []
for v in values[1:]:
if isinstance(v, dict):
continue # Skip metadata dicts
content_values.append(str(v) if v is not None else "")
content = " ".join(content_values)
if prepend and isinstance(content, str):
content = prepend + content
yield id, content or ""

def tuples_with_metadata_generator(rows_list) -> Iterable[Tuple[str, Union[bytes, str], Optional[Dict[str, Any]]]]:
for row in rows_list:
values = list(row.values())
keys = list(row.keys())
id: str = prefix + str(values[0])
content: Optional[Union[bytes, str]] = None
metadata: Optional[Dict[str, Any]] = None

# Extract metadata if present
if "metadata" in keys:
metadata_value = row["metadata"]
if isinstance(metadata_value, dict):
metadata = metadata_value

if binary:
content = cast(bytes, values[1])
else:
# Only concatenate non-metadata, non-id values
content_values = []
for i, v in enumerate(values[1:], 1): # Start from index 1 (skip id)
key = keys[i]
if key == "metadata":
continue # Skip metadata field
content_values.append(str(v) if v is not None else "")
content = " ".join(content_values)

if prepend and isinstance(content, str):
content = prepend + content
yield id, content or "", metadata

# Check if any row has metadata to determine which method to use
has_metadata = False
first_row = None

# Peek at the first row to determine if metadata exists
rows_list = list(rows) # Convert to list so we can examine the first row
if rows_list:
first_row = rows_list[0]
has_metadata = isinstance(first_row, dict) and "metadata" in first_row
else:
# No rows to process
return

embed_kwargs = {"store": store}
if batch_size:
embed_kwargs["batch_size"] = batch_size
collection_obj.embed_multi(tuples(), **embed_kwargs)

if has_metadata:
collection_obj.embed_multi_with_metadata(
(tuples_with_metadata_generator(rows_list)), **embed_kwargs
)
else:
collection_obj.embed_multi(
(tuples_generator(rows_list)), **embed_kwargs
)


@cli.command()
Expand Down
32 changes: 32 additions & 0 deletions tests/test_embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,38 @@ def test_embed_multi_file_input(tmpdir, use_stdin, prefix, prepend, filename, co
assert ids == expected_ids


def test_embed_multi_with_metadata(tmpdir):
"""Test that embed-multi works with JSON files containing metadata."""
db_path = tmpdir / "embeddings.db"
content = '[{"id": "1", "content": "An item", "metadata": {"key1": "value1", "key2": "value2"}}, {"id": "2", "content": "Another item", "metadata": {"key1": "value1", "key2": "value2"}}]'

path = tmpdir / "metadata.json"
path.write_text(content, "utf-8")

args = ["embed-multi", "metadata-test", "-d", str(db_path), "-m", "embed-demo", str(path), "--store"]

runner = CliRunner()
result = runner.invoke(cli, args, catch_exceptions=False)
assert result.exit_code == 0

# Check that everything was embedded correctly with metadata
db = sqlite_utils.Database(str(db_path))
assert db["embeddings"].count == 2

rows = list(db["embeddings"].rows)
assert len(rows) == 2

# Check first row
row1 = [row for row in rows if row["id"] == "1"][0]
assert row1["content"] == "An item"
assert json.loads(row1["metadata"]) == {"key1": "value1", "key2": "value2"}

# Check second row
row2 = [row for row in rows if row["id"] == "2"][0]
assert row2["content"] == "Another item"
assert json.loads(row2["metadata"]) == {"key1": "value1", "key2": "value2"}


def test_embed_multi_files_binary_store(tmpdir):
db_path = tmpdir / "embeddings.db"
args = ["embed-multi", "binfiles", "-d", str(db_path), "-m", "embed-demo"]
Expand Down