Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class

**.DS_Store
# C extensions
*.so

Expand Down
4 changes: 3 additions & 1 deletion atommic/collections/common/data/mri_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__( # noqa: MC0001
self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols]

self.indices_to_log = np.random.choice(
len(self.examples), int(log_images_rate * len(self.examples)), replace=False # type: ignore
[example[1] for example in self.examples],
int(log_images_rate * len(self.examples)), # type: ignore
replace=False,
)

def _retrieve_metadata(self, fname: Union[str, Path]) -> Tuple[Dict, int]:
Expand Down
48 changes: 29 additions & 19 deletions atommic/collections/multitask/rs/data/mrirs_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,19 +416,22 @@
kspace = kspace[:, :, 0, :] + kspace[:, :, 1, :]
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1+echo2-mc":
kspace = np.concatenate([kspace[:, :, 0, :], kspace[:, :, 1, :]], axis=-1)
elif not is_none(dataset_format) and dataset_format == "skm-tea-echo1-echo2":
kspace = kspace
else:
warnings.warn(
f"Dataset format {dataset_format} is either not supported or set to None. "
"Using by default only the first echo."
)
kspace = kspace[:, :, 0, :]

kspace = kspace[48:-48, 40:-40]

sensitivity_map = self.get_consecutive_slices(hf, "maps", dataslice).astype(np.complex64)
sensitivity_map = sensitivity_map[..., 0]

sensitivity_map = sensitivity_map[48:-48, 40:-40]
if self.consecutive_slices > 1:
sensitivity_map = sensitivity_map[:, 48:-48, 40:-40]
kspace = kspace[:, 48:-48, 40:-40]
else:
sensitivity_map = sensitivity_map[48:-48, 40:-40]
kspace = kspace[48:-48, 40:-40]

if masking == "custom":
mask = np.array([])
Expand Down Expand Up @@ -470,22 +473,17 @@
# combine Lateral Meniscus and Medial Meniscus
medial_meniscus = lateral_meniscus + medial_meniscus

if self.consecutive_slices > 1:
segmentation_labels_dim = 1
else:
segmentation_labels_dim = 0

# stack the labels in the last dimension
segmentation_labels = np.stack(
[patellar_cartilage, femoral_cartilage, tibial_cartilage, medial_meniscus],
axis=segmentation_labels_dim,
axis=-1,
)

# TODO: This is hardcoded on the SKM-TEA side, how to generalize this?
# We need to crop the segmentation labels in the frequency domain to reduce the FOV.
segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels))
segmentation_labels = segmentation_labels[:, 48:-48, 40:-40]
segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels)).real
segmentation_labels = np.fft.fftshift(np.fft.fft2(segmentation_labels, axes=(-3, -2)))
segmentation_labels = segmentation_labels[..., 48:-48, 40:-40, :]
segmentation_labels = np.fft.ifft2(np.fft.ifftshift(segmentation_labels), axes=(-3, -2)).real

imspace = np.empty([])

Expand All @@ -499,12 +497,24 @@
metadata["noise"] = 1.0

attrs.update(metadata)

kspace = np.transpose(kspace, (2, 0, 1))
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1))

if not is_none(dataset_format) and dataset_format == "skm-tea-echo1-echo2":
if self.consecutive_slices > 1:
segmentation_labels = np.transpose(segmentation_labels, (0, 3, 1, 2))
kspace = np.transpose(kspace, (3, 0, 4, 1, 2))
sensitivity_map = np.transpose(sensitivity_map, (4, 0, 3, 1, 2))
else:
segmentation_labels = np.transpose(segmentation_labels, (2, 0, 1))
kspace = np.transpose(kspace, (2, 3, 0, 1))
sensitivity_map = np.transpose(sensitivity_map, (3, 2, 0, 1))
elif self.consecutive_slices > 1 and not is_none(dataset_format) and dataset_format != "skm-tea-echo1-echo2":
segmentation_labels = np.transpose(segmentation_labels, (0, 3, 1, 2))
kspace = np.transpose(kspace, (0, 3, 1, 2))
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (0, 3, 1, 2))
else:
segmentation_labels = np.transpose(segmentation_labels, (2, 0, 1))
kspace = np.transpose(kspace, (2, 0, 1))
sensitivity_map = np.transpose(sensitivity_map.squeeze(), (2, 0, 1))
attrs["log_image"] = bool(dataslice in self.indices_to_log)

return (
(
kspace,
Expand Down
142 changes: 100 additions & 42 deletions atommic/collections/multitask/rs/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
# Initialize the dimensionality of the data. It can be 2D or 2.5D -> meaning 2D with > 1 slices or 3D.
self.dimensionality = cfg_dict.get("dimensionality", 2)
self.consecutive_slices = cfg_dict.get("consecutive_slices", 1)

self.num_echoes = cfg_dict.get("num_echoes", 1)
# Initialize the coil combination method. It can be either "SENSE" or "RSS" (root-sum-of-squares) or
# "RSS-complex" (root-sum-of-squares of the complex-valued data).
self.coil_combination_method = cfg_dict.get("coil_combination_method", "SENSE")
Expand Down Expand Up @@ -601,9 +601,6 @@
If self.accumulate_loss is True, returns an accumulative result of all intermediate losses.
Otherwise, returns the loss of the last intermediate loss.
"""
if self.consecutive_slices > 1:
batch_size, slices = target_segmentation.shape[:2]
target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:])

segmentation_loss = self.process_segmentation_loss(target_segmentation, predictions_segmentation, attrs)

Expand Down Expand Up @@ -675,27 +672,31 @@
if isinstance(predictions_segmentation, list):
while isinstance(predictions_segmentation, list):
predictions_segmentation = predictions_segmentation[-1]

if self.consecutive_slices > 1:
# reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y]
batch_size = target_segmentation.shape[0] // self.consecutive_slices
target_segmentation = target_segmentation.reshape(
batch_size, self.consecutive_slices, *target_segmentation.shape[1:]
)
target_reconstruction = target_reconstruction.reshape(
batch_size, self.consecutive_slices, *target_reconstruction.shape[2:]
)
batch_size = int(target_segmentation.shape[0] / self.consecutive_slices)
predictions_segmentation = predictions_segmentation.reshape(
batch_size, self.consecutive_slices, *predictions_segmentation.shape[2:]
batch_size, self.consecutive_slices, *predictions_segmentation.shape[1:]
)
predictions_reconstruction = predictions_reconstruction.reshape(
batch_size, self.consecutive_slices, *predictions_reconstruction.shape[1:]
target_segmentation = target_segmentation.reshape(
batch_size, self.consecutive_slices, *target_segmentation.shape[1:]
)

target_segmentation = target_segmentation[:, self.consecutive_slices // 2]
target_reconstruction = target_reconstruction[:, self.consecutive_slices // 2]
predictions_segmentation = predictions_segmentation[:, self.consecutive_slices // 2]
predictions_reconstruction = predictions_reconstruction[:, self.consecutive_slices // 2]

if self.num_echoes > 1:
# find the batch size
batch_size = target_reconstruction.shape[0] / self.num_echoes
# reshape to [batch_size, num_echoes, n_x, n_y]
target_reconstruction = target_reconstruction.reshape(
(int(batch_size), self.num_echoes, *target_reconstruction.shape[1:])
)
predictions_reconstruction = predictions_reconstruction.reshape(
(int(batch_size), self.num_echoes, *predictions_reconstruction.shape[1:])
)
fname = attrs["fname"]
slice_idx = attrs["slice_idx"]

Expand Down Expand Up @@ -734,11 +735,6 @@
batch_idx=_batch_idx_,
)

output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu()
output_target_reconstruction = output_target_reconstruction.detach().cpu()
output_target_segmentation = output_target_segmentation.detach().cpu()
output_predictions_segmentation = output_predictions_segmentation.detach().cpu()

# Normalize target and predictions to [0, 1] for logging.
if torch.is_complex(output_target_reconstruction) and output_target_reconstruction.shape[-1] != 2:
output_target_reconstruction = torch.view_as_real(output_target_reconstruction)
Expand All @@ -747,7 +743,6 @@
output_target_reconstruction = output_target_reconstruction / torch.max(
torch.abs(output_target_reconstruction)
)
output_target_reconstruction = output_target_reconstruction.detach().cpu()

if (
torch.is_complex(output_predictions_reconstruction)
Expand All @@ -759,7 +754,11 @@
output_predictions_reconstruction = output_predictions_reconstruction / torch.max(
torch.abs(output_predictions_reconstruction)
)
output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu()

output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu().float()
output_target_reconstruction = output_target_reconstruction.detach().cpu().float()
output_target_segmentation = output_target_segmentation.detach().cpu().float()
output_predictions_segmentation = output_predictions_segmentation.detach().cpu().float()

# Log target and predictions, if log_image is True for this slice.
if attrs["log_image"][_batch_idx_]:
Expand All @@ -772,17 +771,33 @@
)

if self.use_reconstruction_module:
self.log_image(
f"{key}/a/reconstruction/target/predictions/error",
torch.cat(
[
output_target_reconstruction,
output_predictions_reconstruction,
torch.abs(output_target_reconstruction - output_predictions_reconstruction),
],
dim=-1,
),
)
if self.num_echoes > 1:
for i in range(output_target_reconstruction.shape[0]):
self.log_image(
f"{key}/a/reconstruction_abs/target echo: {i+1}/predictions echo: {i+1}/error echo: {i+1}",
torch.cat(
[
output_target_reconstruction[i],
output_predictions_reconstruction[i],
torch.abs(
output_target_reconstruction[i] - output_predictions_reconstruction[i]
),
],
dim=-1,
),
)
else:
self.log_image(
f"{key}/a/reconstruction_abs/target/predictions/error",
torch.cat(
[
output_target_reconstruction,
output_predictions_reconstruction,
torch.abs(output_target_reconstruction - output_predictions_reconstruction),
],
dim=-1,
),
)

# concatenate the segmentation classes for logging
target_segmentation_class = torch.cat(
Expand Down Expand Up @@ -1120,7 +1135,16 @@
self.coil_combination_method,
self.coil_dim,
)

if self.num_echoes > 1:
# stack the echoes along the batch dimension
kspace = kspace.view(-1, *kspace.shape[2:])
y = y.view(-1, *y.shape[2:])
mask = mask.view(-1, *mask.shape[2:])
initial_prediction_reconstruction = initial_prediction_reconstruction.view(
-1, *initial_prediction_reconstruction.shape[2:]
)
target_reconstruction = target_reconstruction.view(-1, *target_reconstruction.shape[2:])
sensitivity_maps = torch.repeat_interleave(sensitivity_maps, repeats=kspace.shape[0], dim=0).squeeze(1)
# Model forward pass
predictions_reconstruction, predictions_segmentation = self.forward(
y,
Expand All @@ -1130,6 +1154,19 @@
target_reconstruction,
attrs["noise"],
)
if self.consecutive_slices > 1:
## reshape the target and prediction segmentation to [batch_size * consecutive_slices, nr_classes, n_x, n_y]
batch_size, slices = target_segmentation.shape[:2]
target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:])
if isinstance(predictions_segmentation, list):
for i, prediction_segmentation in enumerate(predictions_segmentation):
predictions_segmentation[i] = prediction_segmentation.reshape(
batch_size * slices, *prediction_segmentation.shape[2:]
)
else:
predictions_segmentation = predictions_segmentation.reshape(
batch_size * slices, *predictions_segmentation.shape[2:]
)

if not is_none(self.segmentation_classes_thresholds):
for class_idx, thres in enumerate(self.segmentation_classes_thresholds):
Expand Down Expand Up @@ -1482,6 +1519,26 @@
while isinstance(predictions_reconstruction, list):
predictions_reconstruction = predictions_reconstruction[-1]

if self.consecutive_slices > 1:
# reshape the target and prediction to [batch_size, self.consecutive_slices, nr_classes, n_x, n_y]
batch_size = int(target_segmentation.shape[0] / self.consecutive_slices)
predictions_segmentation = predictions_segmentation.reshape(
batch_size, self.consecutive_slices, *predictions_segmentation.shape[1:]
)
predictions_segmentation = predictions_segmentation[:, self.consecutive_slices // 2]
predictions_reconstruction = predictions_reconstruction[:, self.consecutive_slices // 2]

if self.num_echoes > 1:
# find the batch size
batch_size = target_reconstruction.shape[0] / self.num_echoes
# reshape to [batch_size, num_echoes, n_x, n_y]
target_reconstruction = target_reconstruction.reshape(
(int(batch_size), self.num_echoes, *target_reconstruction.shape[1:])
)
predictions_reconstruction = predictions_reconstruction.reshape(
(int(batch_size), self.num_echoes, *predictions_reconstruction.shape[1:])
)

# If "16" or "16-mixed" fp is used, ensure complex type will be supported when saving the predictions.
predictions_reconstruction = (
torch.view_as_complex(torch.view_as_real(predictions_reconstruction).type(torch.float32))
Expand Down Expand Up @@ -1670,10 +1727,10 @@
for fname in segmentations:
segmentations[fname] = np.stack([out for _, out in sorted(segmentations[fname])])

if self.consecutive_slices > 1:
# iterate over the slices and always keep the middle slice
for fname in segmentations:
segmentations[fname] = segmentations[fname][:, self.consecutive_slices // 2]
# if self.consecutive_slices > 1:
# # iterate over the slices and always keep the middle slice
# for fname in segmentations:
# segmentations[fname] = segmentations[fname][:, self.consecutive_slices // 2] #TODO remove, is already done in the test_step to minimize memory load

if self.use_reconstruction_module:
reconstructions = defaultdict(list)
Expand All @@ -1684,10 +1741,10 @@
for fname in reconstructions:
reconstructions[fname] = np.stack([out for _, out in sorted(reconstructions[fname])])

if self.consecutive_slices > 1:
# iterate over the slices and always keep the middle slice
for fname in reconstructions:
reconstructions[fname] = reconstructions[fname][:, self.consecutive_slices // 2]
# if self.consecutive_slices > 1: #TODO remove, is already done in the test_step to minimize memory load
# # iterate over the slices and always keep the middle slice
# for fname in reconstructions:
# reconstructions[fname] = reconstructions[fname][:, self.consecutive_slices // 2]
else:
reconstructions = None

Expand Down Expand Up @@ -1752,6 +1809,7 @@
"skm-tea-echo2",
"skm-tea-echo1+echo2",
"skm-tea-echo1+echo2-mc",
"skm-tea-echo1-echo2",
):
dataloader = mrirs_loader.SKMTEARSMRIDataset
else:
Expand Down
Loading
Loading