diff --git a/llm/cli.py b/llm/cli.py index 4c5151d8..b64dc8cc 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -3310,23 +3310,80 @@ def load_rows(fp): rows, label="Embedding", show_percent=True, length=expected_length ) as rows: - def tuples() -> Iterable[Tuple[str, Union[bytes, str]]]: - for row in rows: + def tuples_generator(rows_list) -> Iterable[Tuple[str, Union[bytes, str]]]: + for row in rows_list: values = list(row.values()) id: str = prefix + str(values[0]) content: Optional[Union[bytes, str]] = None if binary: content = cast(bytes, values[1]) else: - content = " ".join(v or "" for v in values[1:]) + # Skip metadata if it exists - only concatenate string/numeric values + content_values = [] + for v in values[1:]: + if isinstance(v, dict): + continue # Skip metadata dicts + content_values.append(str(v) if v is not None else "") + content = " ".join(content_values) if prepend and isinstance(content, str): content = prepend + content yield id, content or "" + def tuples_with_metadata_generator(rows_list) -> Iterable[Tuple[str, Union[bytes, str], Optional[Dict[str, Any]]]]: + for row in rows_list: + values = list(row.values()) + keys = list(row.keys()) + id: str = prefix + str(values[0]) + content: Optional[Union[bytes, str]] = None + metadata: Optional[Dict[str, Any]] = None + + # Extract metadata if present + if "metadata" in keys: + metadata_value = row["metadata"] + if isinstance(metadata_value, dict): + metadata = metadata_value + + if binary: + content = cast(bytes, values[1]) + else: + # Only concatenate non-metadata, non-id values + content_values = [] + for i, v in enumerate(values[1:], 1): # Start from index 1 (skip id) + key = keys[i] + if key == "metadata": + continue # Skip metadata field + content_values.append(str(v) if v is not None else "") + content = " ".join(content_values) + + if prepend and isinstance(content, str): + content = prepend + content + yield id, content or "", metadata + + # Check if any row has metadata to determine which method to use + has_metadata = False + first_row = None + + # Peek at the first row to determine if metadata exists + rows_list = list(rows) # Convert to list so we can examine the first row + if rows_list: + first_row = rows_list[0] + has_metadata = isinstance(first_row, dict) and "metadata" in first_row + else: + # No rows to process + return + embed_kwargs = {"store": store} if batch_size: embed_kwargs["batch_size"] = batch_size - collection_obj.embed_multi(tuples(), **embed_kwargs) + + if has_metadata: + collection_obj.embed_multi_with_metadata( + (tuples_with_metadata_generator(rows_list)), **embed_kwargs + ) + else: + collection_obj.embed_multi( + (tuples_generator(rows_list)), **embed_kwargs + ) @cli.command() diff --git a/tests/test_embed_cli.py b/tests/test_embed_cli.py index afee7712..f1ebe35e 100644 --- a/tests/test_embed_cli.py +++ b/tests/test_embed_cli.py @@ -345,6 +345,38 @@ def test_embed_multi_file_input(tmpdir, use_stdin, prefix, prepend, filename, co assert ids == expected_ids +def test_embed_multi_with_metadata(tmpdir): + """Test that embed-multi works with JSON files containing metadata.""" + db_path = tmpdir / "embeddings.db" + content = '[{"id": "1", "content": "An item", "metadata": {"key1": "value1", "key2": "value2"}}, {"id": "2", "content": "Another item", "metadata": {"key1": "value1", "key2": "value2"}}]' + + path = tmpdir / "metadata.json" + path.write_text(content, "utf-8") + + args = ["embed-multi", "metadata-test", "-d", str(db_path), "-m", "embed-demo", str(path), "--store"] + + runner = CliRunner() + result = runner.invoke(cli, args, catch_exceptions=False) + assert result.exit_code == 0 + + # Check that everything was embedded correctly with metadata + db = sqlite_utils.Database(str(db_path)) + assert db["embeddings"].count == 2 + + rows = list(db["embeddings"].rows) + assert len(rows) == 2 + + # Check first row + row1 = [row for row in rows if row["id"] == "1"][0] + assert row1["content"] == "An item" + assert json.loads(row1["metadata"]) == {"key1": "value1", "key2": "value2"} + + # Check second row + row2 = [row for row in rows if row["id"] == "2"][0] + assert row2["content"] == "Another item" + assert json.loads(row2["metadata"]) == {"key1": "value1", "key2": "value2"} + + def test_embed_multi_files_binary_store(tmpdir): db_path = tmpdir / "embeddings.db" args = ["embed-multi", "binfiles", "-d", str(db_path), "-m", "embed-demo"]