Skip to content

Commit 0a75709

Browse files
Add pyright github action (#118)
* Add pyright github action Adds type checking with pyright * Force color in pyright * Fix pyright issues * Fix overly long line * Fix pyright issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 89b4471 commit 0a75709

File tree

5 files changed

+39
-3
lines changed

5 files changed

+39
-3
lines changed

.github/workflows/pyright.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: Pyright Type Checking
2+
3+
on:
4+
push:
5+
branches: [ master ]
6+
pull_request:
7+
branches: [ master ]
8+
9+
jobs:
10+
pyright:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Set up Python
16+
uses: actions/setup-python@v5
17+
with:
18+
python-version: '3.11'
19+
20+
- name: Install dependencies
21+
run: |
22+
pip install .[dev]
23+
24+
- name: Run pyright
25+
run: pyright
26+
env:
27+
FORCE_COLOR: "1"

preprocessing_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ def create_val_set(csv_file, val_fraction):
3232
out of it specified by val_fraction.
3333
"""
3434
csv_file = Path(csv_file)
35-
dataset = pd.read_csv(csv_file)
35+
dataset: pd.DataFrame = pd.read_csv(csv_file)
3636
np.random.seed(0)
3737
dataset_mod = dataset[dataset.toxic != -1]
3838
indices = np.random.rand(len(dataset_mod)) > val_fraction
3939
val_set = dataset_mod[~indices]
4040
output_file = csv_file.parent / "val.csv"
4141
logger.info("Validation set saved to %s", output_file)
42+
assert isinstance(val_set, (pd.DataFrame, pd.Series))
4243
val_set.to_csv(output_file)
4344

4445

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ dev = [
3333
"scikit-learn >= 0.23.2",
3434
"tqdm",
3535
"pre-commit",
36-
"numpy>=2"
36+
"numpy>=2",
37+
"pyright"
3738
]
3839

3940
[tool.ruff]

run_prediction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ def run(model_name, input_obj, dest_file, from_ckpt, device="cpu"):
3232
model = Detoxify(checkpoint=from_ckpt, device=device)
3333
res = model.predict(text)
3434

35-
res_df = pd.DataFrame(res, index=[text] if isinstance(text, str) else text).round(5)
35+
res_df = pd.DataFrame(
36+
res,
37+
index=[text] if isinstance(text, str) else text, # pyright: ignore[reportArgumentType]
38+
).round(5)
3639
print(res_df)
3740
if dest_file is not None:
3841
res_df.index.name = "input_text"

src/data_loaders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def load_data(self, csv_file):
4949
filtered_change_names = {k: v for k, v in change_names.items() if k in final_df.columns}
5050
if len(filtered_change_names) > 0:
5151
final_df.rename(columns=filtered_change_names, inplace=True)
52+
else:
53+
raise TypeError("Invalid input type for csv_file, must be a string or a list of strings")
5254
return final_df
5355

5456
def load_val(self, test_csv_file, add_labels=False):
@@ -155,6 +157,8 @@ def __getitem__(self, index):
155157
meta["text_id"] = text_id
156158

157159
if self.train:
160+
if self.weights is None:
161+
raise Exception("self.weights must not be None")
158162
meta["weights"] = self.weights[index]
159163
toxic_weight = self.weights[index] * self.loss_weight * 1.0 / len(self.classes)
160164
identity_weight = (1 - self.loss_weight) * 1.0 / len(self.identity_classes)

0 commit comments

Comments
 (0)