|
| 1 | +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Load labels for model prediction. |
| 16 | +
|
| 17 | +Given paths of CSV files, task is to import them and convert into a |
| 18 | +form required for mapping with the model output. |
| 19 | +""" |
| 20 | +import csv |
| 21 | +from typing import TypedDict |
| 22 | + |
| 23 | + |
| 24 | +class ItemDict(TypedDict): |
| 25 | + id: int |
| 26 | + name: str |
| 27 | + supercategory: str |
| 28 | + |
| 29 | + |
| 30 | +def read_csv_to_list(file_path: str) -> list[str]: |
| 31 | + """Reads a CSV file and returns its contents as a list. |
| 32 | +
|
| 33 | + This function reads the given CSV file, skips the header, and assumes |
| 34 | + there is only one column in the CSV. It returns the contents as a list of |
| 35 | + strings. |
| 36 | +
|
| 37 | + Args: |
| 38 | + file_path: The path to the CSV file. |
| 39 | +
|
| 40 | + Returns: |
| 41 | + The contents of the CSV file as a list of strings. |
| 42 | + """ |
| 43 | + data_list = [] |
| 44 | + with open(file_path, 'r') as csvfile: |
| 45 | + reader = csv.reader(csvfile) |
| 46 | + next(reader) # Skip the header row if present |
| 47 | + for row in reader: |
| 48 | + data_list.append(row[0]) # Assuming there is only one column in the CSV |
| 49 | + return data_list |
| 50 | + |
| 51 | + |
| 52 | +def categories_dictionary(objects: list[str]) -> dict[int, ItemDict]: |
| 53 | + """This function takes a list of objects and returns a dictionaries. |
| 54 | +
|
| 55 | + A dictionary of objects, where each object is represented by a dictionary |
| 56 | + with the following keys: |
| 57 | + - id: The ID of the object. |
| 58 | + - name: The name of the object. |
| 59 | + - supercategory: The supercategory of the object. |
| 60 | +
|
| 61 | + Args: |
| 62 | + objects: A list of strings, where each string is the name of an |
| 63 | + object. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + A tuple of two dictionaries, as described above. |
| 67 | + """ |
| 68 | + category_index = {} |
| 69 | + |
| 70 | + for num, obj_name in enumerate(objects, start=1): |
| 71 | + obj_dict = {'id': num, 'name': obj_name, 'supercategory': 'objects'} |
| 72 | + category_index[num] = obj_dict |
| 73 | + |
| 74 | + return category_index |
| 75 | + |
| 76 | + |
| 77 | +def load_labels( |
| 78 | + label_paths: dict[str, str] |
| 79 | +) -> tuple[list[list[str]], dict[int, ItemDict]]: |
| 80 | + """Loads labels, combines them, and formats them for prediction. |
| 81 | +
|
| 82 | + This function reads labels for multiple models, combines the labels in |
| 83 | + order to predict a single label output, and formats them into the desired |
| 84 | + structure required for prediction. |
| 85 | +
|
| 86 | + Args: |
| 87 | + label_paths: Dictionary of label paths for different models. |
| 88 | +
|
| 89 | + Returns: |
| 90 | + - A list of lists containing individual category indices for each |
| 91 | + model. |
| 92 | + - A dictionary of combined category indices in the desired format for |
| 93 | + prediction. |
| 94 | +
|
| 95 | + Note: |
| 96 | + - The function assumes there are exactly two models. |
| 97 | + - Inserts a category 'Na' for both models in case there is no detection. |
| 98 | + - The total number of predicted labels for a combined model is |
| 99 | + predetermined. |
| 100 | + """ |
| 101 | + # loading labels for both models |
| 102 | + category_indices = [read_csv_to_list(label) for label in label_paths.values()] |
| 103 | + |
| 104 | + # insert a cateory 'Na' for both models in case there is no detection |
| 105 | + for i in [0, 1]: |
| 106 | + category_indices[i].insert(0, 'Na') |
| 107 | + |
| 108 | + # combine the labels for both models in order to predict a single label output |
| 109 | + combined_category_indices = [] |
| 110 | + for i in category_indices[0]: |
| 111 | + for j in category_indices[1]: |
| 112 | + combined_category_indices.append(f'{i}_{j}') |
| 113 | + combined_category_indices.sort() |
| 114 | + |
| 115 | + # convert the list of labels into a desired format required for prediction |
| 116 | + category_index = categories_dictionary(combined_category_indices) |
| 117 | + |
| 118 | + return category_indices, category_index |
0 commit comments