|
20 | 20 |
|
21 | 21 | from official.common import dataset_fn
|
22 | 22 | from official.core import base_task
|
| 23 | +from official.core import config_definitions as cfg |
23 | 24 | from official.core import task_factory
|
24 | 25 | from official.vision.configs import retinanet as exp_cfg
|
| 26 | +from official.vision.dataloaders import input_reader |
25 | 27 | from official.vision.dataloaders import input_reader_factory
|
26 | 28 | from official.vision.dataloaders import retinanet_input
|
27 | 29 | from official.vision.dataloaders import tf_example_decoder
|
@@ -130,10 +132,33 @@ def build_inputs(self,
|
130 | 132 | skip_crowd_during_training=params.parser.skip_crowd_during_training,
|
131 | 133 | max_num_instances=params.parser.max_num_instances)
|
132 | 134 |
|
| 135 | + combine_fn = None |
| 136 | + if params.is_training and params.weights: |
| 137 | + # Combine multiple datasets using weighted sampling. |
| 138 | + if (not isinstance(params.input_path, cfg.base_config.Config) or |
| 139 | + not isinstance(params.weights, cfg.base_config.Config)): |
| 140 | + raise ValueError( |
| 141 | + 'input_path and weights must both be a Config to use weighted ' |
| 142 | + 'sampling.') |
| 143 | + input_paths = params.input_path.as_dict() |
| 144 | + weights = params.weights.as_dict() |
| 145 | + if len(input_paths) != len(weights): |
| 146 | + raise ValueError( |
| 147 | + 'The number of input_path and weights must be the same, but got %d ' |
| 148 | + 'input_paths and %d weights.' % (len(input_paths), len(weights))) |
| 149 | + |
| 150 | + for k in input_paths.keys(): |
| 151 | + if k not in weights: |
| 152 | + raise ValueError( |
| 153 | + 'input_path key \'%s\' does not have a corresponding weight.' % k) |
| 154 | + |
| 155 | + combine_fn = input_reader.build_weighted_sampling_combine_fn(weights) |
| 156 | + |
133 | 157 | reader = input_reader_factory.input_reader_generator(
|
134 | 158 | params,
|
135 | 159 | dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
|
136 | 160 | decoder_fn=decoder.decode,
|
| 161 | + combine_fn=combine_fn, |
137 | 162 | parser_fn=parser.parse_fn(params.is_training))
|
138 | 163 | dataset = reader.read(input_context=input_context)
|
139 | 164 |
|
|
0 commit comments