Skip to content

Commit 40e179e

Browse files
author
Daniele Briggi
committed
feat(add): #8 option to only/exclude files by ext
1 parent 8f8bf27 commit 40e179e

File tree

5 files changed

+298
-42
lines changed

5 files changed

+298
-42
lines changed

src/sqlite_rag/cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,17 +242,33 @@ def add(
242242
help="Optional metadata in JSON format to associate with the document",
243243
metavar="JSON",
244244
),
245+
only_extensions: Optional[str] = typer.Option(
246+
None,
247+
"--only",
248+
help="Only process these file extensions from supported list (comma-separated, e.g. 'py,js')",
249+
),
250+
exclude_extensions: Optional[str] = typer.Option(
251+
None,
252+
"--exclude",
253+
help="File extensions to exclude (comma-separated, e.g. 'py,js')",
254+
),
245255
):
246256
"""Add a file path to the database"""
247257
rag_context = ctx.obj["rag_context"]
248258
start_time = time.time()
249259

260+
# Parse extension lists
261+
only_list = only_extensions.split(",") if only_extensions else None
262+
exclude_list = exclude_extensions.split(",") if exclude_extensions else None
263+
250264
rag = rag_context.get_rag()
251265
rag.add(
252266
path,
253267
recursive=recursive,
254268
use_relative_paths=use_relative_paths,
255269
metadata=json.loads(metadata or "{}"),
270+
only_extensions=only_list,
271+
exclude_extensions=exclude_list,
256272
)
257273

258274
elapsed_time = time.time() - start_time

src/sqlite_rag/reader.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,61 @@
66

77
class FileReader:
88
extensions = [
9-
".c",
10-
".cpp",
11-
".css",
12-
".csv",
13-
".docx",
14-
".go",
15-
".h",
16-
".hpp",
17-
".html",
18-
".java",
19-
".js",
20-
".json",
21-
".kt",
22-
".md",
23-
".mdx",
24-
".mjs",
25-
".pdf",
26-
".php",
27-
".pptx",
28-
".py",
29-
".rb",
30-
".rs",
31-
".svelte",
32-
".swift",
33-
".ts",
34-
".tsx",
35-
".txt",
36-
".vue",
37-
".xml",
38-
".xlsx",
39-
".yaml",
40-
".yml",
9+
"c",
10+
"cpp",
11+
"css",
12+
"csv",
13+
"docx",
14+
"go",
15+
"h",
16+
"hpp",
17+
"html",
18+
"java",
19+
"js",
20+
"json",
21+
"kt",
22+
"md",
23+
"mdx",
24+
"mjs",
25+
"pdf",
26+
"php",
27+
"pptx",
28+
"py",
29+
"rb",
30+
"rs",
31+
"svelte",
32+
"swift",
33+
"ts",
34+
"tsx",
35+
"txt",
36+
"vue",
37+
"xml",
38+
"xlsx",
39+
"yaml",
40+
"yml",
4141
]
4242

4343
@staticmethod
44-
def is_supported(path: Path) -> bool:
44+
def is_supported(
45+
path: Path,
46+
only_extensions: Optional[list[str]] = None,
47+
exclude_extensions: Optional[list[str]] = None,
48+
) -> bool:
4549
"""Check if the file extension is supported"""
46-
return path.suffix.lower() in FileReader.extensions
50+
extension = path.suffix.lower().lstrip(".")
51+
52+
supported_extensions = set(FileReader.extensions)
53+
exclude_set = set()
54+
55+
# Only keep those that are in both lists
56+
if only_extensions:
57+
only_set = {ext.lower().lstrip(".") for ext in only_extensions}
58+
supported_extensions &= only_set
59+
60+
if exclude_extensions:
61+
exclude_set = {ext.lower().lstrip(".") for ext in exclude_extensions}
62+
63+
return extension in supported_extensions and extension not in exclude_set
4764

4865
@staticmethod
4966
def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str:
@@ -65,12 +82,19 @@ def parse_file(path: Path, max_document_size_bytes: Optional[int] = None) -> str
6582
raise ValueError(f"Failed to parse file {path}") from exc
6683

6784
@staticmethod
68-
def collect_files(path: Path, recursive: bool = False) -> list[Path]:
85+
def collect_files(
86+
path: Path,
87+
recursive: bool = False,
88+
only_extensions: Optional[list[str]] = None,
89+
exclude_extensions: Optional[list[str]] = None,
90+
) -> list[Path]:
6991
"""Collect files from the path, optionally recursively"""
7092
if not path.exists():
7193
raise FileNotFoundError(f"{path} does not exist.")
7294

73-
if path.is_file() and FileReader.is_supported(path):
95+
if path.is_file() and FileReader.is_supported(
96+
path, only_extensions, exclude_extensions
97+
):
7498
return [path]
7599

76100
files_to_process = []
@@ -83,7 +107,8 @@ def collect_files(path: Path, recursive: bool = False) -> list[Path]:
83107
files_to_process = [
84108
f
85109
for f in files_to_process
86-
if f.is_file() and FileReader.is_supported(f)
110+
if f.is_file()
111+
and FileReader.is_supported(f, only_extensions, exclude_extensions)
87112
]
88113

89114
return files_to_process

src/sqlite_rag/sqliterag.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,32 @@ def add(
7272
recursive: bool = False,
7373
use_relative_paths: bool = False,
7474
metadata: dict = {},
75+
only_extensions: Optional[list[str]] = None,
76+
exclude_extensions: Optional[list[str]] = None,
7577
) -> int:
76-
"""Add the file content into the database"""
78+
"""Add the file content into the database
79+
80+
Args:
81+
path: File or directory path to add
82+
recursive: Recursively add files in directories
83+
use_relative_paths: Store relative paths instead of absolute paths
84+
metadata: Metadata to associate with documents
85+
only_extensions: Only process these file extensions from the supported list (e.g. ['py', 'js'])
86+
exclude_extensions: Skip these file extensions (e.g. ['py', 'js'])
87+
"""
7788
self._ensure_initialized()
7889

7990
if not Path(path).exists():
8091
raise FileNotFoundError(f"{path} does not exist.")
8192

8293
parent = Path(path).parent
8394

84-
files_to_process = FileReader.collect_files(Path(path), recursive=recursive)
95+
files_to_process = FileReader.collect_files(
96+
Path(path),
97+
recursive=recursive,
98+
only_extensions=only_extensions,
99+
exclude_extensions=exclude_extensions,
100+
)
85101

86102
self._engine.create_new_context()
87103

tests/test_reader.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,102 @@ def test_collect_files_recursive_directory(self):
6262
assert file2 in files
6363

6464
def test_is_supported(self):
65-
unsupported_extensions = [".exe", ".bin", ".jpg", ".png"]
65+
unsupported_extensions = ["exe", "bin", "jpg", "png"]
6666

6767
for ext in FileReader.extensions:
68-
assert FileReader.is_supported(Path(f"test{ext}"))
68+
assert FileReader.is_supported(Path(f"test.{ext}"))
6969

7070
for ext in unsupported_extensions:
71-
assert not FileReader.is_supported(Path(f"test{ext}"))
71+
assert not FileReader.is_supported(Path(f"test.{ext}"))
72+
73+
def test_is_supported_with_only_extensions(self):
74+
"""Test is_supported with only_extensions parameter"""
75+
# Test with only_extensions - should only allow specified extensions
76+
assert FileReader.is_supported(Path("test.py"), only_extensions=["py", "js"])
77+
assert FileReader.is_supported(Path("test.js"), only_extensions=["py", "js"])
78+
assert not FileReader.is_supported(
79+
Path("test.txt"), only_extensions=["py", "js"]
80+
)
81+
assert not FileReader.is_supported(
82+
Path("test.md"), only_extensions=["py", "js"]
83+
)
84+
85+
# Test with dots in extensions (should be normalized)
86+
assert FileReader.is_supported(Path("test.py"), only_extensions=[".py", ".js"])
87+
assert FileReader.is_supported(Path("test.js"), only_extensions=[".py", ".js"])
88+
89+
# Test case insensitive
90+
assert FileReader.is_supported(Path("test.py"), only_extensions=["PY", "JS"])
91+
assert FileReader.is_supported(Path("test.JS"), only_extensions=["py", "js"])
92+
93+
def test_is_supported_with_exclude_extensions(self):
94+
"""Test is_supported with exclude_extensions parameter"""
95+
# Test basic exclusion - py files should be excluded
96+
assert not FileReader.is_supported(Path("test.py"), exclude_extensions=["py"])
97+
assert FileReader.is_supported(Path("test.js"), exclude_extensions=["py"])
98+
assert FileReader.is_supported(Path("test.txt"), exclude_extensions=["py"])
99+
100+
# Test with dots in extensions (should be normalized)
101+
assert not FileReader.is_supported(Path("test.py"), exclude_extensions=[".py"])
102+
assert FileReader.is_supported(Path("test.js"), exclude_extensions=[".py"])
103+
104+
# Test case insensitive
105+
assert not FileReader.is_supported(Path("test.py"), exclude_extensions=["PY"])
106+
assert not FileReader.is_supported(Path("test.PY"), exclude_extensions=["py"])
107+
108+
# Test multiple exclusions
109+
assert not FileReader.is_supported(
110+
Path("test.py"), exclude_extensions=["py", "js"]
111+
)
112+
assert not FileReader.is_supported(
113+
Path("test.js"), exclude_extensions=["py", "js"]
114+
)
115+
assert FileReader.is_supported(
116+
Path("test.txt"), exclude_extensions=["py", "js"]
117+
)
118+
119+
def test_is_supported_with_only_and_exclude_extensions(self):
120+
"""Test is_supported with both only_extensions and exclude_extensions"""
121+
# Include py and js, but exclude py - should only allow js
122+
assert not FileReader.is_supported(
123+
Path("test.py"), only_extensions=["py", "js"], exclude_extensions=["py"]
124+
)
125+
assert FileReader.is_supported(
126+
Path("test.js"), only_extensions=["py", "js"], exclude_extensions=["py"]
127+
)
128+
assert not FileReader.is_supported(
129+
Path("test.txt"), only_extensions=["py", "js"], exclude_extensions=["py"]
130+
)
131+
132+
# Include py, txt, md, but exclude md - should only allow py and txt
133+
assert FileReader.is_supported(
134+
Path("test.py"),
135+
only_extensions=["py", "txt", "md"],
136+
exclude_extensions=["md"],
137+
)
138+
assert FileReader.is_supported(
139+
Path("test.txt"),
140+
only_extensions=["py", "txt", "md"],
141+
exclude_extensions=["md"],
142+
)
143+
assert not FileReader.is_supported(
144+
Path("test.md"),
145+
only_extensions=["py", "txt", "md"],
146+
exclude_extensions=["md"],
147+
)
148+
assert not FileReader.is_supported(
149+
Path("test.js"),
150+
only_extensions=["py", "txt", "md"],
151+
exclude_extensions=["md"],
152+
)
153+
154+
def test_is_supported_with_unsupported_extensions_in_only(self):
155+
"""Test that only_extensions can't add unsupported extensions"""
156+
# .exe is not in FileReader.extensions, so should not be supported even if in only_extensions
157+
assert not FileReader.is_supported(
158+
Path("test.exe"), only_extensions=["exe", "py"]
159+
)
160+
assert FileReader.is_supported(Path("test.py"), only_extensions=["exe", "py"])
72161

73162
def test_parse_html_into_markdown(self):
74163
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as f:

0 commit comments

Comments
 (0)