Skip to content

Commit b50c213

Browse files
tensorflower-gardenerfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 494223959
1 parent a619f51 commit b50c213

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

official/vision/configs/retinanet.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616

1717
import dataclasses
1818
import os
19-
from typing import List, Optional, Union
19+
from typing import Optional, List, Sequence, Union
2020

2121
from official.core import config_definitions as cfg
2222
from official.core import exp_factory
2323
from official.modeling import hyperparams
2424
from official.modeling import optimization
25+
from official.modeling.hyperparams import base_config
2526
from official.vision.configs import common
2627
from official.vision.configs import decoders
2728
from official.vision.configs import backbones
@@ -65,8 +66,14 @@ class Parser(hyperparams.Config):
6566

6667
@dataclasses.dataclass
6768
class DataConfig(cfg.DataConfig):
68-
"""Input config for training."""
69-
input_path: str = ''
69+
"""Input config for training.
70+
71+
Attributes:
72+
weights: Sampling weights for each corresponding input_path. If used, then
73+
input_path must be a config with matching keys.
74+
"""
75+
input_path: Union[Sequence[str], str, base_config.Config] = ''
76+
weights: Optional[base_config.Config] = None
7077
global_batch_size: int = 0
7178
is_training: bool = False
7279
dtype: str = 'bfloat16'

official/vision/dataloaders/input_reader.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,31 @@
1414

1515
"""Dataset reader for vision model garden."""
1616

17-
from typing import Any, Callable, Optional, Tuple
17+
from typing import Any, Callable, Mapping, Optional, Tuple
1818

1919
import tensorflow as tf
2020

2121
from official.core import config_definitions as cfg
2222
from official.core import input_reader
2323

2424

25+
def build_weighted_sampling_combine_fn(
26+
weights: Mapping[Any, Any]) -> Callable[[tf.data.Dataset], tf.data.Dataset]:
27+
"""Builds a combine_fn using weighted sampling."""
28+
29+
def combine_fn(datasets: Mapping[Any, tf.data.Dataset]) -> tf.data.Dataset:
30+
"""Combines multiple datasets using weighted sampling."""
31+
ds = []
32+
ws = []
33+
for k, dataset in datasets.items():
34+
ds.append(dataset)
35+
ws.append(weights[k])
36+
return tf.data.Dataset.sample_from_datasets(
37+
ds, ws, stop_on_empty_dataset=True)
38+
39+
return combine_fn
40+
41+
2542
def calculate_batch_sizes(total_batch_size: int,
2643
pseudo_label_ratio: float) -> Tuple[int, int]:
2744
"""Calculates labeled and pseudo-labeled dataset batch sizes.

official/vision/tasks/retinanet.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020

2121
from official.common import dataset_fn
2222
from official.core import base_task
23+
from official.core import config_definitions as cfg
2324
from official.core import task_factory
2425
from official.vision.configs import retinanet as exp_cfg
26+
from official.vision.dataloaders import input_reader
2527
from official.vision.dataloaders import input_reader_factory
2628
from official.vision.dataloaders import retinanet_input
2729
from official.vision.dataloaders import tf_example_decoder
@@ -130,10 +132,33 @@ def build_inputs(self,
130132
skip_crowd_during_training=params.parser.skip_crowd_during_training,
131133
max_num_instances=params.parser.max_num_instances)
132134

135+
combine_fn = None
136+
if params.is_training and params.weights:
137+
# Combine multiple datasets using weighted sampling.
138+
if (not isinstance(params.input_path, cfg.base_config.Config) or
139+
not isinstance(params.weights, cfg.base_config.Config)):
140+
raise ValueError(
141+
'input_path and weights must both be a Config to use weighted '
142+
'sampling.')
143+
input_paths = params.input_path.as_dict()
144+
weights = params.weights.as_dict()
145+
if len(input_paths) != len(weights):
146+
raise ValueError(
147+
'The number of input_path and weights must be the same, but got %d '
148+
'input_paths and %d weights.' % (len(input_paths), len(weights)))
149+
150+
for k in input_paths.keys():
151+
if k not in weights:
152+
raise ValueError(
153+
'input_path key \'%s\' does not have a corresponding weight.' % k)
154+
155+
combine_fn = input_reader.build_weighted_sampling_combine_fn(weights)
156+
133157
reader = input_reader_factory.input_reader_generator(
134158
params,
135159
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
136160
decoder_fn=decoder.decode,
161+
combine_fn=combine_fn,
137162
parser_fn=parser.parse_fn(params.is_training))
138163
dataset = reader.read(input_context=input_context)
139164

0 commit comments

Comments
 (0)