Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
2,676 changes: 2,676 additions & 0 deletions src/cryo_challenge/_map_to_map/alignment/all_to_all_alignment.ipynb

Large diffs are not rendered by default.

343 changes: 343 additions & 0 deletions src/cryo_challenge/_map_to_map/alignment/map_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
import torch
import torch.nn.functional as F
import argparse
import time
import multiprocessing as mp
import logging
import os
from copy import deepcopy

import pymanopt
from pymanopt import Problem
from pymanopt.manifolds import SpecialOrthogonalGroup, Euclidean, Product
from pymanopt.optimizers.conjugate_gradient import ConjugateGradient

from cryo_challenge._preprocessing.fourier_utils import downsample_volume
from cryo_challenge._map_to_map.map_to_map_distance import normalize


def interpolate_volume(volume, rotation, translation, grid):
"""

Notes:
-----
translation is normalized coordinates, since grid is from [-1,+1]. Invariant to n_pix (from downsampling volume)
"""
n_pix = len(volume)
grid = grid @ rotation.T + translation
# Interpolate the 3D array at the grid points
interpolated_volume = F.grid_sample(
volume.reshape(1, 1, n_pix, n_pix, n_pix),
grid[..., [2, 1, 0]],
mode="bilinear",
padding_mode="zeros",
align_corners=True,
).reshape(n_pix, n_pix, n_pix)
return interpolated_volume


def loss_l2(volume_i, volume_j):
return torch.linalg.norm(volume_i - volume_j) ** 2


def prepare_grid(n_pix, torch_dtype):
x = y = z = torch.linspace(-1, 1, n_pix).to(torch_dtype)
xx, yy, zz = torch.meshgrid(x, y, z, indexing="ij")
grid = torch.stack([xx, yy, zz], dim=-1) # Shape: (D, H, W, 3)
# Reshape grid to match the expected input shape for grid_sample
grid = grid.unsqueeze(0) # Add batch dimension, shape: (1, D, H, W, 3)
return grid


def align(volume_i, volume_j):
assert volume_i.shape == volume_j.shape
assert volume_i.ndim == 3

# Generate grid points
torch_dtype = torch.float32
n_pix = len(volume_i)
grid = prepare_grid(n_pix, torch_dtype)

SO3 = SpecialOrthogonalGroup(3)
E3 = Euclidean(3)
SE3 = Product([SO3, E3])

@pymanopt.function.pytorch(SE3)
def loss(rotation, translation):
"""Objective function.

Takes rotation then translation (in that order) because of the product manifold is SO(3) x E(3).
"""
# Apply the rotation and tralsation to the volume
interpolated_volume = interpolate_volume(volume_i, rotation, translation, grid)
# Compute the L2 loss between the two functions
return loss_l2(interpolated_volume, volume_j)

# Define the problem
problem = Problem(manifold=SE3, cost=loss)

# Solve the problem with the custom solver
optimizer = ConjugateGradient(
max_iterations=100,
)

initial_point = (
torch.eye(3).to(torch_dtype).numpy(),
torch.zeros(3).to(torch_dtype).numpy(),
)
result = optimizer.run(problem, initial_point=initial_point)

return result


def parse_args():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument(
"--fname_i", type=str, default=None, help="Input volume set i (submission file)"
)
parser.add_argument(
"--fname_j", type=str, default=None, help="Input volume set j (submission file)"
)
parser.add_argument(
"--n_i", type=int, default=80, help="Number of volumes in set i"
)
parser.add_argument(
"--n_j", type=int, default=80, help="Number of volumes in set j"
)
parser.add_argument(
"--n_cpus",
type=int,
default=mp.cpu_count(),
help="Number of cpus for multiprocessing",
)

parser.add_argument(
"--downsample_box_size",
type=int,
default=32,
help="Box size to downsample volumes to",
)

parser.add_argument(
"--no_normalize",
action="store_false",
help="Disable normalization of volumes (default: True)",
)

parser.add_argument(
"--apply_alignments",
action="store_true",
help="Apply alignments and compute loss before/after (default: False)",
)

return parser.parse_args()


def run_all_by_all_alignment_naive_loop(volumes_i, volumes_j, args):
# fname = "/mnt/home/smbp/ceph/smbpchallenge/round2/set2/processed_submissions/submission_23.pt"
# submission = torch.load(fname, weights_only=False)

# torch_dtype = torch.float32
# volumes = submission["volumes"].to(torch_dtype)
# volumes_i = volumes[: args.n_i]
# volumes_j = volumes[: args.n_j]
if args.no_normalize: # if use flag, no_normalize = False and do not normalize
volumes_i = normalize(volumes_i, "l2")
volumes_j = normalize(volumes_j, "l2")

box_size_ds = args.downsample_box_size

torch_dtype = volumes_i.dtype

size_of_rotation_matrix = (3, 3)
size_of_translation_vector = (3,)
rotations = torch.empty(
(
args.n_i,
args.n_j,
)
+ size_of_rotation_matrix
)
translations = torch.empty(
(
args.n_i,
args.n_j,
)
+ size_of_translation_vector
)
loss_initial = torch.empty(args.n_i, args.n_j)
loss_final = torch.empty(args.n_i, args.n_j)

n_pix = len(volumes_i[0])
grid = prepare_grid(n_pix, torch_dtype)

for idx_i, volume_i in enumerate(volumes_i):
volume_i_ds = downsample_volume(volume_i, box_size_ds)

for idx_j, volume_j in enumerate(volumes_j):
volume_j_ds = downsample_volume(volume_j, box_size_ds)
result = align(volume_i_ds, volume_j_ds)
rotation, translation = result.point
rotations[idx_i, idx_j] = torch.from_numpy(rotation)
translations[idx_i, idx_j] = torch.from_numpy(translation)
volume_i_aligned_to_j = interpolate_volume(
volume_i, rotation, translation, grid
).reshape(n_pix, n_pix, n_pix)
loss_initial[idx_i, idx_j] = loss_l2(volume_i, volume_j)
loss_final[idx_i, idx_j] = loss_l2(volume_i_aligned_to_j, volume_j)

return {
"rotations": rotations,
"translations": translations,
"loss_initial": loss_initial,
"loss_final": loss_final,
}


# Enable logging to debug errors
# logging.basicConfig(level=logging.ERROR)
logging.getLogger("pymanopt").setLevel(logging.ERROR)


# Ensure the multiprocessing context uses 'spawn'
mp.set_start_method("spawn", force=True)


def process_pair(idx_i, idx_j, volume_i, volume_j, box_size_ds):
"""Aligns two volumes and returns the results."""
try:
volume_i = volume_i.clone()
volume_j = volume_j.clone()
logging.info(f"Starting alignment for pair ({idx_i}, {idx_j})")
result = align(volume_i, volume_j)
logging.info(f"Finished alignment for pair ({idx_i}, {idx_j})")
rotation, translation = result.point

return idx_i, idx_j, rotation, translation

except Exception as e:
logging.error(f"Error in alignment for pair ({idx_i}, {idx_j}): {e}")
return idx_i, idx_j, None, None, None, None


def run_all_by_all_alignment_mp(volumes_i, volumes_j, args):
torch_dtype = volumes_i.dtype
assert torch_dtype == volumes_j.dtype
box_size_ds = args.downsample_box_size

volumes_i = deepcopy(volumes_i)
volumes_j = deepcopy(volumes_j)

if args.no_normalize: # if use flag, no_normalize = False and do not normalize
volumes_i = normalize(volumes_i, "l2")
volumes_j = normalize(volumes_j, "l2")

volumes_downsampled_i = torch.empty(
(args.n_i, box_size_ds, box_size_ds, box_size_ds), dtype=torch_dtype
)
volumes_downsampled_j = torch.empty(
(args.n_j, box_size_ds, box_size_ds, box_size_ds), dtype=torch_dtype
)
for i, v in enumerate(volumes_i):
volumes_downsampled_i[i] = downsample_volume(v, box_size_ds)
volumes_downsampled_i[i] /= torch.norm(volumes_downsampled_i[i], keepdim=True)
for j, v in enumerate(volumes_j):
volumes_downsampled_j[j] = downsample_volume(v, box_size_ds)
volumes_downsampled_j[j] /= torch.norm(volumes_downsampled_j[j], keepdim=True)

rotations = torch.empty(len(volumes_i), len(volumes_j), 3, 3)
translations = torch.empty(len(volumes_i), len(volumes_j), 3)

# Prepare arguments for starmap
tasks = []
for idx_i, volume_i in enumerate(volumes_downsampled_i):
for idx_j, volume_j in enumerate(volumes_downsampled_j):
tasks.append(
(idx_i, idx_j, volume_i.clone(), volume_j.clone(), box_size_ds)
)

# Use multiprocessing with starmap
s = time.time()
with mp.Pool(processes=args.n_cpus) as pool:
results = pool.starmap(process_pair, tasks)
e = time.time()
logging.info(f"Time taken: {e-s:.2f}s")

# Store results
for idx_i, idx_j, rotation, translation in results:
if rotation is None:
rotation = torch.nan * torch.empty(3, 3)
translation = torch.nan * torch.empty(3)
rotations[idx_i, idx_j] = torch.from_numpy(rotation)
translations[idx_i, idx_j] = torch.from_numpy(translation)

return {
"rotations": rotations,
"translations": translations,
}


def apply_alignments(volumes, rotations, translations, volumes_j=None):
_I, J = rotations.shape[:2]
assert len(volumes) == _I == translations.shape[0]
n_pix = volumes.shape[-1]
torch_dtype = volumes.dtype
grid = prepare_grid(n_pix, torch_dtype)
interpolated_volumes_i_to_j = torch.empty(_I, J, n_pix, n_pix, n_pix)
loss_initial = torch.empty(_I, J)
loss_final = torch.empty(_I, J)
for idx_i, volume_i in enumerate(volumes):
for idx_j in range(J):
rotation_ij = rotations[idx_i, idx_j]
translation_ij = translations[idx_i, idx_j]
interpolated_volume_i_to_j = interpolate_volume(
volume_i, rotation_ij, translation_ij, grid
).reshape(*volume_i.shape)

if volumes_j is not None:
volume_j = volumes_j[idx_j]
loss_initial[idx_i, idx_j] = loss_l2(volume_i, volume_j)
loss_final[idx_i, idx_j] = loss_l2(interpolated_volume_i_to_j, volume_j)

interpolated_volumes_i_to_j[idx_i, idx_j] = interpolated_volume_i_to_j

return interpolated_volumes_i_to_j, loss_initial, loss_final


if __name__ == "__main__":
args = parse_args()

fname_i = args.fname_i # "/mnt/home/smbp/ceph/smbpchallenge/round2/set2/processed_submissions/submission_23.pt"
submission = torch.load(fname_i, weights_only=False)

torch_dtype = torch.float32
volumes = submission["volumes"].to(torch_dtype)
volumes_i = volumes[: args.n_i]

fname_j = args.fname_j # "/mnt/home/smbp/ceph/smbpchallenge/round2/set2/processed_submissions/submission_23.pt"
submission = torch.load(fname_j, weights_only=False)
volumes = submission["volumes"].to(torch_dtype)
volumes_j = volumes[: args.n_j]

results = run_all_by_all_alignment_mp(volumes_i, volumes_j, args)
rotations = results["rotations"]
translations = results["translations"]

if args.apply_alignments:
(
results["interpolated_volumes_i_to_j"],
results["loss_initial"],
results["loss_final"],
) = apply_alignments(volumes_i, rotations, translations, volumes_j)

odir = "/mnt/home/gwoollard/ceph/repos/Cryo-EM-Heterogeneity-Challenge-1/src/cryo_challenge/_map_to_map/alignment/"
basename_without_extension_i = os.path.splitext(os.path.basename(fname_i))[0]
basename_without_extension_j = os.path.splitext(os.path.basename(fname_j))[0]

torch.save(
results,
os.path.join(
odir,
f"alignments_se3_ni{args.n_i}_nj{args.n_j}_ds{args.downsample_box_size}_ConjugateGradient_{basename_without_extension_i}-vs-{basename_without_extension_j}.pt",
),
)
23 changes: 23 additions & 0 deletions src/cryo_challenge/_map_to_map/alignment/run_torchmp_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
#SBATCH --job-name=torchmp_benchmark
#SBATCH --output=slurm/logs/%j.out
#SBATCH --error=slurm/logs/%j.err
#SBATCH --partition=ccb
#SBATCH -n 1
#SBATCH -c 128
#SBATCH --time=99:00:00

# Define parameter ranges
n_pix_values=(244)
nn_values=(1)
num_workers_values=(16 32 64 128)

# Loop through all combinations of k, nn, and num_workers
for n_pix in "${n_pix_values[@]}"; do
for nn in "${nn_values[@]}"; do
for num_workers in "${num_workers_values[@]}"; do
echo "Running with n_pix=$n_pix, nn=$nn, num_workers=$num_workers"
python torch_mp.py +n_pix=$n_pix +nn=$nn +num_workers=$num_workers > torchmp_benchmark_npix_${n_pix}_nn_${nn}_num_workers_${num_workers}.log
done
done
done
Loading
Loading