Skip to content

Commit de129d6

Browse files
authored
Merge pull request #27 from DefTruth/dev
Dev
2 parents ec2ead2 + 1d19f51 commit de129d6

File tree

9 files changed

+186
-100
lines changed

9 files changed

+186
-100
lines changed

README.md

Lines changed: 143 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,27 @@
1212

1313

1414
## 🤗 Introduction
15-
**torchlm** is a PyTorch landmarks-only library with **100+ data augmentations**, support **training** and **inference**. **torchlm** is aims at only focus on any landmark detection, such as face landmarks, hand keypoints and body keypoints, etc. It provides **30+** native data augmentations and can **bind** with **80+** transforms from torchvision and albumentations, no matter the input is a np.ndarray or a torch Tensor, **torchlm** will automatically be compatible with different data types and then wrap it back to the original type through a **autodtype** wrapper. Further, **torchlm** will add modules for **training** and **inference** in the future.
15+
**torchlm** is aims to build a high level pipeline for face landmarks detection, support **100+ data augmentations**, **training** and **inference**, can can easily install with **pip**.
1616
<div align='center'>
17-
<img src='docs/res/605.jpg' height="100px" width="100px">
18-
<img src='docs/res/802.jpg' height="100px" width="100px">
19-
<img src='docs/res/92.jpg' height="100px" width="100px">
20-
<img src='docs/res/234.jpg' height="100px" width="100px">
21-
<img src='docs/res/906.jpg' height="100px" width="100px">
22-
<img src='docs/res/825.jpg' height="100px" width="100px">
23-
<img src='docs/res/388.jpg' height="100px" width="100px">
24-
<br>
2517
<img src='docs/res/2_wflw_44.jpg' height="100px" width="100px">
2618
<img src='docs/res/2_wflw_67.jpg' height="100px" width="100px">
2719
<img src='docs/res/2_wflw_76.jpg' height="100px" width="100px">
28-
<img src='docs/res/2_wflw_162.jpg' height="100px" width="100px">
29-
<img src='docs/res/2_wflw_229.jpg' height="100px" width="100px">
30-
<img src='docs/res/2_wflw_440.jpg' height="100px" width="100px">
31-
<img src='docs/res/2_wflw_478.jpg' height="100px" width="100px">
20+
<img src='docs/assets/pipnet0.jpg' height="100px" width="100px">
21+
<img src='docs/assets/pipnet_300W_CELEBA_model.gif' height="100px" width="100px">
22+
<img src='docs/assets/pipnet_shaolin_soccer.gif' height="100px" width="100px">
23+
<img src='docs/assets/pipnet_WFLW_model.gif' height="100px" width="100px">
3224
</div>
3325

3426
<p align="center"> ❤️ Star 🌟👆🏻 this repo to support me if it does any helps to you, thanks ~ </p>
3527

28+
## 👋 Core Features
29+
* High level pipeline for **training** and **inference**.
30+
* Provides **30+** native landmarks data augmentations.
31+
* Can **bind 80+** transforms from torchvision and albumentations with **one-line-code**.
32+
* Support awesome models for face landmarks detection, such as YOLOX, YOLOv5, ResNet, MobileNet, ShuffleNet and PIPNet, etc.
3633

37-
# 🆕 What's New
38-
34+
## 🆕 What's New
35+
* [2022/03/08]: Add **PIPNet**: [Towards Efficient Facial Landmark Detection in the Wild, CVPR2021](https://github.com/jhb86253817/PIPNet)
3936
* [2022/02/13]: Add **30+** native data augmentations and **bind** **80+** transforms from torchvision and albumentations.
4037

4138
## 🛠️ Usage
@@ -44,53 +41,56 @@
4441
* opencv-python-headless>=4.5.2
4542
* numpy>=1.14.4
4643
* torch>=1.6.0
47-
* torchvision>=0.9.0
44+
* torchvision>=0.8.0
4845
* albumentations>=1.1.0
46+
* onnx>=1.8.0
47+
* onnxruntime>=1.7.0
48+
* tqdm>=4.10.0
4949

5050
### Installation
51-
you can install **torchlm** directly from [pypi](https://pypi.org/project/torchlm/).
51+
you can install **torchlm** directly from [pypi](https://pypi.org/project/torchlm/). See [NOTE](#torchlm-NOTE) before installation!!!
5252
```shell
5353
pip3 install torchlm
5454
# install from specific pypi mirrors use '-i'
5555
pip3 install torchlm -i https://pypi.org/simple/
5656
```
57-
or install from source.
57+
or install from source if you want the latest torchlm and install it in editable mode with `-e`.
5858
```shell
59-
# clone torchlm repository locally
59+
# clone torchlm repository locally if you want the latest torchlm
6060
git clone --depth=1 https://github.com/DefTruth/torchlm.git
6161
cd torchlm
6262
# install in editable mode
6363
pip install -e .
6464
```
65+
<div id="torchlm-NOTE"></div>
66+
67+
**NOTE**: If you have the conflict problem between different installed version of opencv (opencv-python and opencv-python-headless, `ablumentations` need opencv-python-headless). Please uninstall the opencv-python and opencv-python-headless first, and then reinstall torchlm. See [albumentations#1139](https://github.com/albumentations-team/albumentations/issues/1139) for more details.
68+
69+
```shell
70+
# first uninstall confilct opencvs
71+
pip uninstall opencv-python
72+
pip uninstall opencv-python-headless
73+
pip uninstall torchlm # if you have installed torchlm
74+
# then reinstall torchlm
75+
pip install torchlm # will also install deps, e.g opencv
76+
```
6577

66-
### Data Augmentation
78+
### 🌟🌟Data Augmentation
6779
**torchlm** provides **30+** native data augmentations for landmarks and can **bind** with **80+** transforms from torchvision and albumentations through **torchlm.bind** method. Further, **torchlm.bind** provide a `prob` param at bind-level to force any transform or callable be a random-style augmentation. The data augmentations in **torchlm** are `safe` and `simplest`. Any transform operations at runtime cause landmarks outside will be auto dropped to keep the number of landmarks unchanged. The layout format of landmarks is `xy` with shape `(N, 2)`, `N` denotes the number of the input landmarks. No matter the input is a np.ndarray or a torch Tensor, **torchlm** will automatically be compatible with different data types and then wrap it back to the original type through a **autodtype** wrapper.
6880

6981
* use almost **30+** native transforms from **torchlm** directly
7082
```python
7183
import torchlm
7284
transform = torchlm.LandmarksCompose([
73-
# use native torchlm transforms
74-
torchlm.LandmarksRandomScale(prob=0.5),
75-
torchlm.LandmarksRandomTranslate(prob=0.5),
76-
torchlm.LandmarksRandomShear(prob=0.5),
77-
torchlm.LandmarksRandomMask(prob=0.5),
78-
torchlm.LandmarksRandomBlur(kernel_range=(5, 25), prob=0.5),
79-
torchlm.LandmarksRandomBrightness(prob=0.),
80-
torchlm.LandmarksRandomRotate(40, prob=0.5, bins=8),
81-
torchlm.LandmarksRandomCenterCrop((0.5, 1.0), (0.5, 1.0), prob=0.5),
82-
# ...
83-
])
85+
torchlm.LandmarksRandomScale(prob=0.5),
86+
torchlm.LandmarksRandomMask(prob=0.5),
87+
torchlm.LandmarksRandomBlur(kernel_range=(5, 25), prob=0.5),
88+
torchlm.LandmarksRandomBrightness(prob=0.),
89+
torchlm.LandmarksRandomRotate(40, prob=0.5, bins=8),
90+
torchlm.LandmarksRandomCenterCrop((0.5, 1.0), (0.5, 1.0), prob=0.5)
91+
])
8492
```
8593
<div align='center'>
86-
<img src='docs/res/605.jpg' height="100px" width="100px">
87-
<img src='docs/res/802.jpg' height="100px" width="100px">
88-
<img src='docs/res/92.jpg' height="100px" width="100px">
89-
<img src='docs/res/234.jpg' height="100px" width="100px">
90-
<img src='docs/res/906.jpg' height="100px" width="100px">
91-
<img src='docs/res/825.jpg' height="100px" width="100px">
92-
<img src='docs/res/388.jpg' height="100px" width="100px">
93-
<br>
9494
<img src='docs/res/2_wflw_44.jpg' height="100px" width="100px">
9595
<img src='docs/res/2_wflw_67.jpg' height="100px" width="100px">
9696
<img src='docs/res/2_wflw_76.jpg' height="100px" width="100px">
@@ -102,76 +102,45 @@ transform = torchlm.LandmarksCompose([
102102

103103
* **bind** **80+** torchvision and albumentations's transforms through **torchlm.bind**
104104
```python
105-
import torchvision
106-
import albumentations
107-
import torchlm
108105
transform = torchlm.LandmarksCompose([
109-
# use native torchlm transforms
110-
torchlm.LandmarksRandomScale(prob=0.5),
111-
# bind torchvision image only transforms, bind with a given prob
112-
torchlm.bind(torchvision.transforms.GaussianBlur(kernel_size=(5, 25)), prob=0.5),
113-
torchlm.bind(torchvision.transforms.RandomAutocontrast(p=0.5)),
114-
# bind albumentations image only transforms
115-
torchlm.bind(albumentations.ColorJitter(p=0.5)),
116-
torchlm.bind(albumentations.GlassBlur(p=0.5)),
117-
# bind albumentations dual transforms
118-
torchlm.bind(albumentations.RandomCrop(height=200, width=200, p=0.5)),
119-
torchlm.bind(albumentations.Rotate(p=0.5)),
120-
# ...
121-
])
106+
torchlm.bind(torchvision.transforms.GaussianBlur(kernel_size=(5, 25)), prob=0.5),
107+
torchlm.bind(albumentations.ColorJitter(p=0.5))
108+
])
122109
```
123-
* **bind** custom callable array or Tensor functions through **torchlm.bind**
110+
See [transforms.md](docs/api/transforms.md) for supported transforms sets and more example can be found at [test/transforms.py](test/transforms.py).
111+
112+
<details>
113+
<summary> bind custom callable array or Tensor functions through torchlm.bind </summary>
124114

125115
```python
126116
# First, defined your custom functions
127-
def callable_array_noop(img: np.ndarray, landmarks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
128-
# do some transform here ...
117+
def callable_array_noop(img: np.ndarray, landmarks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: # do some transform here ...
129118
return img.astype(np.uint32), landmarks.astype(np.float32)
130119

131-
def callable_tensor_noop(img: Tensor, landmarks: Tensor) -> Tuple[Tensor, Tensor]:
132-
# do some transform here ...
120+
def callable_tensor_noop(img: Tensor, landmarks: Tensor) -> Tuple[Tensor, Tensor]: # do some transform here ...
133121
return img, landmarks
134122
```
135123

136124
```python
137125
# Then, bind your functions and put it into the transforms pipeline.
138126
transform = torchlm.LandmarksCompose([
139-
# use native torchlm transforms
140-
torchlm.LandmarksRandomScale(prob=0.5),
141-
# bind custom callable array functions
142127
torchlm.bind(callable_array_noop, bind_type=torchlm.BindEnum.Callable_Array),
143-
# bind custom callable Tensor functions with a given prob
144-
torchlm.bind(callable_tensor_noop, bind_type=torchlm.BindEnum.Callable_Tensor, prob=0.5),
145-
# ...
146-
])
128+
torchlm.bind(callable_tensor_noop, bind_type=torchlm.BindEnum.Callable_Tensor, prob=0.5)
129+
])
147130
```
148-
<div align='center'>
149-
<img src='docs/res/124.jpg' height="100px" width="100px">
150-
<img src='docs/res/158.jpg' height="100px" width="100px">
151-
<img src='docs/res/386.jpg' height="100px" width="100px">
152-
<img src='docs/res/478.jpg' height="100px" width="100px">
153-
<img src='docs/res/537.jpg' height="100px" width="100px">
154-
<img src='docs/res/605.jpg' height="100px" width="100px">
155-
<img src='docs/res/802.jpg' height="100px" width="100px">
156-
<br>
157-
<img src='docs/res/2_wflw_484.jpg' height="100px" width="100px">
158-
<img src='docs/res/2_wflw_505.jpg' height="100px" width="100px">
159-
<img src='docs/res/2_wflw_529.jpg' height="100px" width="100px">
160-
<img src='docs/res/2_wflw_536.jpg' height="100px" width="100px">
161-
<img src='docs/res/2_wflw_669.jpg' height="100px" width="100px">
162-
<img src='docs/res/2_wflw_672.jpg' height="100px" width="100px">
163-
<img src='docs/res/2_wflw_741.jpg' height="100px" width="100px">
164-
</div>
131+
</details>
165132

133+
<details>
134+
<summary> some global debug setting for torchlm's transform </summary>
166135

167136
* setup logging mode as `True` globally might help you figure out the runtime details
168137
```python
169-
import torchlm
170138
# some global setting
171139
torchlm.set_transforms_debug(True)
172140
torchlm.set_transforms_logging(True)
173141
torchlm.set_autodtype_logging(True)
174-
```
142+
```
143+
175144
some detail information will show you at each runtime, the infos might look like
176145
```shell
177146
LandmarksRandomScale() AutoDtype Info: AutoDtypeEnum.Array_InOut
@@ -194,21 +163,98 @@ LandmarksRandomTranslate() Execution Flag: False
194163

195164
But, is ok if you pass a Tensor to a np.ndarray-like transform, **torchlm** will automatically be compatible with different data types and then wrap it back to the original type through a **autodtype** wrapper.
196165

166+
</details>
167+
168+
169+
### 🎉🎉Training
170+
In **torchlm**, each model have a high level and user-friendly API named `training`, here is a example of [PIPNet](https://github.com/jhb86253817/PIPNet).
171+
```python
172+
from torchlm.models import pipnet
173+
174+
model = pipnet(
175+
backbone="resnet18",
176+
pretrained=False,
177+
num_nb=10,
178+
num_lms=98,
179+
net_stride=32,
180+
input_size=256,
181+
meanface_type="wflw",
182+
backbone_pretrained=True,
183+
map_location="cuda",
184+
checkpoint=None
185+
)
186+
187+
model.training(
188+
self,
189+
annotation_path: str,
190+
criterion_cls: nn.Module = nn.MSELoss(),
191+
criterion_reg: nn.Module = nn.L1Loss(),
192+
learning_rate: float = 0.0001,
193+
cls_loss_weight: float = 10.,
194+
reg_loss_weight: float = 1.,
195+
num_nb: int = 10,
196+
num_epochs: int = 60,
197+
save_dir: Optional[str] = "./save",
198+
save_interval: Optional[int] = 10,
199+
save_prefix: Optional[str] = "",
200+
decay_steps: Optional[List[int]] = (30, 50),
201+
decay_gamma: Optional[float] = 0.1,
202+
device: Optional[Union[str, torch.device]] = "cuda",
203+
transform: Optional[transforms.LandmarksCompose] = None,
204+
coordinates_already_normalized: Optional[bool] = False,
205+
**kwargs: Any # params for DataLoader
206+
) -> nn.Module:
207+
```
208+
Please jump to the entry point of the function for the detail documentations of **training** API for each defined models in torchlm, e.g [pipnet/_impls.py#L159](https://github.com/DefTruth/torchlm/blob/main/torchlm/models/pipnet/_impls.py#L159). Further, the model implementation plan is as follows:
197209

198-
* Supported Transforms Sets, see [transforms.md](docs/api/transforms.md). A detail example can be found at [test/transforms.py](test/transforms.py).
210+
❔ YOLOX ❔ YOLOv5 ❔ NanoDet ✅ [PIPNet](https://github.com/jhb86253817/PIPNet) ❔ ResNet ❔ MobileNet ❔ ShuffleNet ❔...
199211

200-
### Training(TODO)
201-
* [ ] YOLOX
202-
* [ ] YOLOv5
203-
* [ ] NanoDet
204-
* [ ] PIPNet
205-
* [ ] ResNet
206-
* [ ] MobileNet
207-
* [ ] ShuffleNet
208-
* [ ] ...
212+
✅ = known work and official supported, ❔ = in my plan, but not coming soon.
209213

210-
### Inference
214+
### 👀👇 Inference
215+
#### C++ API
211216
The ONNXRuntime(CPU/GPU), MNN, NCNN and TNN C++ inference of **torchlm** will be release at [lite.ai.toolkit](https://github.com/DefTruth/lite.ai.toolkit).
217+
#### Python API
218+
In **torchlm**, we offer a high level API named `runtime.bind` to bind any models in torchlm and then you can run the `runtime.forward` API to get the output landmarks and bboxes, here is a example of [PIPNet](https://github.com/jhb86253817/PIPNet).
219+
```python
220+
import cv2
221+
import torchlm
222+
from torchlm.tools import faceboxesv2
223+
from torchlm.models import pipnet
224+
225+
def test_pipnet_runtime():
226+
img_path = "./1.jpg"
227+
save_path = "./1.jpg"
228+
checkpoint = "./pipnet_resnet18_10x98x32x256_wflw.pth"
229+
image = cv2.imread(img_path)
230+
231+
torchlm.runtime.bind(faceboxesv2())
232+
torchlm.runtime.bind(
233+
pipnet(
234+
backbone="resnet18",
235+
pretrained=True,
236+
num_nb=10,
237+
num_lms=98,
238+
net_stride=32,
239+
input_size=256,
240+
meanface_type="wflw",
241+
backbone_pretrained=True,
242+
map_location="cpu",
243+
checkpoint=checkpoint
244+
)
245+
)
246+
landmarks, bboxes = torchlm.runtime.forward(image)
247+
image = torchlm.utils.draw_bboxes(image, bboxes=bboxes)
248+
image = torchlm.utils.draw_landmarks(image, landmarks=landmarks)
249+
250+
cv2.imwrite(save_path, image)
251+
```
252+
<div align='center'>
253+
<img src='docs/assets/pipnet0.jpg' height="180px" width="180px">
254+
<img src='docs/assets/pipnet_300W_CELEBA_model.gif' height="180px" width="180px">
255+
<img src='docs/assets/pipnet_shaolin_soccer.gif' height="180px" width="180px">
256+
<img src='docs/assets/pipnet_WFLW_model.gif' height="180px" width="180px">
257+
</div>
212258

213259
## 📖 Documentations
214260
* [x] [Data Augmentation's API](docs/api/transforms.md)

docs/assets/pipnet0.jpg

201 KB
Loading
18.8 MB
Loading

docs/assets/pipnet_WFLW_model.gif

19.3 MB
Loading
14.1 MB
Loading

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# torchlm
2-
opencv-python-headless>=4.5.2
2+
opencv-python-headless>=4.3.0
33
numpy>=1.14.4
44
torch>=1.6.0
55
torchvision>=0.9.0

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def get_long_description():
2525
url="https://github.com/DefTruth/torchlm",
2626
packages=setuptools.find_packages(),
2727
install_requires=[
28-
"opencv-python-headless>=4.5.2",
28+
"opencv-python-headless>=4.3.0",
2929
"numpy>=1.14.4",
3030
"torch>=1.6.0",
3131
"torchvision>=0.8.0",
3232
"albumentations>=1.1.0",
3333
"onnx>=1.8.0",
3434
"onnxruntime>=1.7.0",
35-
"tqdm>=4.60.0"
35+
"tqdm>=4.10.0"
3636
],
3737
classifiers=[
3838
"Programming Language :: Python :: 3",

torchlm/data/_converters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
import cv2
3+
import numpy as np
4+
from abc import ABCMeta, abstractmethod
5+
from typing import Tuple, Optional, List
6+
7+
8+
class BaseConverter(object):
9+
__metaclass__ = ABCMeta
10+
11+
@abstractmethod
12+
def convert(self, *args, **kwargs):
13+
raise NotImplementedError

torchlm/models/pipnet/_impls.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,33 @@ def training(
157157
coordinates_already_normalized: Optional[bool] = False,
158158
**kwargs: Any # params for DataLoader
159159
) -> nn.Module:
160+
"""
161+
:param annotation_path: the path to a annotation file, the format must be
162+
"img0_path img_path x0 y0 x1 y1 ... xn-1,yn-1"
163+
"img1_path img_path x0 y0 x1 y1 ... xn-1,yn-1"
164+
"img2_path img_path x0 y0 x1 y1 ... xn-1,yn-1"
165+
"img3_path img_path x0 y0 x1 y1 ... xn-1,yn-1"
166+
...
167+
:param criterion_cls: loss criterion for PIPNet heatmap classification, default MSELoss
168+
:param criterion_reg: loss criterion for PIPNet offsets regression, default L1Loss
169+
:param learning_rate: learning rate, default 0.0001
170+
:param cls_loss_weight: weight for heatmap classification
171+
:param reg_loss_weight: weight for offsets regression
172+
:param num_nb: the number of Nearest-neighbor landmarks for NRM, default 10
173+
:param num_epochs: the number of training epochs
174+
:param save_dir: the dir to save checkpoints
175+
:param save_interval: the interval to save checkpoints
176+
:param save_prefix: the prefix to save checkpoints, the saved name would look like
177+
{save_prefix}-epoch{epoch}-loss{epoch_loss}.pth
178+
:param decay_steps: decay steps for learning rate
179+
:param decay_gamma: decay gamma for learning rate
180+
:param device: training device, default cuda.
181+
:param transform: user specific transform. If None, torchlm will build a default transform,
182+
more details can be found at `torchlm.transforms.build_default_transform`
183+
:param coordinates_already_normalized: denoted the label in annotation_path is normalized(by image size) of not
184+
:param kwargs: params for DataLoader
185+
:return: A trained model.
186+
"""
160187
device = device if torch.cuda.is_available() else "cpu"
161188
# prepare dataset
162189
default_dataset = _PIPTrainDataset(

0 commit comments

Comments
 (0)