Skip to content

Commit bc597e9

Browse files
authored
Add black and flake8 (#532)
* Add black and flake8 * Fix test losses * Fix pre-commit * Update README
1 parent a469f86 commit bc597e9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+1703
-1382
lines changed

.flake8

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[flake8]
2+
max-line-length = 119
3+
exclude =.git,__pycache__,docs/conf.py,build,dist,setup.py,tests
4+
ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006,D412
5+
inline-quotes = "

.github/workflows/tests.yml

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,39 @@ on:
1212

1313
jobs:
1414
test:
15-
1615
runs-on: ubuntu-18.04
17-
1816
steps:
1917
- uses: actions/checkout@v2
20-
2118
- name: Set up Python ${{ matrix.python-version }}
2219
uses: actions/setup-python@v2
2320
with:
2421
python-version: 3.6
25-
2622
- name: Install dependencies
2723
run: |
2824
python -m pip install --upgrade pip
29-
python -m pip install codecov pytest mock
30-
pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
31-
pip install .
32-
- name: Test
33-
run: |
34-
python -m pytest -s tests
25+
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
26+
pip install .[test]
27+
- name: Run Tests
28+
run: python -m pytest -s tests
29+
- name: Run Flake8
30+
run: flake8 --config=.flake8
31+
32+
check_code_formatting:
33+
name: Check code formatting with Black
34+
runs-on: ubuntu-latest
35+
strategy:
36+
matrix:
37+
python-version: [3.8]
38+
steps:
39+
- name: Checkout
40+
uses: actions/checkout@v2
41+
- name: Set up Python
42+
uses: actions/setup-python@v2
43+
with:
44+
python-version: ${{ matrix.python-version }}
45+
- name: Update pip
46+
run: python -m pip install --upgrade pip
47+
- name: Install Black
48+
run: pip install black==21.9b0
49+
- name: Run Black
50+
run: black --config=pyproject.toml --check .

.pre-commit-config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
repos:
2+
- repo: https://github.com/psf/black
3+
rev: 21.12b0
4+
hooks:
5+
- id: black
6+
args: [ --config=pyproject.toml ]
7+
- repo: https://gitlab.com/pycqa/flake8
8+
rev: 4.0.1
9+
hooks:
10+
- id: flake8
11+
args: [ --config=.flake8 ]
12+
additional_dependencies: [ flake8-docstrings==1.6.0 ]

README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ model = smp.Unet(
5858

5959
#### 2. Configure data preprocessing
6060

61-
All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). But it is relevant only for 1-2-3-channels images and **not necessary** in case you train the whole model, not only decoder.
61+
All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). It is **not necessary** in case you train the whole model, not only decoder.
6262

6363
```python
6464
from segmentation_models_pytorch.encoders import get_preprocessing_fn
@@ -419,11 +419,23 @@ $ pip install git+https://github.com/qubvel/segmentation_models.pytorch
419419
420420
### 🤝 Contributing
421421
422-
##### Run test
422+
##### Install linting and formatting pre-commit hooks
423+
```bash
424+
pip install pre-commit black flake8
425+
pre-commit install
426+
```
427+
428+
##### Run tests
429+
```bash
430+
pytest -p no:cacheprovider
431+
```
432+
433+
##### Run tests in docker
423434
```bash
424435
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
425436
```
426-
##### Generate table
437+
438+
##### Generate table with encoders (in case you add a new encoder)
427439
```bash
428440
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
429441
```

__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

misc/generate_table.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
"Params, M",
1111
]
1212

13+
1314
def wrap_row(r):
1415
return "|{}|".format(r)
1516

16-
header = "|".join([column.ljust(WIDTH, ' ') for column in COLUMNS])
17+
18+
header = "|".join([column.ljust(WIDTH, " ") for column in COLUMNS])
1719
separator = "|".join(["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1))
1820

1921
print(wrap_row(header))

misc/generate_table_timm.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,34 @@ def check_features_and_reduction(name):
77
if not encoder.feature_info.reduction() == [2, 4, 8, 16, 32]:
88
raise ValueError
99

10+
1011
def has_dilation_support(name):
1112
try:
1213
timm.create_model(name, features_only=True, output_stride=8, pretrained=False)
1314
timm.create_model(name, features_only=True, output_stride=16, pretrained=False)
1415
return True
15-
except Exception as e:
16+
except Exception:
1617
return False
1718

19+
1820
def make_table(data):
19-
names = supported.keys()
21+
names = data.keys()
2022
max_len1 = max([len(x) for x in names]) + 2
2123
max_len2 = len("support dilation") + 2
22-
24+
2325
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
2426
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
2527
top = "| " + "Encoder name".ljust(max_len1 - 2) + " | " + "Support dilation".center(max_len2 - 2) + " |\n"
26-
28+
2729
table = l1 + top + l2
28-
30+
2931
for k in sorted(data.keys()):
3032
support = "✅".center(max_len2 - 3) if data[k]["has_dilation"] else " ".center(max_len2 - 2)
3133
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
3234
table += l1
33-
35+
3436
return table
35-
37+
3638

3739
if __name__ == "__main__":
3840

pyproject.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[tool.black]
2+
line-length = 119
3+
target-version = ['py36', 'py37', 'py38']
4+
include = '\.pyi?$'
5+
exclude = '''
6+
/(
7+
\.eggs
8+
| \.git
9+
| \.hg
10+
| \.mypy_cache
11+
| \.tox
12+
| \.venv
13+
| docs
14+
| _build
15+
| buck-out
16+
| build
17+
| dist
18+
)/
19+
'''

segmentation_models_pytorch/__init__.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,31 @@ def create_model(
2828
classes: int = 1,
2929
**kwargs,
3030
) -> _torch.nn.Module:
31-
"""Models entrypoint, allows to create any model architecture just with
32-
parameters, without using its class"""
31+
"""Models entrypoint, allows to create any model architecture just with
32+
parameters, without using its class
33+
"""
3334

34-
archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
35+
archs = [
36+
Unet,
37+
UnetPlusPlus,
38+
MAnet,
39+
Linknet,
40+
FPN,
41+
PSPNet,
42+
DeepLabV3,
43+
DeepLabV3Plus,
44+
PAN,
45+
]
3546
archs_dict = {a.__name__.lower(): a for a in archs}
3647
try:
3748
model_class = archs_dict[arch.lower()]
3849
except KeyError:
39-
raise KeyError("Wrong architecture type `{}`. Available options are: {}".format(
40-
arch, list(archs_dict.keys()),
41-
))
50+
raise KeyError(
51+
"Wrong architecture type `{}`. Available options are: {}".format(
52+
arch,
53+
list(archs_dict.keys()),
54+
)
55+
)
4256
return model_class(
4357
encoder_name=encoder_name,
4458
encoder_weights=encoder_weights,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
VERSION = (0, 2, 1)
22

3-
__version__ = '.'.join(map(str, VERSION))
3+
__version__ = ".".join(map(str, VERSION))

0 commit comments

Comments
 (0)