Skip to content

Commit 01a860c

Browse files
Merge pull request #9 from sqliteai/8-include-exclude-file-type
feat(add): #8 option to only/exclude files by ext
2 parents 8f8bf27 + fe65f7b commit 01a860c

File tree

6 files changed

+405
-44
lines changed

6 files changed

+405
-44
lines changed

src/sqlite_rag/cli.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ def main(
9696
def show_settings(ctx: typer.Context):
9797
"""Show current settings"""
9898
rag_context = ctx.obj["rag_context"]
99-
rag = rag_context.get_rag(require_existing=True)
99+
try:
100+
rag = rag_context.get_rag(require_existing=True)
101+
except FileNotFoundError:
102+
typer.echo("Database not found. No settings available.")
103+
raise typer.Exit(1)
104+
100105
current_settings = rag.get_settings()
101106

102107
typer.echo("Current settings:")
@@ -242,17 +247,44 @@ def add(
242247
help="Optional metadata in JSON format to associate with the document",
243248
metavar="JSON",
244249
),
250+
only_extensions: Optional[str] = typer.Option(
251+
None,
252+
"--only",
253+
help="Only process these file extensions from supported list (comma-separated, e.g. 'py,js')",
254+
),
255+
exclude_extensions: Optional[str] = typer.Option(
256+
None,
257+
"--exclude",
258+
help="File extensions to exclude (comma-separated, e.g. 'py,js')",
259+
),
245260
):
246261
"""Add a file path to the database"""
247262
rag_context = ctx.obj["rag_context"]
248263
start_time = time.time()
249264

265+
only_list = (
266+
[e.strip().lstrip(".").lower() for e in only_extensions.split(",") if e.strip()]
267+
if only_extensions
268+
else None
269+
)
270+
exclude_list = (
271+
[
272+
e.strip().lstrip(".").lower()
273+
for e in exclude_extensions.split(",")
274+
if e.strip()
275+
]
276+
if exclude_extensions
277+
else None
278+
)
279+
250280
rag = rag_context.get_rag()
251281
rag.add(
252282
path,
253283
recursive=recursive,
254284
use_relative_paths=use_relative_paths,
255285
metadata=json.loads(metadata or "{}"),
286+
only_extensions=only_list,
287+
exclude_extensions=exclude_list,
256288
)
257289

258290
elapsed_time = time.time() - start_time

src/sqlite_rag/reader.py

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,67 @@
66

77
class FileReader:
88
extensions = [
9-
".c",
10-
".cpp",
11-
".css",
12-
".csv",
13-
".docx",
14-
".go",
15-
".h",
16-
".hpp",
17-
".html",
18-
".java",
19-
".js",
20-
".json",
21-
".kt",
22-
".md",
23-
".mdx",
24-
".mjs",
25-
".pdf",
26-
".php",
27-
".pptx",
28-
".py",
29-
".rb",
30-
".rs",
31-
".svelte",
32-
".swift",
33-
".ts",
34-
".tsx",
35-
".txt",
36-
".vue",
37-
".xml",
38-
".xlsx",
39-
".yaml",
40-
".yml",
9+
"c",
10+
"cpp",
11+
"css",
12+
"csv",
13+
"docx",
14+
"go",
15+
"h",
16+
"hpp",
17+
"html",
18+
"java",
19+
"js",
20+
"json",
21+
"kt",
22+
"md",
23+
"mdx",
24+
"mjs",
25+
"pdf",
26+
"php",
27+
"pptx",
28+
"py",
29+
"rb",
30+
"rs",
31+
"svelte",
32+
"swift",
33+
"ts",
34+
"tsx",
35+
"txt",
36+
"vue",
37+
"xml",
38+
"xlsx",
39+
"yaml",
40+
"yml",
4141
]
4242

4343
@staticmethod
44-
def is_supported(path: Path) -> bool:
45-
"""Check if the file extension is supported"""
46-
return path.suffix.lower() in FileReader.extensions
44+
def is_supported(
45+
path: Path,
46+
only_extensions: Optional[list[str]] = None,
47+
exclude_extensions: Optional[list[str]] = None,
48+
) -> bool:
49+
"""Check if the file extension is supported.
50+
51+
Parameters:
52+
path (Path): The file path to check.
53+
only_extensions (Optional[list[str]]): If provided, only files with these extensions are considered.
54+
exclude_extensions (Optional[list[str]]): If provided, files with these extensions are excluded.
55+
"""
56+
extension = path.suffix.lower().lstrip(".")
57+
58+
supported_extensions = set(FileReader.extensions)
59+
exclude_set = set()
60+
61+
# Only keep those that are in both lists
62+
if only_extensions:
63+
only_set = {ext.lower().lstrip(".") for ext in only_extensions}
64+
supported_extensions &= only_set
65+
66+
if exclude_extensions:
67+
exclude_set = {ext.lower().lstrip(".") for ext in exclude_extensions}
68+
69+
return extension in supported_extensions and extension not in exclude_set
4770

4871
@staticmethod
4972
def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str:
@@ -65,12 +88,19 @@ def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str
6588
raise ValueError(f"Failed to parse file {path}") from exc
6689

6790
@staticmethod
68-
def collect_files(path: Path, recursive: bool = False) -> list[Path]:
91+
def collect_files(
92+
path: Path,
93+
recursive: bool = False,
94+
only_extensions: Optional[list[str]] = None,
95+
exclude_extensions: Optional[list[str]] = None,
96+
) -> list[Path]:
6997
"""Collect files from the path, optionally recursively"""
7098
if not path.exists():
7199
raise FileNotFoundError(f"{path} does not exist.")
72100

73-
if path.is_file() and FileReader.is_supported(path):
101+
if path.is_file() and FileReader.is_supported(
102+
path, only_extensions, exclude_extensions
103+
):
74104
return [path]
75105

76106
files_to_process = []
@@ -83,7 +113,8 @@ def collect_files(path: Path, recursive: bool = False) -> list[Path]:
83113
files_to_process = [
84114
f
85115
for f in files_to_process
86-
if f.is_file() and FileReader.is_supported(f)
116+
if f.is_file()
117+
and FileReader.is_supported(f, only_extensions, exclude_extensions)
87118
]
88119

89120
return files_to_process

src/sqlite_rag/sqliterag.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,32 @@ def add(
7272
recursive: bool = False,
7373
use_relative_paths: bool = False,
7474
metadata: dict = {},
75+
only_extensions: Optional[list[str]] = None,
76+
exclude_extensions: Optional[list[str]] = None,
7577
) -> int:
76-
"""Add the file content into the database"""
78+
"""Add the file content into the database
79+
80+
Args:
81+
path: File or directory path to add
82+
recursive: Recursively add files in directories
83+
use_relative_paths: Store relative paths instead of absolute paths
84+
metadata: Metadata to associate with documents
85+
only_extensions: Only process these file extensions from the supported list (e.g. ['py', 'js'])
86+
exclude_extensions: Skip these file extensions (e.g. ['py', 'js'])
87+
"""
7788
self._ensure_initialized()
7889

7990
if not Path(path).exists():
8091
raise FileNotFoundError(f"{path} does not exist.")
8192

8293
parent = Path(path).parent
8394

84-
files_to_process = FileReader.collect_files(Path(path), recursive=recursive)
95+
files_to_process = FileReader.collect_files(
96+
Path(path),
97+
recursive=recursive,
98+
only_extensions=only_extensions,
99+
exclude_extensions=exclude_extensions,
100+
)
85101

86102
self._engine.create_new_context()
87103

tests/integration/test_cli.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,86 @@ def test_change_database_path(self):
113113
assert result.exit_code == 0
114114

115115
assert f"Database: {tmp_db.name}" in result.stdout
116+
117+
def test_add_with_exclude_extensions(self):
118+
with tempfile.TemporaryDirectory() as tmp_dir:
119+
(Path(tmp_dir) / "file1.txt").write_text("This is a text file.")
120+
(Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.")
121+
(Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')")
122+
(Path(tmp_dir) / "file4.js").write_text("console.log('Hello, world!');")
123+
124+
with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
125+
runner = CliRunner()
126+
127+
result = runner.invoke(
128+
app,
129+
["--database", tmp_db.name, "add", tmp_dir, "--exclude", "py,js"],
130+
)
131+
assert result.exit_code == 0
132+
133+
# Check that only .txt and .md files were added
134+
assert "Processing 2 files" in result.stdout
135+
assert "file1.txt" in result.stdout
136+
assert "file2.md" in result.stdout
137+
assert "file3.py" not in result.stdout
138+
assert "file4.js" not in result.stdout
139+
140+
def test_add_with_only_extensions(self):
141+
with tempfile.TemporaryDirectory() as tmp_dir:
142+
(Path(tmp_dir) / "file1.txt").write_text("This is a text file.")
143+
(Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.")
144+
(Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')")
145+
(Path(tmp_dir) / "file4.js").write_text("console.log('Hello, world!');")
146+
147+
with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
148+
runner = CliRunner()
149+
150+
result = runner.invoke(
151+
app,
152+
[
153+
"--database",
154+
tmp_db.name,
155+
"add",
156+
tmp_dir,
157+
"--only",
158+
"md,txt",
159+
],
160+
)
161+
assert result.exit_code == 0
162+
163+
# Check that only .txt and .md files were added
164+
assert "Processing 2 files" in result.stdout
165+
assert "file1.txt" in result.stdout
166+
assert "file2.md" in result.stdout
167+
assert "file3.py" not in result.stdout
168+
assert "file4.js" not in result.stdout
169+
170+
def test_add_with_only_and_exclude_extensions_are_normilized(self):
171+
with tempfile.TemporaryDirectory() as tmp_dir:
172+
(Path(tmp_dir) / "file1.txt").write_text("This is a text file.")
173+
(Path(tmp_dir) / "file2.md").write_text("# This is a markdown file.")
174+
(Path(tmp_dir) / "file3.py").write_text("print('Hello, world!')")
175+
176+
with tempfile.NamedTemporaryFile(suffix=".tempdb") as tmp_db:
177+
runner = CliRunner()
178+
179+
result = runner.invoke(
180+
app,
181+
[
182+
"--database",
183+
tmp_db.name,
184+
"add",
185+
tmp_dir,
186+
"--only",
187+
".md, .txt,py",
188+
"--exclude",
189+
".py ", # wins over --only
190+
],
191+
)
192+
assert result.exit_code == 0
193+
194+
# Check that only .txt and .md files were added
195+
assert "Processing 2 files" in result.stdout
196+
assert "file1.txt" in result.stdout
197+
assert "file2.md" in result.stdout
198+
assert "file3.py" not in result.stdout

0 commit comments

Comments
 (0)