|
28 | 28 | #include "doctest.h" |
29 | 29 | #include "mock_data_loader.h" |
30 | 30 |
|
| 31 | + |
31 | 32 | namespace triton { namespace perfanalyzer { |
32 | 33 |
|
33 | 34 | /// Helper class for testing the DataLoader |
@@ -104,6 +105,199 @@ TEST_CASE("dataloader: GetTotalSteps") |
104 | 105 | CHECK_EQ(dataloader.GetTotalSteps(2), 0); |
105 | 106 | } |
106 | 107 |
|
| 108 | +TEST_CASE("dataloader: ValidateIOExistsInModel") |
| 109 | +{ |
| 110 | + MockDataLoader dataloader; |
| 111 | + std::shared_ptr<ModelTensorMap> inputs = std::make_shared<ModelTensorMap>(); |
| 112 | + std::shared_ptr<ModelTensorMap> outputs = std::make_shared<ModelTensorMap>(); |
| 113 | + ModelTensor input1 = TestDataLoader::CreateTensor("INPUT1"); |
| 114 | + ModelTensor output1 = TestDataLoader::CreateTensor("OUTPUT1"); |
| 115 | + inputs->insert(std::make_pair(input1.name_, input1)); |
| 116 | + outputs->insert(std::make_pair(output1.name_, output1)); |
| 117 | + |
| 118 | + SUBCASE("Directory does not exist") |
| 119 | + { |
| 120 | + std::string data_directory = "non_existent_directory"; |
| 121 | + cb::Error status = |
| 122 | + dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory); |
| 123 | + CHECK( |
| 124 | + status.Message() == |
| 125 | + "Error: Directory does not exist or is not a directory: " |
| 126 | + "non_existent_directory"); |
| 127 | + CHECK(status.Err() == pa::GENERIC_ERROR); |
| 128 | + } |
| 129 | + |
| 130 | + SUBCASE("Directory is not a directory") |
| 131 | + { |
| 132 | + std::string data_directory = "tmp/test.txt"; |
| 133 | + std::ofstream file(data_directory); |
| 134 | + cb::Error status = |
| 135 | + dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory); |
| 136 | + CHECK( |
| 137 | + status.Message() == |
| 138 | + "Error: Directory does not exist or is not a directory: tmp/test.txt"); |
| 139 | + CHECK(status.Err() == pa::GENERIC_ERROR); |
| 140 | + std::remove(data_directory.c_str()); |
| 141 | + } |
| 142 | + |
| 143 | + SUBCASE("Valid directory but no corresponding files") |
| 144 | + { |
| 145 | + std::string data_directory = "valid_directory"; |
| 146 | + std::filesystem::create_directory(data_directory); |
| 147 | + std::ofstream(data_directory + "/invalid_file").close(); |
| 148 | + cb::Error status = |
| 149 | + dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory); |
| 150 | + std::filesystem::remove_all(data_directory); |
| 151 | + CHECK( |
| 152 | + status.Message() == |
| 153 | + "Provided data file 'invalid_file' does not correspond to a valid " |
| 154 | + "model input or output."); |
| 155 | + CHECK(status.Err() == pa::GENERIC_ERROR); |
| 156 | + } |
| 157 | + |
| 158 | + SUBCASE("Valid directory with corresponding files") |
| 159 | + { |
| 160 | + std::string data_directory = "valid_directory"; |
| 161 | + std::filesystem::create_directory(data_directory); |
| 162 | + std::ofstream(data_directory + "/INPUT1").close(); |
| 163 | + std::ofstream(data_directory + "/OUTPUT1").close(); |
| 164 | + cb::Error status = |
| 165 | + dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory); |
| 166 | + std::filesystem::remove_all(data_directory); |
| 167 | + CHECK(status.Message().empty()); |
| 168 | + CHECK(status.IsOk()); |
| 169 | + } |
| 170 | + |
| 171 | + SUBCASE("Valid directory with multiple input and output tensors") |
| 172 | + { |
| 173 | + ModelTensor input2 = TestDataLoader::CreateTensor("INPUT2"); |
| 174 | + ModelTensor output2 = TestDataLoader::CreateTensor("OUTPUT2"); |
| 175 | + |
| 176 | + inputs->insert(std::make_pair(input2.name_, input2)); |
| 177 | + outputs->insert(std::make_pair(output2.name_, output2)); |
| 178 | + |
| 179 | + std::string data_directory = "valid_directory_multiple"; |
| 180 | + std::filesystem::create_directory(data_directory); |
| 181 | + std::ofstream(data_directory + "/INPUT1").close(); |
| 182 | + std::ofstream(data_directory + "/INPUT2").close(); |
| 183 | + std::ofstream(data_directory + "/OUTPUT1").close(); |
| 184 | + std::ofstream(data_directory + "/OUTPUT2").close(); |
| 185 | + |
| 186 | + cb::Error status = |
| 187 | + dataloader.ValidateIOExistsInModel(inputs, outputs, data_directory); |
| 188 | + std::filesystem::remove_all(data_directory); |
| 189 | + CHECK(status.Message().empty()); |
| 190 | + CHECK(status.IsOk()); |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +TEST_CASE("dataloader: ReadDataFromJSON") |
| 195 | +{ |
| 196 | + DataLoader dataloader; |
| 197 | + std::shared_ptr<ModelTensorMap> inputs = std::make_shared<ModelTensorMap>(); |
| 198 | + std::shared_ptr<ModelTensorMap> outputs = std::make_shared<ModelTensorMap>(); |
| 199 | + ModelTensor input1 = TestDataLoader::CreateTensor("INPUT1"); |
| 200 | + ModelTensor output1 = TestDataLoader::CreateTensor("OUTPUT1"); |
| 201 | + |
| 202 | + inputs->insert(std::make_pair(input1.name_, input1)); |
| 203 | + outputs->insert(std::make_pair(output1.name_, output1)); |
| 204 | + |
| 205 | + SUBCASE("File does not exist") |
| 206 | + { |
| 207 | + std::string json_file = "non_existent_file.json"; |
| 208 | + cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file); |
| 209 | + CHECK(status.Message() == "failed to open file for reading provided data"); |
| 210 | + CHECK(status.Err() == pa::GENERIC_ERROR); |
| 211 | + } |
| 212 | + |
| 213 | + SUBCASE("Valid JSON file") |
| 214 | + { |
| 215 | + std::string json_file = "valid_file.json"; |
| 216 | + std::ofstream out(json_file); |
| 217 | + out << R"({ |
| 218 | + "data": [ |
| 219 | + { "INPUT1": [1] }, |
| 220 | + { "INPUT1": [2] }, |
| 221 | + { "INPUT1": [3] } |
| 222 | + ], |
| 223 | + "validation_data": [ |
| 224 | + { "OUTPUT1": [4] }, |
| 225 | + { "OUTPUT1": [5] }, |
| 226 | + { "OUTPUT1": [6] } |
| 227 | + ]})"; |
| 228 | + out.close(); |
| 229 | + |
| 230 | + cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file); |
| 231 | + std::filesystem::remove(json_file); |
| 232 | + CHECK(status.Message().empty()); |
| 233 | + CHECK(status.IsOk()); |
| 234 | + } |
| 235 | + |
| 236 | + SUBCASE("Invalid JSON file") |
| 237 | + { |
| 238 | + std::string json_file = "invalid_file.json"; |
| 239 | + std::ofstream out(json_file); |
| 240 | + out << R"({invalid_json: 1,)"; |
| 241 | + out.close(); |
| 242 | + |
| 243 | + cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file); |
| 244 | + std::filesystem::remove(json_file); |
| 245 | + |
| 246 | + CHECK( |
| 247 | + status.Message() == |
| 248 | + "failed to parse the specified json file for reading provided data"); |
| 249 | + CHECK(status.Err() == pa::GENERIC_ERROR); |
| 250 | + } |
| 251 | + |
| 252 | + SUBCASE("Multiple input and output tensors") |
| 253 | + { |
| 254 | + ModelTensor input2 = TestDataLoader::CreateTensor("INPUT2"); |
| 255 | + ModelTensor output2 = TestDataLoader::CreateTensor("OUTPUT2"); |
| 256 | + |
| 257 | + inputs->insert(std::make_pair(input2.name_, input2)); |
| 258 | + outputs->insert(std::make_pair(output2.name_, output2)); |
| 259 | + |
| 260 | + std::string json_file = "valid_file_multiple_input_output.json"; |
| 261 | + std::ofstream out(json_file); |
| 262 | + out << R"({ |
| 263 | + "data": [ |
| 264 | + { |
| 265 | + "INPUT1": [1], |
| 266 | + "INPUT2": [4] |
| 267 | + }, |
| 268 | + { |
| 269 | + "INPUT1": [2], |
| 270 | + "INPUT2": [5] |
| 271 | + }, |
| 272 | + { |
| 273 | + "INPUT1": [3], |
| 274 | + "INPUT2": [6] |
| 275 | + } |
| 276 | + ], |
| 277 | + "validation_data": [ |
| 278 | + { |
| 279 | + "OUTPUT1": [4], |
| 280 | + "OUTPUT2": [7] |
| 281 | + }, |
| 282 | + { |
| 283 | + "OUTPUT1": [5], |
| 284 | + "OUTPUT2": [8] |
| 285 | + }, |
| 286 | + { |
| 287 | + "OUTPUT1": [6], |
| 288 | + "OUTPUT2": [9] |
| 289 | + } |
| 290 | + ] |
| 291 | + })"; |
| 292 | + out.close(); |
| 293 | + |
| 294 | + cb::Error status = dataloader.ReadDataFromJSON(inputs, outputs, json_file); |
| 295 | + std::filesystem::remove(json_file); |
| 296 | + CHECK(status.Message().empty()); |
| 297 | + CHECK(status.IsOk()); |
| 298 | + } |
| 299 | +} |
| 300 | + |
107 | 301 | TEST_CASE("dataloader: GetInputData missing data") |
108 | 302 | { |
109 | 303 | MockDataLoader dataloader; |
|
0 commit comments