Skip to content

Commit c0c9a96

Browse files
No public description
PiperOrigin-RevId: 569017854
1 parent 5a64033 commit c0c9a96

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""This module provides utilities for executing shell commands.
16+
17+
It particularly downloads and extracts Mask RCNN models from the TensorFlow
18+
model garden. It includes a function to execute shell commands and
19+
a custom exception to handle errors that arise from command execution.
20+
21+
Functions:
22+
- execute_command(cmd: str) -> str: Executes a shell command and returns its
23+
standard output. Raises
24+
a CommandExecutionError if the command execution fails.
25+
26+
Exceptions:
27+
- CommandExecutionError: Custom exception that's raised when there's an
28+
error executing a shell command.
29+
30+
Usage:
31+
The main purpose of this module is to download two specific Mask RCNN models
32+
and unzip them. The module
33+
performs these operations when imported.
34+
35+
Note:
36+
It's recommended to not perform actions like downloading files on module
37+
import in production applications.
38+
It's better to move such tasks inside a function or a main block to allow
39+
for more controlled execution.
40+
"""
41+
import argparse
42+
import os
43+
import subprocess
44+
45+
46+
class CommandExecutionError(Exception):
47+
"""Raised when there's an error executing a shell command."""
48+
49+
def __init__(self, cmd, returncode, stderr):
50+
super().__init__(f"Error executing command: {cmd}. Error: {stderr}")
51+
self.cmd = cmd
52+
self.returncode = returncode
53+
self.stderr = stderr
54+
55+
56+
def execute_command(cmd: str) -> str:
57+
"""Executes a shell command and returns its output."""
58+
result = subprocess.run(
59+
cmd,
60+
shell=True,
61+
stdout=subprocess.PIPE,
62+
stderr=subprocess.PIPE,
63+
check=False,
64+
)
65+
66+
if result.returncode != 0:
67+
raise CommandExecutionError(
68+
cmd, result.returncode, result.stderr.decode("utf-8")
69+
)
70+
71+
return result.stdout.decode("utf-8")
72+
73+
74+
def main(_) -> None:
75+
# Download the provided files
76+
execute_command(f"wget {args.url1}")
77+
execute_command(f"wget {args.url2}")
78+
79+
# Create directories
80+
os.makedirs("material", exist_ok=True)
81+
os.makedirs("material_form", exist_ok=True)
82+
83+
# Unzip the provided files
84+
zip_file1 = os.path.basename(args.url1)
85+
zip_file2 = os.path.basename(args.url2)
86+
execute_command(f"unzip {zip_file1} -d material/")
87+
execute_command(f"unzip {zip_file2} -d material_form/")
88+
89+
90+
if __name__ == "__main__":
91+
parser = argparse.ArgumentParser(
92+
description="Download and extract Mask RCNN models."
93+
)
94+
parser.add_argument("material_url", help="repo url for material model")
95+
parser.add_argument(
96+
"material_form_url", help="repo url for material form model"
97+
)
98+
99+
args = parser.parse_args()
100+
main(args)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Load labels for model prediction.
16+
17+
Given paths of CSV files, task is to import them and convert into a
18+
form required for mapping with the model output.
19+
"""
20+
import csv
21+
from typing import TypedDict
22+
23+
24+
class ItemDict(TypedDict):
25+
id: int
26+
name: str
27+
supercategory: str
28+
29+
30+
def read_csv_to_list(file_path: str) -> list[str]:
31+
"""Reads a CSV file and returns its contents as a list.
32+
33+
This function reads the given CSV file, skips the header, and assumes
34+
there is only one column in the CSV. It returns the contents as a list of
35+
strings.
36+
37+
Args:
38+
file_path: The path to the CSV file.
39+
40+
Returns:
41+
The contents of the CSV file as a list of strings.
42+
"""
43+
data_list = []
44+
with open(file_path, 'r') as csvfile:
45+
reader = csv.reader(csvfile)
46+
next(reader) # Skip the header row if present
47+
for row in reader:
48+
data_list.append(row[0]) # Assuming there is only one column in the CSV
49+
return data_list
50+
51+
52+
def categories_dictionary(objects: list[str]) -> dict[int, ItemDict]:
53+
"""This function takes a list of objects and returns a dictionaries.
54+
55+
A dictionary of objects, where each object is represented by a dictionary
56+
with the following keys:
57+
- id: The ID of the object.
58+
- name: The name of the object.
59+
- supercategory: The supercategory of the object.
60+
61+
Args:
62+
objects: A list of strings, where each string is the name of an
63+
object.
64+
65+
Returns:
66+
A tuple of two dictionaries, as described above.
67+
"""
68+
category_index = {}
69+
70+
for num, obj_name in enumerate(objects, start=1):
71+
obj_dict = {'id': num, 'name': obj_name, 'supercategory': 'objects'}
72+
category_index[num] = obj_dict
73+
74+
return category_index
75+
76+
77+
def load_labels(
78+
label_paths: dict[str, str]
79+
) -> tuple[list[list[str]], dict[int, ItemDict]]:
80+
"""Loads labels, combines them, and formats them for prediction.
81+
82+
This function reads labels for multiple models, combines the labels in
83+
order to predict a single label output, and formats them into the desired
84+
structure required for prediction.
85+
86+
Args:
87+
label_paths: Dictionary of label paths for different models.
88+
89+
Returns:
90+
- A list of lists containing individual category indices for each
91+
model.
92+
- A dictionary of combined category indices in the desired format for
93+
prediction.
94+
95+
Note:
96+
- The function assumes there are exactly two models.
97+
- Inserts a category 'Na' for both models in case there is no detection.
98+
- The total number of predicted labels for a combined model is
99+
predetermined.
100+
"""
101+
# loading labels for both models
102+
category_indices = [read_csv_to_list(label) for label in label_paths.values()]
103+
104+
# insert a cateory 'Na' for both models in case there is no detection
105+
for i in [0, 1]:
106+
category_indices[i].insert(0, 'Na')
107+
108+
# combine the labels for both models in order to predict a single label output
109+
combined_category_indices = []
110+
for i in category_indices[0]:
111+
for j in category_indices[1]:
112+
combined_category_indices.append(f'{i}_{j}')
113+
combined_category_indices.sort()
114+
115+
# convert the list of labels into a desired format required for prediction
116+
category_index = categories_dictionary(combined_category_indices)
117+
118+
return category_indices, category_index

0 commit comments

Comments
 (0)