Skip to content

Commit 1e76de8

Browse files
committed
* fixes consecutive slices 2d & 3D settings for mtlrs
1 parent 1abb819 commit 1e76de8

File tree

6 files changed

+145
-107
lines changed

6 files changed

+145
-107
lines changed

atommic/collections/multitask/rs/data/mrirs_loader.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -409,26 +409,32 @@ def __getitem__(self, i: int): # noqa: MC0001
409409
kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64)
410410

411411
if not is_none(dataset_format) and dataset_format == "skm-tea-echo1":
412-
kspace = kspace[:, :, 0, :]
412+
kspace = kspace[..., 0, :]
413413
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo2":
414-
kspace = kspace[:, :, 1, :]
414+
kspace = kspace[..., 1, :]
415415
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2":
416-
kspace = kspace[:, :, 0, :] + kspace[:, :, 1, :]
416+
kspace = kspace[..., 0, :] + kspace[..., 1, :]
417417
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2-mc":
418-
kspace = np.concatenate([kspace[:, :, 0, :], kspace[:, :, 1, :]], axis=-1)
418+
kspace = np.concatenate([kspace[..., 0, :], kspace[..., 1, :]], axis=-1)
419419
else:
420420
warnings.warn(
421421
f"Dataset format {dataset_format} is either not supported or set to None. "
422422
"Using by default only the first echo."
423423
)
424-
kspace = kspace[:, :, 0, :]
425-
426-
kspace = kspace[48:-48, 40:-40]
424+
kspace = kspace[..., 0, :]
427425

428426
sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64)
429427
sensitivity_map = sensitivity_map[..., 0]
430428

431-
sensitivity_map = sensitivity_map[48:-48, 40:-40]
429+
if dataset_format == "skm-tea-echo1+echo2-mc":
430+
sensitivity_map = np.concatenate([sensitivity_map, sensitivity_map], axis=-1)
431+
432+
if self.consecutive_slices > 1:
433+
sensitivity_map = sensitivity_map[:, 48:-48, 40:-40]
434+
kspace = kspace[:, 48:-48, 40:-40]
435+
else:
436+
sensitivity_map = sensitivity_map[48:-48, 40:-40]
437+
kspace = kspace[48:-48, 40:-40]
432438

433439
if masking == "custom":
434440
mask = np.array([])
@@ -484,12 +490,11 @@ def __getitem__(self, i: int): # noqa: MC0001
484490
# TODO: This is hardcoded on the SKM-TEA side, how to generalize this?
485491
# We need to crop the segmentation labels in the frequency domain to reduce the FOV.
486492
segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels))
487-
segmentation_labels = segmentation_labels[:, 48:-48, 40:-40]
493+
segmentation_labels = segmentation_labels[..., 48:-48, 40:-40]
488494
segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels)).real
489495
segmentation_labels = np.where(segmentation_labels > 0.5, 1.0, 0.0) # Make sure the labels are binary.
490496

491497
imspace = np.empty([])
492-
493498
initial_prediction = np.empty([])
494499
attrs = dict(hf.attrs)
495500

@@ -501,8 +506,12 @@ def __getitem__(self, i: int): # noqa: MC0001
501506

502507
attrs.update(metadata)
503508

504-
kspace = np.transpose(kspace, (2, 0, 1))
505-
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1))
509+
if self.consecutive_slices == 1:
510+
kspace = np.transpose(kspace, (2, 0, 1))
511+
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1))
512+
else:
513+
kspace = np.transpose(kspace, (0, 3, 1, 2))
514+
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (0, 3, 1, 2))
506515

507516
attrs["log_image"] = bool(dataslice in self.indices_to_log)
508517

atommic/collections/multitask/rs/nn/base.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,9 @@ def __compute_loss__(
726726
if self.consecutive_slices > 1:
727727
batch_size, slices = target_segmentation.shape[:2]
728728
target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:])
729+
predictions_segmentation = predictions_segmentation.reshape( # type: ignore
730+
batch_size * slices, *predictions_segmentation.shape[2:] # type: ignore
731+
)
729732

730733
segmentation_loss = self.process_segmentation_loss(target_segmentation, predictions_segmentation, attrs)
731734

@@ -798,26 +801,6 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001
798801
while isinstance(predictions_segmentation, list):
799802
predictions_segmentation = predictions_segmentation[-1]
800803

801-
if self.consecutive_slices > 1:
802-
# reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y]
803-
batch_size = target_segmentation.shape[0] // self.consecutive_slices
804-
target_segmentation = target_segmentation.reshape(
805-
batch_size, self.consecutive_slices, *target_segmentation.shape[1:]
806-
)
807-
target_reconstruction = target_reconstruction.reshape(
808-
batch_size, self.consecutive_slices, *target_reconstruction.shape[2:]
809-
)
810-
predictions_segmentation = predictions_segmentation.reshape(
811-
batch_size, self.consecutive_slices, *predictions_segmentation.shape[2:]
812-
)
813-
predictions_reconstruction = predictions_reconstruction.reshape(
814-
batch_size, self.consecutive_slices, *predictions_reconstruction.shape[1:]
815-
)
816-
target_segmentation = target_segmentation[:, self.consecutive_slices // 2]
817-
target_reconstruction = target_reconstruction[:, self.consecutive_slices // 2]
818-
predictions_segmentation = predictions_segmentation[:, self.consecutive_slices // 2]
819-
predictions_reconstruction = predictions_reconstruction[:, self.consecutive_slices // 2]
820-
821804
fname = attrs["fname"]
822805
slice_idx = attrs["slice_idx"]
823806

@@ -1267,33 +1250,74 @@ def inference_step( # noqa: MC0001
12671250
for class_idx, thres in enumerate(self.segmentation_classes_thresholds):
12681251
if self.segmentation_activation == "sigmoid":
12691252
if isinstance(predictions_segmentation, list):
1270-
cond = [torch.sigmoid(pred[:, class_idx]) for pred in predictions_segmentation]
1253+
cond = [
1254+
torch.sigmoid(pred[:, class_idx])
1255+
if (self.consecutive_slices == 1 or self.dimensionality == 2)
1256+
else torch.sigmoid(pred[:, :, class_idx])
1257+
for pred in predictions_segmentation
1258+
]
12711259
else:
1272-
cond = torch.sigmoid(predictions_segmentation[:, class_idx])
1260+
cond = (
1261+
torch.sigmoid(predictions_segmentation[:, class_idx])
1262+
if (self.consecutive_slices == 1 or self.dimensionality == 2)
1263+
else torch.sigmoid(predictions_segmentation[:, :, class_idx])
1264+
)
12731265
elif self.segmentation_activation == "softmax":
12741266
if isinstance(predictions_segmentation, list):
1275-
cond = [torch.softmax(pred[:, class_idx], dim=1) for pred in predictions_segmentation]
1267+
cond = [
1268+
torch.softmax(pred[:, class_idx], dim=1)
1269+
if (self.consecutive_slices == 1 or self.dimensionality == 2)
1270+
else torch.softmax(pred[:, :, class_idx], dim=1)
1271+
for pred in predictions_segmentation
1272+
]
12761273
else:
1277-
cond = torch.softmax(predictions_segmentation[:, class_idx], dim=1)
1274+
cond = (
1275+
torch.softmax(predictions_segmentation[:, class_idx], dim=1)
1276+
if (self.consecutive_slices == 1 or self.dimensionality == 2)
1277+
else torch.softmax(predictions_segmentation[:, :, class_idx], dim=1)
1278+
)
12781279
else:
12791280
if isinstance(predictions_segmentation, list):
1280-
cond = [pred[:, class_idx] for pred in predictions_segmentation]
1281+
cond = [
1282+
pred[:, class_idx]
1283+
if self.consecutive_slices == 1 or self.dimensionality == 2
1284+
else pred[:, :, class_idx]
1285+
for pred in predictions_segmentation
1286+
]
12811287
else:
1282-
cond = predictions_segmentation[:, class_idx]
1288+
cond = (
1289+
predictions_segmentation[:, class_idx]
1290+
if (self.consecutive_slices == 1 or self.dimensionality == 2)
1291+
else predictions_segmentation[:, :, class_idx]
1292+
)
12831293

12841294
if isinstance(predictions_segmentation, list):
12851295
for idx, pred in enumerate(predictions_segmentation):
1286-
predictions_segmentation[idx][:, class_idx] = torch.where(
1287-
cond[idx] >= thres,
1288-
predictions_segmentation[idx][:, class_idx],
1289-
torch.zeros_like(predictions_segmentation[idx][:, class_idx]),
1290-
)
1296+
if self.consecutive_slices == 1 or self.dimensionality == 2:
1297+
predictions_segmentation[idx][:, class_idx] = torch.where(
1298+
cond[idx] >= thres,
1299+
predictions_segmentation[idx][:, class_idx],
1300+
torch.zeros_like(predictions_segmentation[idx][:, class_idx]),
1301+
)
1302+
else:
1303+
predictions_segmentation[idx][:, :, class_idx] = torch.where(
1304+
cond[idx] >= thres,
1305+
predictions_segmentation[idx][:, :, class_idx],
1306+
torch.zeros_like(predictions_segmentation[idx][:, :, class_idx]),
1307+
)
12911308
else:
1292-
predictions_segmentation[:, class_idx] = torch.where(
1293-
cond >= thres,
1294-
predictions_segmentation[:, class_idx],
1295-
torch.zeros_like(predictions_segmentation[:, class_idx]),
1296-
)
1309+
if self.consecutive_slices == 1 or self.dimensionality == 2:
1310+
predictions_segmentation[:, class_idx] = torch.where(
1311+
cond >= thres,
1312+
predictions_segmentation[:, class_idx],
1313+
torch.zeros_like(predictions_segmentation[:, class_idx]),
1314+
)
1315+
else:
1316+
predictions_segmentation[:, :, class_idx] = torch.where(
1317+
cond >= thres,
1318+
predictions_segmentation[:, :, class_idx],
1319+
torch.zeros_like(predictions_segmentation[:, :, class_idx]),
1320+
)
12971321

12981322
# Noise-to-Recon forward pass, if Noise-to-Recon is used.
12991323
predictions_reconstruction_n2r = None
@@ -1310,11 +1334,13 @@ def inference_step( # noqa: MC0001
13101334
# Get acceleration factor from acceleration list, if multiple accelerations are used. Or if batch size > 1.
13111335
if isinstance(acceleration, list):
13121336
if acceleration[0].shape[0] > 1:
1313-
acceleration[0] = acceleration[0][0]
1337+
for i in enumerate(acceleration):
1338+
acceleration[i] = acceleration[i][0]
13141339
acceleration = np.round(acceleration[r].item())
13151340
else:
1316-
if acceleration.shape[0] > 1: # type: ignore
1317-
acceleration = acceleration[0] # type: ignore
1341+
if acceleration[0].shape[0] > 1: # type: ignore
1342+
for i in enumerate(acceleration): # type: ignore
1343+
acceleration[i] = acceleration[i][0] # type: ignore
13181344
acceleration = np.round(acceleration.item()) # type: ignore
13191345

13201346
# Pass r to the attrs dictionary, so that it can be used in unnormalize_for_loss_or_log if needed.

atommic/collections/multitask/rs/nn/mtlrs.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def forward(
160160
init_reconstruction_pred = pred_reconstruction[-1][-1]
161161

162162
if self.task_adaption_type == "multi_task_learning":
163+
# this is the most important step of the mtlrs model, where the segmentation mask is applied to the
164+
# hidden states and updates them
163165
hidden_states = [
164166
torch.cat(
165167
[torch.abs(init_reconstruction_pred.unsqueeze(self.coil_dim) * pred_segmentation)]
@@ -170,21 +172,17 @@ def forward(
170172
if f != 0
171173
]
172174

173-
if self.consecutive_slices > 1:
174-
hx = [x.unsqueeze(1) for x in hx]
175-
176175
# Check if the concatenated hidden states are the same size as the hidden state of the RNN
177-
if hidden_states[0].shape[self.coil_dim] != hx[0].shape[self.coil_dim]:
178-
prev_hidden_states = hidden_states
179-
hidden_states = []
180-
for hs in prev_hidden_states:
181-
new_hidden_state = hs
182-
for _ in range(hx[0].shape[1] - prev_hidden_states[0].shape[1]):
183-
new_hidden_state = torch.cat(
184-
[new_hidden_state, torch.zeros_like(hx[0][:, 0, :, :]).unsqueeze(self.coil_dim)],
185-
dim=self.coil_dim,
186-
)
187-
hidden_states.append(new_hidden_state)
176+
if (
177+
hidden_states[0].shape[self.coil_dim if self.consecutive_slices == 1 else self.coil_dim - 1]
178+
!= hx[0].shape[self.coil_dim if self.consecutive_slices == 1 else self.coil_dim - 1]
179+
):
180+
hidden_states = [
181+
hidden_states[i].reshape(
182+
hidden_states[0].shape[0] * self.consecutive_slices, *hidden_states[i].shape[2:]
183+
)
184+
for i in range(len(hidden_states))
185+
]
188186

189187
hx = [hx[i] + hidden_states[i] for i in range(len(hx))]
190188

atommic/collections/multitask/rs/nn/mtlrs_base/mtlrs_block.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ def __init__(
8282
self.spatial_dims = spatial_dims
8383
self.coil_dim = coil_dim
8484
self.dimensionality = dimensionality
85-
if self.dimensionality != 2:
86-
raise NotImplementedError(f"Currently only 2D is supported for segmentation, got {self.dimensionality}D.")
8785
self.consecutive_slices = consecutive_slices
8886
self.coil_combination_method = coil_combination_method
8987

@@ -237,7 +235,6 @@ def forward( # noqa: MC0001
237235
)
238236
cascades_predictions = []
239237
for i, cascade in enumerate(self.reconstruction_module):
240-
# Forward pass through the cascades
241238
prediction_slice, hx = cascade(
242239
prediction_slice,
243240
y_slice,
@@ -248,6 +245,11 @@ def forward( # noqa: MC0001
248245
sigma,
249246
keep_prediction=False if i == 0 else self.keep_prediction,
250247
)
248+
if (prediction_slice[0].shape[0] == self.consecutive_slices) and (
249+
hx[0].shape[0] == self.consecutive_slices
250+
):
251+
prediction_slice = [pred[slice_idx].unsqueeze(0) for pred in prediction_slice]
252+
hx = [h[slice_idx].unsqueeze(0) for h in hx]
251253
time_steps_predictions = [torch.view_as_complex(pred) for pred in prediction_slice]
252254
cascades_predictions.append(torch.stack(time_steps_predictions, dim=0))
253255
pred_reconstruction_slices.append(torch.stack(cascades_predictions, dim=0))
@@ -267,6 +269,8 @@ def forward( # noqa: MC0001
267269
if init_reconstruction_pred is None or init_reconstruction_pred.dim() < 4
268270
else init_reconstruction_pred
269271
)
272+
if self.consecutive_slices > 1 and self.reconstruction_module_dimensionality == 3:
273+
mask = torch.concatenate([mask.unsqueeze(self.coil_dim)] * y.shape[1], 1)
270274
sigma = 1.0
271275
cascades_predictions = []
272276
for i, cascade in enumerate(self.reconstruction_module):
@@ -292,11 +296,13 @@ def forward( # noqa: MC0001
292296
_pred_reconstruction = _pred_reconstruction[-1]
293297
if _pred_reconstruction.shape[-1] != 2:
294298
_pred_reconstruction = torch.view_as_real(_pred_reconstruction)
299+
295300
if self.consecutive_slices > 1 and _pred_reconstruction.dim() == 5:
296301
_pred_reconstruction = _pred_reconstruction.reshape(
297302
_pred_reconstruction.shape[0] * _pred_reconstruction.shape[1],
298303
*_pred_reconstruction.shape[2:],
299304
)
305+
300306
if _pred_reconstruction.shape[-1] == 2:
301307
if self.input_channels == 1:
302308
_pred_reconstruction = torch.view_as_complex(_pred_reconstruction).unsqueeze(1)

0 commit comments

Comments
 (0)