Skip to content

Commit 4c99ab7

Browse files
vdumoulintensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 381089283
1 parent 8b47c48 commit 4c99ab7

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

official/vision/beta/data/process_coco_few_shot.sh

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
# Processes the COCO few-shot benchmark into TFRecord files. Requires `wget`.
44

55
tmp_dir=$(mktemp -d -t coco-XXXXXXXXXX)
6+
base_image_dir="/tmp/coco_images"
67
output_dir="/tmp/coco_few_shot"
7-
while getopts "o:" o; do
8+
while getopts ":i:o:" o; do
89
case "${o}" in
910
o) output_dir=${OPTARG} ;;
10-
*) echo "Usage: ${0} [-o <output_dir>]" 1>&2; exit 1 ;;
11+
i) base_image_dir=${OPTARG} ;;
12+
*) echo "Usage: ${0} [-i <base_image_dir>] [-o <output_dir>]" 1>&2; exit 1 ;;
1113
esac
1214
done
1315

@@ -25,8 +27,8 @@ for seed in {0..9}; do
2527
for shots in 10 30; do
2628
python create_coco_tf_record.py \
2729
--logtostderr \
28-
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/train2014 \
29-
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/val2014 \
30+
--image_dir="${base_image_dir}/train2014" \
31+
--image_dir="${base_image_dir}/val2014" \
3032
--image_info_file="${tmp_dir}/${shots}shot_seed${seed}.json" \
3133
--object_annotations_file="${tmp_dir}/${shots}shot_seed${seed}.json" \
3234
--caption_annotations_file="" \
@@ -37,8 +39,8 @@ done
3739

3840
python create_coco_tf_record.py \
3941
--logtostderr \
40-
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/train2014 \
41-
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/val2014 \
42+
--image_dir="${base_image_dir}/train2014" \
43+
--image_dir="${base_image_dir}/val2014" \
4244
--image_info_file="${tmp_dir}/datasplit/5k.json" \
4345
--object_annotations_file="${tmp_dir}/datasplit/5k.json" \
4446
--caption_annotations_file="" \
@@ -47,12 +49,22 @@ python create_coco_tf_record.py \
4749

4850
python create_coco_tf_record.py \
4951
--logtostderr \
50-
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/train2014 \
51-
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/val2014 \
52+
--image_dir="${base_image_dir}/train2014" \
53+
--image_dir="${base_image_dir}/val2014" \
5254
--image_info_file="${tmp_dir}/datasplit/trainvalno5k_base.json" \
5355
--object_annotations_file="${tmp_dir}/datasplit/trainvalno5k_base.json" \
5456
--caption_annotations_file="" \
5557
--output_file_prefix="${output_dir}/trainvalno5k_base" \
5658
--num_shards=200
5759

60+
python create_coco_tf_record.py \
61+
--logtostderr \
62+
--image_dir="${base_image_dir}/train2014" \
63+
--image_dir="${base_image_dir}/val2014" \
64+
--image_info_file="${tmp_dir}/datasplit/5k_base.json" \
65+
--object_annotations_file="${tmp_dir}/datasplit/5k_base.json" \
66+
--caption_annotations_file="" \
67+
--output_file_prefix="${output_dir}/5k_base" \
68+
--num_shards=10
69+
5870
rm -rf "${tmp_dir}"

official/vision/beta/data/process_coco_few_shot_json_files.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,18 @@
8787
def main(unused_argv):
8888
workdir = FLAGS.workdir
8989

90-
# Filter novel class annotations from the training set.
91-
file_path = os.path.join(workdir, 'datasplit', 'trainvalno5k.json')
92-
with tf.io.gfile.GFile(file_path, 'r') as f:
93-
json_dict = json.load(f)
94-
95-
json_dict['annotations'] = [a for a in json_dict['annotations']
96-
if a['category_id'] in BASE_CLASS_IDS]
97-
output_path = os.path.join(workdir, 'datasplit', 'trainvalno5k_base.json')
98-
with tf.io.gfile.GFile(output_path, 'w') as f:
99-
json.dump(json_dict, f)
90+
# Filter novel class annotations from the training and validation sets.
91+
for name in ('trainvalno5k', '5k'):
92+
file_path = os.path.join(workdir, 'datasplit', '{}.json'.format(name))
93+
with tf.io.gfile.GFile(file_path, 'r') as f:
94+
json_dict = json.load(f)
95+
96+
json_dict['annotations'] = [a for a in json_dict['annotations']
97+
if a['category_id'] in BASE_CLASS_IDS]
98+
output_path = os.path.join(
99+
workdir, 'datasplit', '{}_base.json'.format(name))
100+
with tf.io.gfile.GFile(output_path, 'w') as f:
101+
json.dump(json_dict, f)
100102

101103
for seed, shots in itertools.product(SEEDS, SHOTS):
102104
# Retrieve all examples for a given seed and shots setting.

0 commit comments

Comments
 (0)