@@ -726,6 +726,9 @@ def __compute_loss__(
726726 if self .consecutive_slices > 1 :
727727 batch_size , slices = target_segmentation .shape [:2 ]
728728 target_segmentation = target_segmentation .reshape (batch_size * slices , * target_segmentation .shape [2 :])
729+ predictions_segmentation = predictions_segmentation .reshape ( # type: ignore
730+ batch_size * slices , * predictions_segmentation .shape [2 :] # type: ignore
731+ )
729732
730733 segmentation_loss = self .process_segmentation_loss (target_segmentation , predictions_segmentation , attrs )
731734
@@ -798,26 +801,6 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001
798801 while isinstance (predictions_segmentation , list ):
799802 predictions_segmentation = predictions_segmentation [- 1 ]
800803
801- if self .consecutive_slices > 1 :
802- # reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y]
803- batch_size = target_segmentation .shape [0 ] // self .consecutive_slices
804- target_segmentation = target_segmentation .reshape (
805- batch_size , self .consecutive_slices , * target_segmentation .shape [1 :]
806- )
807- target_reconstruction = target_reconstruction .reshape (
808- batch_size , self .consecutive_slices , * target_reconstruction .shape [2 :]
809- )
810- predictions_segmentation = predictions_segmentation .reshape (
811- batch_size , self .consecutive_slices , * predictions_segmentation .shape [2 :]
812- )
813- predictions_reconstruction = predictions_reconstruction .reshape (
814- batch_size , self .consecutive_slices , * predictions_reconstruction .shape [1 :]
815- )
816- target_segmentation = target_segmentation [:, self .consecutive_slices // 2 ]
817- target_reconstruction = target_reconstruction [:, self .consecutive_slices // 2 ]
818- predictions_segmentation = predictions_segmentation [:, self .consecutive_slices // 2 ]
819- predictions_reconstruction = predictions_reconstruction [:, self .consecutive_slices // 2 ]
820-
821804 fname = attrs ["fname" ]
822805 slice_idx = attrs ["slice_idx" ]
823806
@@ -1267,33 +1250,74 @@ def inference_step( # noqa: MC0001
12671250 for class_idx , thres in enumerate (self .segmentation_classes_thresholds ):
12681251 if self .segmentation_activation == "sigmoid" :
12691252 if isinstance (predictions_segmentation , list ):
1270- cond = [torch .sigmoid (pred [:, class_idx ]) for pred in predictions_segmentation ]
1253+ cond = [
1254+ torch .sigmoid (pred [:, class_idx ])
1255+ if (self .consecutive_slices == 1 or self .dimensionality == 2 )
1256+ else torch .sigmoid (pred [:, :, class_idx ])
1257+ for pred in predictions_segmentation
1258+ ]
12711259 else :
1272- cond = torch .sigmoid (predictions_segmentation [:, class_idx ])
1260+ cond = (
1261+ torch .sigmoid (predictions_segmentation [:, class_idx ])
1262+ if (self .consecutive_slices == 1 or self .dimensionality == 2 )
1263+ else torch .sigmoid (predictions_segmentation [:, :, class_idx ])
1264+ )
12731265 elif self .segmentation_activation == "softmax" :
12741266 if isinstance (predictions_segmentation , list ):
1275- cond = [torch .softmax (pred [:, class_idx ], dim = 1 ) for pred in predictions_segmentation ]
1267+ cond = [
1268+ torch .softmax (pred [:, class_idx ], dim = 1 )
1269+ if (self .consecutive_slices == 1 or self .dimensionality == 2 )
1270+ else torch .softmax (pred [:, :, class_idx ], dim = 1 )
1271+ for pred in predictions_segmentation
1272+ ]
12761273 else :
1277- cond = torch .softmax (predictions_segmentation [:, class_idx ], dim = 1 )
1274+ cond = (
1275+ torch .softmax (predictions_segmentation [:, class_idx ], dim = 1 )
1276+ if (self .consecutive_slices == 1 or self .dimensionality == 2 )
1277+ else torch .softmax (predictions_segmentation [:, :, class_idx ], dim = 1 )
1278+ )
12781279 else :
12791280 if isinstance (predictions_segmentation , list ):
1280- cond = [pred [:, class_idx ] for pred in predictions_segmentation ]
1281+ cond = [
1282+ pred [:, class_idx ]
1283+ if self .consecutive_slices == 1 or self .dimensionality == 2
1284+ else pred [:, :, class_idx ]
1285+ for pred in predictions_segmentation
1286+ ]
12811287 else :
1282- cond = predictions_segmentation [:, class_idx ]
1288+ cond = (
1289+ predictions_segmentation [:, class_idx ]
1290+ if (self .consecutive_slices == 1 or self .dimensionality == 2 )
1291+ else predictions_segmentation [:, :, class_idx ]
1292+ )
12831293
12841294 if isinstance (predictions_segmentation , list ):
12851295 for idx , pred in enumerate (predictions_segmentation ):
1286- predictions_segmentation [idx ][:, class_idx ] = torch .where (
1287- cond [idx ] >= thres ,
1288- predictions_segmentation [idx ][:, class_idx ],
1289- torch .zeros_like (predictions_segmentation [idx ][:, class_idx ]),
1290- )
1296+ if self .consecutive_slices == 1 or self .dimensionality == 2 :
1297+ predictions_segmentation [idx ][:, class_idx ] = torch .where (
1298+ cond [idx ] >= thres ,
1299+ predictions_segmentation [idx ][:, class_idx ],
1300+ torch .zeros_like (predictions_segmentation [idx ][:, class_idx ]),
1301+ )
1302+ else :
1303+ predictions_segmentation [idx ][:, :, class_idx ] = torch .where (
1304+ cond [idx ] >= thres ,
1305+ predictions_segmentation [idx ][:, :, class_idx ],
1306+ torch .zeros_like (predictions_segmentation [idx ][:, :, class_idx ]),
1307+ )
12911308 else :
1292- predictions_segmentation [:, class_idx ] = torch .where (
1293- cond >= thres ,
1294- predictions_segmentation [:, class_idx ],
1295- torch .zeros_like (predictions_segmentation [:, class_idx ]),
1296- )
1309+ if self .consecutive_slices == 1 or self .dimensionality == 2 :
1310+ predictions_segmentation [:, class_idx ] = torch .where (
1311+ cond >= thres ,
1312+ predictions_segmentation [:, class_idx ],
1313+ torch .zeros_like (predictions_segmentation [:, class_idx ]),
1314+ )
1315+ else :
1316+ predictions_segmentation [:, :, class_idx ] = torch .where (
1317+ cond >= thres ,
1318+ predictions_segmentation [:, :, class_idx ],
1319+ torch .zeros_like (predictions_segmentation [:, :, class_idx ]),
1320+ )
12971321
12981322 # Noise-to-Recon forward pass, if Noise-to-Recon is used.
12991323 predictions_reconstruction_n2r = None
@@ -1310,11 +1334,13 @@ def inference_step( # noqa: MC0001
13101334 # Get acceleration factor from acceleration list, if multiple accelerations are used. Or if batch size > 1.
13111335 if isinstance (acceleration , list ):
13121336 if acceleration [0 ].shape [0 ] > 1 :
1313- acceleration [0 ] = acceleration [0 ][0 ]
1337+ for i in enumerate (acceleration ):
1338+ acceleration [i ] = acceleration [i ][0 ]
13141339 acceleration = np .round (acceleration [r ].item ())
13151340 else :
1316- if acceleration .shape [0 ] > 1 : # type: ignore
1317- acceleration = acceleration [0 ] # type: ignore
1341+ if acceleration [0 ].shape [0 ] > 1 : # type: ignore
1342+ for i in enumerate (acceleration ): # type: ignore
1343+ acceleration [i ] = acceleration [i ][0 ] # type: ignore
13181344 acceleration = np .round (acceleration .item ()) # type: ignore
13191345
13201346 # Pass r to the attrs dictionary, so that it can be used in unnormalize_for_loss_or_log if needed.
0 commit comments