Skip to content

Commit 50f3d97

Browse files
authored
Style updates (#96)
* Bump version: 0.4.2 → 0.4.3 * style
1 parent ef351e4 commit 50f3d97

File tree

14 files changed

+142
-99
lines changed

14 files changed

+142
-99
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.4.2
2+
current_version = 0.4.3
33
commit = True
44
tag = True
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "milliontrees"
7-
version = "0.4.2"
7+
version = "0.4.3"
88
description = "Benchmark dataset for Airborne Tree Machine Learning"
99
readme = "README.md"
1010
license = { text = "MIT" }

src/milliontrees/common/data_loaders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs):
114114

115115

116116
class GroupSampler:
117-
"""Constructs batches by first sampling groups, then sampling data from those groups.
117+
"""Constructs batches by first sampling groups, then sampling data from
118+
those groups.
118119
119120
It drops the last batch if it's incomplete.
120121
"""

src/milliontrees/common/grouper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
class Grouper:
1313
"""Groupers group data points together based on their metadata.
1414
15-
They are used for training and evaluation, e.g., to measure the accuracies of different groups
16-
of data.
15+
They are used for training and evaluation, e.g., to measure the
16+
accuracies of different groups of data.
1717
"""
1818

1919
def __init__(self):

src/milliontrees/common/metrics/all_metrics.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def binary_logits_to_pred(logits):
4343

4444

4545
def pseudolabel_binary_logits(logits, confidence_threshold):
46-
"""Applies a confidence threshold to binary logits and generates pseudo- labels.
46+
"""Applies a confidence threshold to binary logits and generates pseudo-
47+
labels.
4748
4849
Args:
4950
logits (Tensor): A tensor of shape (batch_size, n_tasks) representing binary logits.
@@ -78,7 +79,8 @@ def pseudolabel_binary_logits(logits, confidence_threshold):
7879

7980

8081
def pseudolabel_multiclass_logits(logits, confidence_threshold):
81-
"""Applies a confidence threshold to multi-class logits and generates pseudo-labels.
82+
"""Applies a confidence threshold to multi-class logits and generates
83+
pseudo-labels.
8284
8385
Args:
8486
logits (Tensor): A tensor of shape (batch_size, ..., n_classes) representing multi-class logits.
@@ -145,7 +147,8 @@ def pseudolabel_detection(preds, confidence_threshold):
145147

146148

147149
def pseudolabel_detection_discard_empty(preds, confidence_threshold):
148-
"""Filters detection predictions based on a confidence threshold and discards empty entries.
150+
"""Filters detection predictions based on a confidence threshold and
151+
discards empty entries.
149152
150153
Args:
151154
preds (List[dict]): A list of length `batch_size`, where each entry is a dictionary
@@ -355,7 +358,8 @@ def __init__(self, name=None):
355358

356359

357360
class PrecisionAtRecall(Metric):
358-
"""Given a specific model threshold, determine the precision score achieved."""
361+
"""Given a specific model threshold, determine the precision score
362+
achieved."""
359363

360364
def __init__(self, threshold, score_fn=None, name=None):
361365
self.score_fn = score_fn
@@ -400,8 +404,8 @@ def worst(self, metrics):
400404

401405

402406
class DetectionAccuracy(ElementwiseMetric):
403-
"""Given a specific Intersection over union threshold, determine the accuracy achieved for a
404-
one-class detector."""
407+
"""Given a specific Intersection over union threshold, determine the
408+
accuracy achieved for a one-class detector."""
405409

406410
def __init__(self,
407411
iou_threshold=0.3,
@@ -488,8 +492,8 @@ def worst(self, metrics):
488492

489493

490494
class KeypointAccuracy(ElementwiseMetric):
491-
"""Given a specific Intersection over union threshold, determine the accuracy achieved for a
492-
one-class detector."""
495+
"""Given a specific Intersection over union threshold, determine the
496+
accuracy achieved for a one-class detector."""
493497

494498
def __init__(self,
495499
distance_threshold=0.1,
@@ -563,8 +567,8 @@ def worst(self, metrics):
563567

564568

565569
class MaskAccuracy(ElementwiseMetric):
566-
"""Given a specific Intersection over union threshold, determine the accuracy achieved for a
567-
Mask R-CNN detector."""
570+
"""Given a specific Intersection over union threshold, determine the
571+
accuracy achieved for a Mask R-CNN detector."""
568572

569573
def __init__(self,
570574
iou_threshold=0.5,
@@ -737,8 +741,8 @@ def worst(self, metrics):
737741
class CountingError(ElementwiseMetric):
738742
"""Mean Absolute Error between ground truth and predicted detection counts.
739743
740-
Calculates MAE between the number of detections in ground truth vs predictions for each sample
741-
in the batch.
744+
Calculates MAE between the number of detections in ground truth vs
745+
predictions for each sample in the batch.
742746
"""
743747

744748
def __init__(self, score_threshold=0.1, name=None, geometry_name="y"):

src/milliontrees/common/metrics/loss.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def __init__(self, loss_fn, name=None):
1212
super().__init__(name=name)
1313

1414
def _compute(self, y_pred, y_true):
15-
"""Helper for computing element-wise metric, implemented for each metric.
15+
"""Helper for computing element-wise metric, implemented for each
16+
metric.
1617
1718
Args:
1819
- y_pred (Tensor): Predicted targets or model output
@@ -23,7 +24,8 @@ def _compute(self, y_pred, y_true):
2324
return self.loss_fn(y_pred, y_true)
2425

2526
def worst(self, metrics):
26-
"""Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
27+
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
28+
metric.
2729
2830
Args:
2931
- metrics (Tensor, numpy array, or list): Metrics
@@ -42,7 +44,8 @@ def __init__(self, loss_fn, name=None):
4244
super().__init__(name=name)
4345

4446
def _compute_element_wise(self, y_pred, y_true):
45-
"""Helper for computing element-wise metric, implemented for each metric.
47+
"""Helper for computing element-wise metric, implemented for each
48+
metric.
4649
4750
Args:
4851
- y_pred (Tensor): Predicted targets or model output
@@ -53,7 +56,8 @@ def _compute_element_wise(self, y_pred, y_true):
5356
return self.loss_fn(y_pred, y_true)
5457

5558
def worst(self, metrics):
56-
"""Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
59+
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
60+
metric.
5761
5862
Args:
5963
- metrics (Tensor, numpy array, or list): Metrics
@@ -81,7 +85,8 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true):
8185
return flattened_loss
8286

8387
def worst(self, metrics):
84-
"""Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
88+
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
89+
metric.
8590
8691
Args:
8792
- metrics (Tensor, numpy array, or list): Metrics

src/milliontrees/common/metrics/metric.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def __init__(self, name):
1010
self._name = name
1111

1212
def _compute(self, y_pred, y_true):
13-
"""Helper function for computing the metric. Subclasses should implement this.
13+
"""Helper function for computing the metric. Subclasses should
14+
implement this.
1415
1516
Args:
1617
- y_pred (Tensor): Predicted targets or model output
@@ -21,7 +22,8 @@ def _compute(self, y_pred, y_true):
2122
return NotImplementedError
2223

2324
def worst(self, metrics):
24-
"""Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
25+
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
26+
metric.
2527
2628
Args:
2729
- metrics (Tensor, numpy array, or list): Metrics
@@ -34,33 +36,35 @@ def worst(self, metrics):
3436
def name(self):
3537
"""Metric name.
3638
37-
Used to name the key in the results dictionaries returned by the metric.
39+
Used to name the key in the results dictionaries returned by the
40+
metric.
3841
"""
3942
return self._name
4043

4144
@property
4245
def agg_metric_field(self):
43-
"""The name of the key in the results dictionary returned by Metric.compute().
46+
"""The name of the key in the results dictionary returned by
47+
Metric.compute().
4448
45-
This should correspond to the aggregate metric computed on all of y_pred and y_true, in
46-
contrast to a group-wise evaluation.
49+
This should correspond to the aggregate metric computed on all
50+
of y_pred and y_true, in contrast to a group-wise evaluation.
4751
"""
4852
return f'{self.name}_all'
4953

5054
def group_metric_field(self, group_idx):
51-
"""The name of the keys corresponding to individual group evaluations in the results
52-
dictionary returned by Metric.compute_group_wise()."""
55+
"""The name of the keys corresponding to individual group evaluations
56+
in the results dictionary returned by Metric.compute_group_wise()."""
5357
return f'{self.name}_group:{group_idx}'
5458

5559
@property
5660
def worst_group_metric_field(self):
57-
"""The name of the keys corresponding to the worst-group metric in the results dictionary
58-
returned by Metric.compute_group_wise()."""
61+
"""The name of the keys corresponding to the worst-group metric in the
62+
results dictionary returned by Metric.compute_group_wise()."""
5963
return f'{self.name}_wg'
6064

6165
def group_count_field(self, group_idx):
62-
"""The name of the keys corresponding to each group's count in the results dictionary
63-
returned by Metric.compute_group_wise()."""
66+
"""The name of the keys corresponding to each group's count in the
67+
results dictionary returned by Metric.compute_group_wise()."""
6468
return f'count_group:{group_idx}'
6569

6670
def compute(self, y_pred, y_true, return_dict=True):
@@ -140,7 +144,8 @@ class ElementwiseMetric(Metric):
140144
"""Averages."""
141145

142146
def _compute_element_wise(self, y_pred, y_true):
143-
"""Helper for computing element-wise metric, implemented for each metric.
147+
"""Helper for computing element-wise metric, implemented for each
148+
metric.
144149
145150
Args:
146151
- y_pred (Tensor): Predicted targets or model output
@@ -151,7 +156,8 @@ def _compute_element_wise(self, y_pred, y_true):
151156
raise NotImplementedError
152157

153158
def worst(self, metrics):
154-
"""Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
159+
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
160+
metric.
155161
156162
Args:
157163
- metrics (Tensor, numpy array, or list): Metrics
@@ -182,7 +188,8 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups):
182188

183189
@property
184190
def agg_metric_field(self):
185-
"""The name of the key in the results dictionary returned by Metric.compute()."""
191+
"""The name of the key in the results dictionary returned by
192+
Metric.compute()."""
186193
return f'{self.name}_avg'
187194

188195
def compute_element_wise(self, y_pred, y_true, return_dict=True):

src/milliontrees/common/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def maximum(numbers, empty_val=0.):
4040

4141

4242
def split_into_groups(g):
43-
"""Splits the input tensor into unique groups and their corresponding indices.
43+
"""Splits the input tensor into unique groups and their corresponding
44+
indices.
4445
4546
Args:
4647
g (Tensor): A vector containing group labels.
@@ -64,9 +65,10 @@ def split_into_groups(g):
6465

6566

6667
def get_counts(g, n_groups):
67-
"""This differs from split_into_groups in how it handles missing groups. get_counts always
68-
returns a count array of length n_groups, whereas split_into_groups returns a unique_counts
69-
array whose length is the number of unique groups present in g.
68+
"""This differs from split_into_groups in how it handles missing groups.
69+
get_counts always returns a count array of length n_groups, whereas
70+
split_into_groups returns a unique_counts array whose length is the number
71+
of unique groups present in g.
7072
7173
Args:
7274
- g (ndarray): Vector of groups
@@ -140,7 +142,8 @@ def shuffle_arr(arr, seed=None):
140142

141143

142144
def threshold_at_recall(y_pred, y_true, global_recall=60):
143-
"""Calculate the model threshold used to achieve a desired global_recall level.
145+
"""Calculate the model threshold used to achieve a desired global_recall
146+
level.
144147
145148
Args:
146149
y_pred (Description of y_pred, Assumes that y_true is a vector of the true binary labels.)

src/milliontrees/datasets/TreeBoxes.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020

2121
class TreeBoxesDataset(MillionTreesDataset):
22-
"""A dataset of tree annotations with bounding box coordinates from multiple global sources.
22+
"""A dataset of tree annotations with bounding box coordinates from
23+
multiple global sources.
2324
2425
The dataset contains aerial imagery of trees with their corresponding bounding box annotations.
2526
Each tree is annotated with a 4-point bounding box (x_min, y_min, x_max, y_max).
@@ -225,8 +226,8 @@ def __init__(self,
225226
def eval(self, y_pred, y_true, metadata):
226227
"""Performs evaluation on the given predictions.
227228
228-
The main evaluation metric, detection_acc_avg_dom, measures the simple average of the
229-
detection accuracies of each domain.
229+
The main evaluation metric, detection_acc_avg_dom, measures the
230+
simple average of the detection accuracies of each domain.
230231
"""
231232

232233
results = {}
@@ -257,7 +258,8 @@ def eval(self, y_pred, y_true, metadata):
257258
return results, results_str
258259

259260
def _get_mini_versions_dict(self):
260-
"""Generate mini versions dict with modified URLs for smaller datasets."""
261+
"""Generate mini versions dict with modified URLs for smaller
262+
datasets."""
261263
mini_versions = {}
262264
for version, info in self._versions_dict.items():
263265
mini_info = info.copy()
@@ -290,7 +292,8 @@ def get_input(self, idx):
290292

291293
@staticmethod
292294
def _collate_fn(batch):
293-
"""Collates a batch by stacking `x` (features) and `metadata`, but not `y` (targets).
295+
"""Collates a batch by stacking `x` (features) and `metadata`, but not
296+
`y` (targets).
294297
295298
The batch is initially a tuple of individual data points: (item1, item2, item3, ...).
296299
After zipping, it transforms into a list of tuples:

src/milliontrees/datasets/TreePoints.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818

1919
class TreePointsDataset(MillionTreesDataset):
20-
"""The TreePoints dataset is a collection of tree annotations annotated as x,y locations.
20+
"""The TreePoints dataset is a collection of tree annotations annotated as
21+
x,y locations.
2122
2223
Dataset Splits:
2324
- random: For each source, 80% of the data is used for training and 20% for testing.
@@ -186,7 +187,8 @@ def __init__(self,
186187
super().__init__(root_dir, download, split_scheme)
187188

188189
def _get_mini_versions_dict(self):
189-
"""Generate mini versions dict with modified URLs for smaller datasets."""
190+
"""Generate mini versions dict with modified URLs for smaller
191+
datasets."""
190192
mini_versions = {}
191193
for version, info in self._versions_dict.items():
192194
mini_info = info.copy()
@@ -204,8 +206,8 @@ def get_annotation_from_filename(self, filename):
204206
return self._y_array[indices]
205207

206208
def eval(self, y_pred, y_true, metadata):
207-
"""The main evaluation metric, detection_acc_avg_dom, measures the simple average of the
208-
detection accuracies of each domain."""
209+
"""The main evaluation metric, detection_acc_avg_dom, measures the
210+
simple average of the detection accuracies of each domain."""
209211

210212
results = {}
211213
results_str = ''
@@ -256,8 +258,8 @@ def get_input(self, idx):
256258
def _collate_fn(batch):
257259
"""Stack x (batch[1]) and metadata (batch[0]), but not y.
258260
259-
originally, batch = (item1, item2, item3, item4) after zip, batch = [(item1[0], item2[0],
260-
..), ..]
261+
originally, batch = (item1, item2, item3, item4) after zip,
262+
batch = [(item1[0], item2[0], ..), ..]
261263
"""
262264
batch = list(zip(*batch))
263265
batch[1] = torch.stack(batch[1])

0 commit comments

Comments
 (0)