diff --git a/.gitignore b/.gitignore index c996a7c0..573745f9 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ ProteinMPNN/ ProteinMPNN/* ProteinMPNN/examples ProteinMPNN/outputs -ProteinMPNN/.gitignore \ No newline at end of file +ProteinMPNN/.gitignore +user_setting.py \ No newline at end of file diff --git a/README.md b/README.md index ddd0a9ff..c3f58c03 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ sequences. CUDA-enabled GPU (required for AlphaFold) Modeller (requires a license key) Git + git-lfs (for large files) ``` Follow these steps to set up PMGen on your system: @@ -32,6 +33,8 @@ Follow these steps to set up PMGen on your system: ```bash git clone https://github.com/AmirAsgary/PMGen.git cd PMGen + conda init bash + source ~/.bashrc bash -l install.sh conda activate PMGen ``` @@ -43,7 +46,9 @@ Follow these steps to set up PMGen on your system: Edit `user_setting.py` to adjust netMHCpan and netMHCIIpan installation paths. - +2. **Tip for installing NetMHCpan**: + - NetMHCpan requires tcsh to be installed. You can install it using `sudo apt-get install tcsh`. + - read the readme file in `netMHCpan` folder and follow instructions. ## Usage diff --git a/data/ESM/esmc_600m/PMGen sequences/process.py b/data/ESM/esmc_600m/PMGen sequences/process.py new file mode 100644 index 00000000..e8ea838a --- /dev/null +++ b/data/ESM/esmc_600m/PMGen sequences/process.py @@ -0,0 +1,512 @@ +# load npz files and print the shape of the data +import numpy as np +import pandas as pd +import os + +def load_npz_files(directory): + keys = [] + embeddings = [] + sequences = [] + # load sequence from the csv file + df = pd.read_csv(os.path.join(directory, 'sequences.csv')) + for filename in os.listdir(directory): + if filename.endswith('.npz'): + print("Loading file:", filename) + file_path = os.path.join(directory, filename) + with np.load(file_path) as npz_file: + print(f"{filename}: contains {len(npz_file.files)} arrays") + for key in npz_file.files: + embedding = npz_file[key] + if embedding.shape[0] == 187 and embedding.shape[1] == 1152: + embeddings.append(embedding) + sequences.append(df[df['id'] == key]['sequence'].values[0]) + else: + print(f"Skipping {key} due to unexpected shape: {embedding.shape}") + + keys.append(key) + + return embeddings, keys, sequences + +def print_shapes(data, keys, seqs): + for arr, key, seq in zip(data, keys, seqs): + print(f"Shape of array {key}: {arr.shape} with sequence {seq}") + +def print_samples(data, keys): + for arr, key in zip(data, keys): + print(f"Sample of array {key}: {arr[:1]}") # Print first samples + +def plot_pairwise_comparison(selected_keys, keys, data, seqs): + """ + Compare sequence and embedding similarities between selected pairs of protein sequences. + + Args: + selected_keys: List of key pairs to compare (e.g., [(key1, key2), (key3, key4)]) + keys: All available keys + data: Embedding data (36, 1152) + seqs: Protein sequences + """ + import matplotlib.pyplot as plt + import numpy as np + from scipy.spatial.distance import cosine + import seaborn as sns + from Levenshtein import distance as levenshtein_distance + + for key1, key2 in selected_keys: + # Get indices of the keys in the main arrays + idx1 = keys.index(key1) + idx2 = keys.index(key2) + + # Get sequences and embeddings + seq1, seq2 = seqs[idx1], seqs[idx2] + emb1, emb2 = data[idx1], data[idx2] + + # Calculate sequence similarity + seq_distance = levenshtein_distance(seq1, seq2) + seq_similarity = 1 - (seq_distance / max(len(seq1), len(seq2))) + + # Calculate embedding similarity metrics + cosine_sim = 1 - cosine(emb1.flatten(), emb2.flatten()) + euclidean_dist = np.linalg.norm(emb1 - emb2) + + # Calculate positional differences + position_diff = np.abs(emb1 - emb2) # Shape: (36, 1152) + mean_diff_by_pos = position_diff.mean(axis=1) # Average diff across 1152 dims for each position + max_diff_by_pos = position_diff.max(axis=1) # Max diff across 1152 dims for each position + + # Calculate feature differences + feature_diff = position_diff.mean(axis=0) # Average diff across 36 positions for each feature + top_features_idx = np.argsort(feature_diff)[::-1][:20] # Top 20 most different features + + # Create visualization + fig, axs = plt.subplots(2, 2, figsize=(18, 14)) + plt.suptitle(f'Comparison: {key1} vs {key2}', fontsize=16) + + # 1. Heatmap of differences across all positions and features + im = axs[0, 0].imshow(position_diff, aspect='auto', cmap='viridis') + axs[0, 0].set_title(f'Embedding Differences (Sequence Sim: {seq_similarity:.2f}, Cosine Sim: {cosine_sim:.2f})') + axs[0, 0].set_xlabel('Feature Dimension') + axs[0, 0].set_ylabel('Sequence Position') + plt.colorbar(im, ax=axs[0, 0], label='Absolute Difference') + + # 2. Bar plot of position differences (which positions differ most) + axs[0, 1].bar(range(len(mean_diff_by_pos)), mean_diff_by_pos, alpha=0.7, label='Mean Diff') + axs[0, 1].bar(range(len(max_diff_by_pos)), max_diff_by_pos, alpha=0.5, label='Max Diff') + axs[0, 1].set_title('Differences by Position') + axs[0, 1].set_xlabel('Position in Sequence') + axs[0, 1].set_ylabel('Difference Magnitude') + axs[0, 1].legend() + + # 3. Top different features + axs[1, 0].bar(range(len(top_features_idx)), feature_diff[top_features_idx]) + axs[1, 0].set_title('Top 20 Most Different Features') + axs[1, 0].set_xlabel('Feature Rank') + axs[1, 0].set_ylabel('Mean Difference') + axs[1, 0].set_xticks(range(len(top_features_idx))) + axs[1, 0].set_xticklabels([f"{i}" for i in top_features_idx], rotation=90) + + # # 4. Correlation matrix between the two embeddings + # corr_matrix = np.corrcoef(emb1, emb2) + # sns.heatmap(corr_matrix, ax=axs[1, 1], cmap='coolwarm', vmin=-1, vmax=1) + # axs[1, 1].set_title('Correlation Between Embeddings') + # + # plt.tight_layout(rect=[0, 0, 1, 0.96]) + # plt.savefig(f'comparison_{key1}_vs_{key2}.png', dpi=300) + # print(f"Saved comparison to comparison_{key1}_vs_{key2}.png") + # + # # Print summary + # print(f"\nComparison of {key1} vs {key2}:") + # print(f" Sequence similarity: {seq_similarity:.4f}") + # print(f" Embedding cosine similarity: {cosine_sim:.4f}") + # print(f" Embedding euclidean distance: {euclidean_dist:.4f}") + # print(f" Top 5 most different positions: {np.argsort(mean_diff_by_pos)[::-1][:5]}") + # print(f" Top 5 most different features: {top_features_idx[:5]}") + + # 4. Sequence comparison visualization + ax = axs[1, 1] + ax.set_title('Sequence Comparison') + ax.axis('off') # Turn off axes + + # Find differences between sequences + differences = [] + for i, (a, b) in enumerate(zip(seq1, seq2)): + if a != b: + differences.append((i, a, b)) + + # Calculate sequence identity + seq_len = max(len(seq1), len(seq2)) + percent_identity = (seq_len - len(differences)) / seq_len * 100 + + # Show sequence statistics + ax.text(0.05, 0.95, f"Sequence length: {len(seq1)} and {len(seq2)} amino acids", fontsize=11) + ax.text(0.05, 0.9, f"Sequence identity: {percent_identity:.1f}%", fontsize=11) + ax.text(0.05, 0.85, f"Number of differences: {len(differences)}", fontsize=11) + + # Display sequence alignment based on sequence length + if len(seq1) > 80: # For longer sequences, show regions with differences + if differences: + ax.text(0.05, 0.75, "Key differences:", fontsize=11, fontweight='bold') + y_pos = 0.7 + + # Show up to 5 difference regions + for i, (pos, aa1, aa2) in enumerate(differences[:5]): + context_start = max(0, pos - 5) + context_end = min(len(seq1), pos + 6) + + # Extract regions around differences + region1 = seq1[context_start:context_end] + region2 = seq2[context_start:context_end] + + # Create marker pointing to difference + marker = " " * (pos - context_start) + "^" + + ax.text(0.05, y_pos, f"Diff {i + 1} (pos {pos + 1}):", fontsize=10) + ax.text(0.05, y_pos - 0.05, region1, fontfamily='monospace', fontsize=10) + ax.text(0.05, y_pos - 0.1, region2, fontfamily='monospace', fontsize=10) + ax.text(0.05, y_pos - 0.15, marker, fontfamily='monospace', fontsize=10) + + y_pos -= 0.2 + + if len(differences) > 5: + ax.text(0.05, y_pos, f"...and {len(differences) - 5} more differences", fontsize=10) + else: + ax.text(0.05, 0.7, "The sequences are identical.", fontsize=11) + else: + # For shorter sequences, show complete alignment + y_pos = 0.75 + ax.text(0.05, y_pos, "Sequence alignment:", fontsize=11, fontweight='bold') + y_pos -= 0.05 + + # Split into chunks of 50 for better display + chunk_size = 50 + for i in range(0, len(seq1), chunk_size): + chunk1 = seq1[i:i + chunk_size] + chunk2 = seq2[i:i + chunk_size] if i < len(seq2) else "" + + # Create marker for differences + marker = "" + for j in range(len(chunk1)): + if j < len(chunk2) and chunk1[j] != chunk2[j]: + marker += "^" + else: + marker += " " + + ax.text(0.05, y_pos, f"Pos {i + 1}:", fontsize=9) + y_pos -= 0.05 + ax.text(0.05, y_pos, chunk1, fontfamily='monospace', fontsize=10) + y_pos -= 0.05 + ax.text(0.05, y_pos, chunk2, fontfamily='monospace', fontsize=10) + y_pos -= 0.05 + ax.text(0.05, y_pos, marker, fontfamily='monospace', fontsize=10) + y_pos -= 0.05 + +def plot_3d_PCA(data, keys, seqs=None, num_samples=100): + """ + Plot 3D PCA visualization of the embeddings. + + Args: + data: List of embeddings with shape (36, 1152) + keys: List of identifiers for each embedding + seqs: Optional list of sequences for annotation + num_samples: Number of samples to plot (default: 10) + """ + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + from sklearn.decomposition import PCA + import numpy as np + import random + + # Limit to num_samples if there are more samples + if len(data) > num_samples: + # Select random indices + indices = random.sample(range(len(data)), num_samples) + sampled_data = [data[i] for i in indices] + sampled_keys = [keys[i] for i in indices] + sampled_seqs = [seqs[i] for i in indices] if seqs is not None else None + else: + sampled_data = data + sampled_keys = keys + sampled_seqs = seqs + + # Convert embeddings to feature vectors + feature_vectors = [] + for embedding in sampled_data: + # Mean across features to get (36,) + feature_vectors.append(embedding.mean(axis=1)) + + feature_vectors = np.array(feature_vectors) + + # Apply PCA to reduce to 3 dimensions + pca = PCA(n_components=3) + pca_result = pca.fit_transform(feature_vectors) + + # Create 3D plot with larger size + fig = plt.figure(figsize=(16, 14)) + ax = fig.add_subplot(111, projection='3d') + + # Use a colormap with distinct colors for each sample + cmap = plt.cm.get_cmap('tab10', num_samples) + + scatter = ax.scatter( + pca_result[:, 0], + pca_result[:, 1], + pca_result[:, 2], + c=range(len(pca_result)), # Each sample gets its own color + cmap=cmap, + s=50, # Smaller dot size + alpha=0.8 + ) + + # Add annotations for all points + for i, key in enumerate(sampled_keys): + ax.text(pca_result[i, 0], pca_result[i, 1], pca_result[i, 2], key, fontsize=9) + + # Set labels and title + ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)') + ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)') + ax.set_zlabel(f'PC3 ({pca.explained_variance_ratio_[2]:.2%} variance)') + ax.set_title(f'3D PCA of Protein Embeddings ({num_samples} Samples)', fontsize=16) + + # Add a color bar with sample IDs + cbar = plt.colorbar(scatter, ax=ax, pad=0.1, ticks=range(len(sampled_keys))) + cbar.set_label('Sample ID') + cbar.ax.set_yticklabels(sampled_keys) + + # Add total explained variance as text + total_var = sum(pca.explained_variance_ratio_[:3]) + fig.text(0.5, 0.01, f'Total variance explained: {total_var:.2%}', ha='center', fontsize=12) + + plt.tight_layout() + plt.savefig('3d_pca_embeddings_10samples.png', dpi=300) + print(f"Saved 3D PCA visualization to 3d_pca_embeddings_10samples.png") + + # Create interactive plot if plotly is available + try: + import plotly.express as px + import pandas as pd + + df = pd.DataFrame({ + 'PC1': pca_result[:, 0], + 'PC2': pca_result[:, 1], + 'PC3': pca_result[:, 2], + 'Sample': sampled_keys + }) + + if sampled_seqs is not None: + df['Sequence'] = sampled_seqs + + fig = px.scatter_3d( + df, x='PC1', y='PC2', z='PC3', + color='Sample', # Each sample gets unique color + hover_data=['Sample'], + title='Interactive 3D PCA of Protein Embeddings (10 Samples)', + labels={ + 'PC1': f'PC1 ({pca.explained_variance_ratio_[0]:.2%})', + 'PC2': f'PC2 ({pca.explained_variance_ratio_[1]:.2%})', + 'PC3': f'PC3 ({pca.explained_variance_ratio_[2]:.2%})' + } + ) + fig.write_html('3d_pca_interactive_10samples.html') + print(f"Saved interactive 3D PCA visualization to 3d_pca_interactive_10samples.html") + except ImportError: + print("Plotly not available. Only static plot created.") + + plt.show() + return pca_result, sampled_keys + + +def plot_2d_TSNE(data, keys, seqs=None, num_samples=100, perplexity=30, learning_rate=200): + """ + Plot 2D t-SNE visualization of the embeddings. + + Args: + data: List of embeddings with shape (36, 1152) + keys: List of identifiers for each embedding + seqs: Optional list of sequences for annotation + num_samples: Number of samples to plot (default: 100) + perplexity: t-SNE perplexity parameter (default: 30) + learning_rate: t-SNE learning rate (default: 200) + """ + import matplotlib.pyplot as plt + from sklearn.manifold import TSNE + import numpy as np + import random + + # Limit to num_samples if there are more samples + if len(data) > num_samples: + # Select random indices + indices = random.sample(range(len(data)), num_samples) + sampled_data = [data[i] for i in indices] + sampled_keys = [keys[i] for i in indices] + sampled_seqs = [seqs[i] for i in indices] if seqs is not None else None + else: + sampled_data = data + sampled_keys = keys + sampled_seqs = seqs + + # Convert embeddings to feature vectors + feature_vectors = [] + for embedding in sampled_data: + # Mean across features to get (36,) + feature_vectors.append(embedding.mean(axis=1)) + + feature_vectors = np.array(feature_vectors) + + # Apply t-SNE to reduce to 2 dimensions + tsne = TSNE(n_components=2, perplexity=min(perplexity, len(feature_vectors)-1), + learning_rate=learning_rate, random_state=42) + tsne_result = tsne.fit_transform(feature_vectors) + + # Create 2D plot with larger size + fig = plt.figure(figsize=(16, 14)) + ax = fig.add_subplot(111) + + # Use a colormap with distinct colors for each sample + cmap = plt.cm.get_cmap('tab10', num_samples) + + scatter = ax.scatter( + tsne_result[:, 0], + tsne_result[:, 1], + c=range(len(tsne_result)), # Each sample gets its own color + cmap=cmap, + s=100, # Larger dot size for 2D plot + alpha=0.8 + ) + + # Add annotations for all points + for i, key in enumerate(sampled_keys): + ax.text(tsne_result[i, 0], tsne_result[i, 1], key, fontsize=9) + + # Set labels and title + ax.set_xlabel('t-SNE dimension 1') + ax.set_ylabel('t-SNE dimension 2') + ax.set_title(f'2D t-SNE of Protein Embeddings ({num_samples} Samples)', fontsize=16) + + # Add a color bar with sample IDs + cbar = plt.colorbar(scatter, ax=ax, pad=0.1, ticks=range(len(sampled_keys))) + cbar.set_label('Sample ID') + cbar.ax.set_yticklabels(sampled_keys) + + plt.tight_layout() + plt.savefig('2d_tsne_embeddings.png', dpi=300) + print(f"Saved 2D t-SNE visualization to 2d_tsne_embeddings.png") + + # Create interactive plot if plotly is available + try: + import plotly.express as px + import pandas as pd + + df = pd.DataFrame({ + 'Dimension 1': tsne_result[:, 0], + 'Dimension 2': tsne_result[:, 1], + 'Sample': sampled_keys + }) + + if sampled_seqs is not None: + df['Sequence'] = sampled_seqs + + fig = px.scatter( + df, x='Dimension 1', y='Dimension 2', + color='Sample', # Each sample gets unique color + hover_data=['Sample'], + title=f'Interactive 2D t-SNE of Protein Embeddings ({num_samples} Samples)', + ) + fig.write_html('2d_tsne_interactive.html') + print(f"Saved interactive 2D t-SNE visualization to 2d_tsne_interactive.html") + except ImportError: + print("Plotly not available. Only static plot created.") + + plt.show() + return tsne_result, sampled_keys + + +# def plot_(data, keys): +# import matplotlib.pyplot as plt +# from sklearn.decomposition import PCA +# from sklearn.manifold import TSNE +# from sklearn.cluster import KMeans +# import seaborn as sns +# import numpy as np +# from scipy.spatial.distance import pdist, squareform +# +# for arr, key in zip(data, keys): +# try: +# if len(arr.shape) != 2: +# print(f"Skipping {key} as it's not 2D") +# continue +# +# n_samples, n_features = arr.shape +# print(f"Visualizing {key}: {n_samples} samples × {n_features} features") +# +# fig = plt.figure(figsize=(18, 15)) +# +# # 1. PCA visualization +# ax1 = plt.subplot(2, 2, 1) +# pca = PCA(n_components=2) +# pca_result = pca.fit_transform(arr) +# +# # Apply KMeans to color points by cluster +# n_clusters = min(5, n_samples) +# kmeans = KMeans(n_clusters=n_clusters, random_state=42) +# clusters = kmeans.fit_predict(arr) +# +# scatter = ax1.scatter(pca_result[:, 0], pca_result[:, 1], +# c=clusters, cmap='viridis', alpha=0.8, s=100) +# ax1.set_title(f'PCA Projection') +# ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)') +# ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)') +# plt.colorbar(scatter, ax=ax1, label='Cluster') +# +# # 2. t-SNE visualization +# ax2 = plt.subplot(2, 2, 2) +# perplexity = min(30, max(5, n_samples-1)) +# tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity) +# tsne_result = tsne.fit_transform(arr) +# +# scatter2 = ax2.scatter(tsne_result[:, 0], tsne_result[:, 1], +# c=clusters, cmap='viridis', alpha=0.8, s=100) +# ax2.set_title(f't-SNE Projection') +# plt.colorbar(scatter2, ax=ax2, label='Cluster') +# +# # 3. Sequence similarity heatmap +# ax3 = plt.subplot(2, 2, 3) +# distances = squareform(pdist(arr, metric='euclidean')) +# sns.heatmap(distances, cmap='coolwarm', ax=ax3) +# ax3.set_title('Pairwise Distance Matrix') +# ax3.set_xlabel('Sequence Index') +# ax3.set_ylabel('Sequence Index') +# +# # 4. Feature variance visualization +# ax4 = plt.subplot(2, 2, 4) +# feature_variance = np.var(arr, axis=0) +# sorted_idx = np.argsort(feature_variance)[::-1] +# top_k = 50 # Show top 50 most variable features +# ax4.bar(range(top_k), feature_variance[sorted_idx[:top_k]]) +# ax4.set_title('Top Variable Features') +# ax4.set_xlabel('Feature Rank') +# ax4.set_ylabel('Variance') +# +# plt.suptitle(f'Visualization of {key} Embeddings', fontsize=16) +# plt.tight_layout(rect=[0, 0, 1, 0.96]) +# +# plt.savefig(f'visualization_{key}.png', dpi=300) +# print(f"Saved visualization to visualization_{key}.png") +# plt.close() +# +# except Exception as e: +# print(f"Error visualizing {key}: {str(e)}") + +def main(): + current_path = os.path.dirname(os.path.abspath(__file__)) + directory = current_path + data, keys, seqs = load_npz_files(directory) + print_shapes(data, keys, seqs) + ## print_samples(data, keys) + ## plot_(data, keys) + # selected_keys = [('HLA-B3813', 'HLA-B3814'), ('HLA-B38:01', 'HLA-B38:05')] # second sample has same pseudoseq + selected_keys = [('HLA-B13020103', 'HLA-B13020105')] + plot_pairwise_comparison(selected_keys, keys, data, seqs) + plot_3d_PCA(data, keys, seqs) + plot_2d_TSNE(data, keys, seqs) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/install.sh b/install.sh index d47dfdc1..b43201e9 100755 --- a/install.sh +++ b/install.sh @@ -42,6 +42,8 @@ echo "4. AlphaFold - GitHub: https://github.com/google-deepmind/alphafold" echo " Paper: https://www.nature.com/articles/s41586-021-03819-2" echo "5. ProteinMPNN - Github https://github.com/dauparas/ProteinMPNN" echo " Paper: https://www.science.org/doi/10.1126/science.add2187" +echo "6. Pep2Vec - GitHub: https://github.com/Genentech/Pep2Vec" +echo " Paper: https://www.biorxiv.org/content/10.1101/2024.10.14.618255v1" echo "########################################################" # Step 1: Ask for Modeller License Key @@ -81,7 +83,8 @@ else fi # Activate the environment -$ACTIVATE_CMD "$ENV_NAME" +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate "$ENV_NAME" # Step 5: Clone and Install PANDORA echo "✔ Cloning and installing PANDORA..." @@ -124,7 +127,37 @@ mv "data/modified_files/PMHC.py" "$PANDORA_PMHC_PATH" echo "ProteinMPNN installation" git clone https://github.com/dauparas/ProteinMPNN.git echo "✔ ProteinMPNN setup is done. " -# Step 12: Cleanup and Completion + +# Step 11: Install Pep2Vec +echo "Pep2Vec installation" +git lfs install +git clone https://github.com/Genentech/Pep2Vec +git lfs pull +# verify pep2vec.bin is downloaded +if [ ! -f "Pep2Vec/pep2vec.bin" ]; then + echo "⚠ pep2vec.bin not found. Please check the Pep2Vec repository." + exit 1 +fi +echo "✔ Pep2Vec setup is done. " + +# Step 12: Install graphviz +echo "Graphviz installation" +if ! command -v dot &>/dev/null; then + echo "⚠ Graphviz not found. Installing..." + if [ "$(uname)" == "Linux" ]; then + sudo apt-get install graphviz + elif [ "$(uname)" == "Darwin" ]; then + brew install graphviz + else + echo "⚠ Unsupported OS. Please install Graphviz manually." + exit 1 + fi +else + echo "✔ Graphviz is already installed." +fi + + +# Step 13: Cleanup and Completion cd "$CURRENT_DIR" echo "✔ Installation completed successfully!" echo "Please check and modify 'user_setting.py' file to customize for your usage" diff --git a/pip_requirements.txt b/pip_requirements.txt index 11a63105..0d12a7b9 100644 --- a/pip_requirements.txt +++ b/pip_requirements.txt @@ -16,5 +16,17 @@ toolz==1.0.0 torch==1.10.1 typing-extensions==3.7.4.3 urllib3==1.26.14 - - +Levenshtein==0.27.1 +tqdm==4.67.1 +pyarrow==19.0.1 +scikit-learn==1.6.1 +scipy==1.9.3 # required by tensorflow 2.11 +umap-learn==0.5.7 +scikit-image==0.22.0 +pandas==2.2.3 +matplotlib==3.6.3 +datashader==0.17.0 +bokeh==3.4.3 +holoviews==1.20.2 +colorcet==3.1.0 +pydot==4.0.0 \ No newline at end of file diff --git a/run_ESM.py b/run_ESM.py new file mode 100644 index 00000000..e5534c65 --- /dev/null +++ b/run_ESM.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python +""" +requires python3.10+ +esm_embed.py +============ + +Generate protein embeddings with ESM-C or any other EvolutionaryScale/Meta +protein language model. + +Examples +-------- +# 1. Local ESM-C-300M on GPU 0, output NPZ +python esm_embed.py --input proteins.fa --model esmc_300m \ + --device cuda:0 --outfile embeddings.npz + +# 2. Remote ESM-3-large (98 B) via Forge API +export ESM_API_TOKEN="hf_xxxxxxxxxxxxxxxxxx" +python esm_embed.py --input proteins.fa --model esm3-98b-2024-08 \ + --remote --outfile embeds.parquet +""" +from __future__ import annotations +import argparse, os, sys, json, time, itertools, pathlib, warnings +from typing import List, Tuple, Dict, Iterable +from esm.sdk.api import ESMProtein, LogitsConfig + +import torch +import numpy as np +import pandas as pd +import tqdm +import csv +############################################################################### +# --------------------------- I/O utilities --------------------------------- +############################################################################### +def read_dat(path: str) -> List[Tuple[str, str]]: + """Read tab-separated file: first col is id, second is sequence.""" + seqs = [] + with open(path, "r") as f: + for line in f: + line = line.strip() + if "\t" in line: + line = line.split("\t", 1) + elif " " in line: + line = line.split(" ") + else: + line = line.split(maxsplit=1) + if not line or len(line) < 2: + print(line) + continue + + if len(line) == 2: + seqs.append((line[0], line[1])) + + return seqs + + +def read_csv(path: str, mhc_class: int): + seqs: List[Tuple[str, ...]] = [] + selected_cols = ["simple_allele", "sequence", "mhc_types", "repres_pseudo_positions"] + file = pd.read_csv(path, sep=",", usecols=selected_cols) + file = file[file["mhc_types"] == mhc_class] + print(file.columns) + # convert simple allele, remove * and : to mtach with netmhcpan + file[selected_cols[0]] = file[selected_cols[0]].str.replace("*", "") + file[selected_cols[0]] = file[selected_cols[0]].str.replace(":", "") + # convert pseudosequence positions to list of integers (convert string to list split by ;) + pseudoseq_indices = file[selected_cols[3]].apply(lambda x: [int(i) for i in x.split(";") if i.isdigit()]) + # return as list of tuples + for index, row in file.iterrows(): + seqs.append((row[selected_cols[0]], row[selected_cols[1]])) + return seqs, pseudoseq_indices.tolist() + +############################################################################### +# ------------- Local model loader (ESM-C, ESM-2, ESM3-open) ---------------- +############################################################################### +def load_local_model(model_name: str, device: str): + """ + Return (model, to_tensor_fn) for **new** ESM-C / ESM-3 + or (model, batch_converter_fn) for **legacy** ESM-2. + """ + try: # new-style ESM-C + if model_name.startswith("esmc"): + from esm.models.esmc import ESMC + from esm.sdk.api import ESMProtein, LogitsConfig + model = ESMC.from_pretrained(model_name).to(device) + def embed_one(seq: str): + protein = ESMProtein(sequence=seq) + t = model.encode(protein) + out = model.logits(t, LogitsConfig(sequence=True, + return_embeddings=True)) + return out.embeddings.mean(0).cpu().numpy() + return embed_one + # new-style ESM-3 open weight + if model_name.startswith("esm3"): + from esm.models.esm3 import ESM3 + from esm.sdk.api import ESMProtein, LogitsConfig + model = ESM3.from_pretrained(model_name).to(device) + def embed_one(seq: str): + protein = ESMProtein(sequence=seq) + t = model.encode(protein) + out = model.logits(t, LogitsConfig(sequence=True, + return_embeddings=True)) + return out.embeddings.mean(0).cpu().numpy() + return embed_one + except ImportError: + pass # will try legacy route below + + # ---------- legacy facebookresearch/esm (ESM-1/2) ---------- + try: + from esm import pretrained + create = getattr(pretrained, model_name, + pretrained.load_model_and_alphabet) + model, alphabet = create(model_name) if callable(create) else create() + model.eval().to(device) + converter = alphabet.get_batch_converter() + def embed_one(seq: str): + _, _, toks = converter([("x", seq)]) + with torch.inference_mode(): + rep = model(toks.to(device), + repr_layers=[model.num_layers])["representations"] + return rep[model.num_layers][0, 1:len(seq)+1].mean(0).cpu().numpy() + return embed_one + except Exception as e: + raise RuntimeError(f"Don’t know how to load {model_name}: {e}") + +# --------------------------------------------------------------------- # +# Remote (Forge / AWS) wrapper +# --------------------------------------------------------------------- # +def load_remote_client(model_name: str, token: str): + import esm + client = esm.sdk.client(model_name, token=token) + from esm.sdk.api import ESMProtein, LogitsConfig + def embed_one(seq: str): + protein = ESMProtein(sequence=seq) + t = client.encode(protein) + out = client.logits(t, LogitsConfig(sequence=True, + return_embeddings=True)) + return out.embeddings.mean(0) + return embed_one + +############################################################################### +# -------------------- Embedding extraction pipeline ------------------------ +############################################################################### +# @torch.inference_mode() +# def embed_local( +# model, +# batch_converter, +# sequences: List[Tuple[str, str]], +# batch_size: int = 8, +# device: str = "cpu", +# pooling: str = "mean", +# ) -> Dict[str, np.ndarray]: +# """ +# Run local model and return {id: embedding(ndarray)} dict. +# Pooling: "mean" (default) or "cls". +# """ +# embeds = {} +# for i in range(0, len(sequences), batch_size): +# sub = sequences[i : i + batch_size] +# labels, strs, toks = batch_converter(sub) +# toks = toks.to(device) +# reps = model(toks, repr_layers=[model.num_layers])["representations"][ +# model.num_layers +# ] # (B, L, D) +# for label, tok, rep in zip(labels, toks, reps): +# if pooling == "mean": +# mask = tok != model.alphabet.padding_idx +# embed = rep[mask].mean(0) +# else: # CLS – first token after padding_tok +# embed = rep[0] +# embeds[label] = embed.cpu().numpy() +# return embeds + + +def embed_remote( + client, + sequences: List[Tuple[str, str]], + batch_size: int = 16, + pooling: str = "mean", +) -> Dict[str, np.ndarray]: + """ + Use cloud client; supports .embed() (ESM-C) and .generate_embeddings() (ESM-3). + """ + embeds = {} + for i in range(0, len(sequences), batch_size): + sub = sequences[i : i + batch_size] + ids, seqs = zip(*sub) + if hasattr(client, "embed"): + out = client.embed(seqs, pooling) + else: + out = client.generate_embeddings(seqs, pooling) + embeds.update({idx: emb for idx, emb in zip(ids, out)}) + return embeds + + +############################################################################### +# -------------------------- Main CLI handler ------------------------------ +############################################################################### +def main(**local_args): + parser = argparse.ArgumentParser(description="ESM embedding generator") + + if local_args: + args = parser.parse_args([]) + for k, v in local_args.items(): + setattr(args, k, v) + else: + parser.add_argument("--input", required=True, help="dat file") + parser.add_argument("--model", default="esmc_300m", help="Model name/id") + parser.add_argument("--outfile", required=True, help="Output .npz or .parquet") + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") + parser.add_argument("--pooling", choices=["mean", "cls"], default="mean") + parser.add_argument("--remote", action="store_true", help="Use cloud API") + parser.add_argument("--api_token", default=os.getenv("ESM_API_TOKEN"), help="Forge/NIM token") + parser.add_argument("--mhc_class", type=int, default=1, help="MHC class (1 or 2) for CSV input") + parser.add_argument("--filter_ps", action="store_true", help="Filter out pseudosequence positions") + parser.add_argument("--noise_augmentation", choices=["Gaussian", None], default=None, + help="Add Gaussian noise to embeddings (10% of std_dev)") + args = parser.parse_args() + + seqs = None + pseudoseq_indices = None + if args.input.endswith(".csv"): + seqs, pseudoseq_indices = read_csv(args.input, mhc_class=args.mhc_class) + elif args.input.endswith(".dat"): + seqs = read_dat(args.input) + if not seqs: + sys.exit("No sequences found!") + + # Generate embeddings and filter out pseudosequence positions if needed + if args.remote: + if args.api_token is None: + sys.exit("Provide --api_token or set ESM_API_TOKEN") + client = load_remote_client(args.model, args.api_token) + embeddings = embed_remote(client, seqs, args.batch_size, args.pooling) + if args.filter_ps: + print("Filtering out pseudosequence positions...") + # select embedding indexes equal to the values in the pseudosequence positions + #TODO + + print("embeddings with", args.model) + sequences = {seq_id: seq for seq_id, seq in seqs} + else: + embed_one = load_local_model(args.model, args.device) + print("embeddings with", args.model) + # print len of sequences + print(f"Number of sequences: {len(seqs)}") + # print first 5 sequences + print("First 5 sequences:", seqs[:5]) + embeddings = {} + for idx, (seq_id, seq) in enumerate(tqdm.tqdm(seqs, desc="Embedding sequences")): + emb = embed_one(seq) + if noise_augmentation is not None: + if noise_augmentation == "Gaussian": + # determine the noise based on the embeddings value range + # calculate the standard deviation of the embedding + std_dev = np.std(emb) + noise = np.random.normal(0, std_dev * 0.1, emb.shape) # 10% of std_dev + emb += noise + else: + raise ValueError(f"Unknown noise augmentation: {noise_augmentation}") + embeddings[seq_id] = emb + # print shape of the embedding + # print(f"Embedding shape for {seq_id}: {embeddings[seq_id].shape}") # (sequence_length, embedding_dim) + if args.filter_ps: + print("Filtering out pseudosequence positions...") + # select embedding indexes equal to the values in the pseudosequence positions + if pseudoseq_indices: + # filter out positions in the sequence that are in the pseudosequence positions + embeddings[seq_id] = embeddings[seq_id][pseudoseq_indices[idx]] + print("Filtered embedding shape for", seq_id, ":", embeddings[seq_id].shape) + + sequences = {seq_id: seq for seq_id, seq in seqs} + + + # ------------------ save ------------------ + out_path = pathlib.Path(args.outfile) + # make directory if it does not exist + out_path.parent.mkdir(parents=True, exist_ok=True) + if out_path.suffix == ".npz": + np.savez_compressed(out_path, **embeddings) + # create a .csv file and save the sequences + with open(out_path.with_suffix(".csv"), "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["id", "sequence"]) + for seq_id, seq in sequences.items(): + writer.writerow([seq_id, seq]) + elif out_path.suffix == ".parquet": + df = pd.DataFrame( + [(k, v.astype(np.float32)) for k, v in embeddings.items()], + columns=["id", "embedding"], + ) + df.to_parquet(out_path, index=False) + else: + sys.exit("outfile must end with .npz or .parquet") + + print(f"[✓] Saved embeddings for {len(embeddings)} sequences to {out_path}") + + +if __name__ == "__main__": + # Config: + mhc_class = 1 # 1 for MHC-I, 2 for MHC-II + model = "esmc_600m" # "esm3-98b-2024-08" "esm3_sm_open_v1", "esm3-open", "esmc_300m", "esmc_600m" + filter_out_pseudoseq_positions = False # filter out positions with pseudosequence in MHC-I and MHC-II + noise_augmentation = "Gaussian" # None + # Input: + dat_path = "data/NetMHCpan_dataset/NetMHCpan_train/MHC_pseudo.dat" # for MHC-I + # dat_path = "data/NetMHCpan_dataset/NetMHCIIpan_train/pseudosequence.2016.all.X.dat" # for MHC-II + # dat_path = "data/HLA_alleles/pseudoseqs/PMGen_pseudoseq.csv" # for both MHC-I and MHC-II + out_path = "data/ESM/NetMHCpan_dataset/esmc_600m/mhc1_encodings.npz" + remote = False + device = "cuda:0" if torch.cuda.is_available() else "cpu" + batch_size = 1 + pooling = "mean" + # run the script + main(input=dat_path, model=model, outfile=out_path, remote=remote, + device=device, batch_size=batch_size, pooling=pooling, mhc_class=mhc_class, filter_ps=filter_out_pseudoseq_positions, + noise_augmentation=noise_augmentation) diff --git a/run_netmhcpan_script.py b/run_netmhcpan_script.py new file mode 100644 index 00000000..1421f939 --- /dev/null +++ b/run_netmhcpan_script.py @@ -0,0 +1,865 @@ +import difflib +import json +import sys +from run_utils import run_and_parse_netmhcpan +import os +import pandas as pd +import numpy as np +from tqdm import tqdm +import pyarrow as pa +import pyarrow.parquet as pq +import shutil +from concurrent.futures import ProcessPoolExecutor + + +""" +Guide to run the script: +1. Run the script with the argument process_data to process the data + - arg1: process_data arg2: None +2. Run the script with the argument run_netmhcpan to run the NetMHCpan for the processed data + - arg1: run_netmhcpan arg2: chunk_number (0-9) +3. Run the script with the argument combine_results to combine the results from the NetMHCpan runs + - arg1: combine_results arg2: None + +""" + + +def load_data(path): + data = pd.read_csv(path, delim_whitespace=True, header=None) + return data + + +def add_mhc_sequence_column(input_df): + """ + Add MHC sequence information to the dataset by looking up allele sequences. + + Args: + input_df: DataFrame with an 'allele' column (like el_data_0 or el_data_1) + + Returns: + DataFrame with added 'mhc_sequence' column + """ + # TODO needs revision and validation + dict_file = "data/HLA_alleles/pseudoseqs/PMGen_pseudoseq.csv" + + # Load allele-sequence mapping + try: + dict_df = pd.read_csv(dict_file, usecols=['simple_allele', 'sequence']) + dict_df['net_mhc_allele'] = ( + dict_df['simple_allele'] + .str.replace('*', '', regex=False) + .str.replace(':', '', regex=False) + ) + mhc_sequence_map = dict(zip(dict_df['net_mhc_allele'], dict_df['sequence'])) + print(f"Loaded allele sequence mapping with {len(mhc_sequence_map)} entries") + except FileNotFoundError: + print(f"Warning: Dictionary file {dict_file} not found. No sequences will be mapped.") + mhc_sequence_map = {} + except Exception as e: + print(f"Error loading dictionary file: {str(e)}") + mhc_sequence_map = {} + + def lookup_sequence(allele_str): + """Find the sequence for an allele with fallback strategies.""" + if pd.isna(allele_str): + return None + + print(f"Processing allele: {allele_str}") + + # Handle list of alleles (from el_data_0 or el_data_1) + if isinstance(allele_str, str) and allele_str.startswith('[') and allele_str.endswith(']'): + try: + allele_list = eval(allele_str) + if isinstance(allele_list, list): + # Get sequence for each allele in list and join with '/' + seqs = [lookup_sequence(a) for a in allele_list] + found = [s for s in seqs if s is not None] + return '/'.join(found) if found else None + except Exception as e: + print(f" Failed to eval list: {e}") + + # Split on '-' if there are two alleles mashed together + elif 'H-2' in allele_str: + parts = [allele_str] + elif '-' in allele_str: + parts = allele_str.split('-') # eg. HLA-DRB101:01-DQA101:01 to [HLA, DRB101:01, DQA101:01] or HLA-DRB101:01 to [HLA, DRB101:01] + if len(parts) > 2: # eg. [HLA, DRB101:01, DQA101:01] + pref = parts[0] # HLA + part1 = "-".join(parts[0:2]) # HLA-DRB101:01 + part2 = pref + "-" + parts[2] # HLA-DQA101:01 + parts = [part1, part2] # [HLA-DRB101:01, HLA-DQA101:01] + elif len(parts) == 2: # eg. [HLA, DRB101:01] + pref = parts[0] # HLA + part1 = parts[1] # DRB101:01 + parts = [pref + "-" + part1] # [HLA-DRB101:01] + else: + parts = [allele_str] # No split needed eg. DRB101:01 + else: + parts = [allele_str] + + seqs = [] + for part in parts: + net = part.replace('*', '').replace(':', '') + print(f" Normalized part: {part!r} → net key: {net!r}") + seq = mhc_sequence_map.get(net) + if seq: + print(f"Found exact match for {net}: sequence length {len(seq)}") + else: + # First try to find keys where our name is contained within + containing_matches = [k for k in mhc_sequence_map.keys() if net in k] + if containing_matches: + # Sort by length to prefer closest match in length + best = sorted(containing_matches, key=len)[0] + seq = mhc_sequence_map[best] + print(f"No exact match for '{net}', using containing match '{best}'") + print(f"No exact match for {net}, using containing '{best}'") + else: + # fallback to closest match (cutoff can be tuned) + matches = difflib.get_close_matches(net, + mhc_sequence_map.keys(), + n=1, + cutoff=0.6) + if matches: + best = matches[0] + seq = mhc_sequence_map[best] + print(f"No exact match for '{net}', using closest match '{best}'") + if seq is None: + print(f"*No sequence found for part {net}*") + + seqs.append(seq) + + # if *all* parts failed, return None + if all(s is None for s in seqs): + return None + + # otherwise join only the found sequences with '/' + found = [s for s in seqs if s is not None] + return '/'.join(found) + + # Process unique alleles for efficiency + # Build map for unique alleles + unique = input_df['allele'].dropna().unique() + allele_seq_map = {} + for a in unique: + allele_seq_map[a] = lookup_sequence(a) + + # Now apply + updated_df = input_df.copy() + updated_df['mhc_sequence'] = updated_df['allele'].map(allele_seq_map) + + # Report summary of misses + missing = [a for a, seq in allele_seq_map.items() if seq is None] + if missing: + print(f"\nTotal missing sequences: {len(missing)}") + print("Examples of alleles without a match:", missing[:10]) + + # TODO fix later + # Drop rows where 'mhc_sequence' is None or NaN + updated_df = updated_df[updated_df['mhc_sequence'].notna() & (updated_df['mhc_sequence'] != None)] + + return updated_df + + +def process_data(mhc_path = "data/NetMHCpan_dataset/NetMHCIIpan_train/", + tmp_path = "data/NetMHCpan_dataset/tmp/", mhc_type=2): + """ + Load and process MHCII data for NetMHCpan + This function works for only HLA inputs + Args: + mhc_path: Path to the MHC data directory + tmp_path: Path to the temporary directory for storing intermediate files + + Returns: + el_data: DataFrame containing EL data + ba_data: DataFrame containing BA data + """ + + if not os.path.isdir(tmp_path): + os.mkdir(tmp_path) + + if mhc_type == 1: + el_data = pd.DataFrame(columns=["peptide", "label", "allele"]) + ba_data = pd.DataFrame(columns=["peptide", "label", "allele"]) + else: + el_data = pd.DataFrame(columns=["peptide", "label", "allele", "core"]) + ba_data = pd.DataFrame(columns=["peptide", "label", "allele", "core"]) + + # Check directory exists + if not os.path.exists(mhc_path): + print(f"Directory not found: {mhc_path}") + return + + if mhc_type == 1: + train_files = [f for f in os.listdir(mhc_path) if "_ba" in f or "_el" in f and "test" not in f] + else: + train_files = [f for f in os.listdir(mhc_path) if "train" in f] + print(f"Found {len(train_files)} train files") + + for file in tqdm(train_files): + if "BA" in file or "ba" in file: + data = load_data(mhc_path + file) + print(f"Loaded BA data from {file}: {ba_data.shape}") + + if mhc_type == 2: + # Rename columns if necessary + if data.shape[1] >= 4: + data.columns = ["peptide", "label", "allele", "core"] + list(range(data.shape[1] - 4)) + # add ba data to the dataframe + ba_data = pd.concat([ba_data, data], ignore_index=True) + else: + print(f"Warning: File {file} has insufficient columns: {data.shape[1]}") + if mhc_type == 1: + # Rename columns if necessary + if data.shape[1] >= 3: + data.columns = ["peptide", "label", "allele"] + list(range(data.shape[1] - 3)) + # add ba data to the dataframe + ba_data = pd.concat([ba_data, data], ignore_index=True) + else: + print(f"Warning: File {file} has insufficient columns: {data.shape[1]}") + + elif "EL" in file or "el" in file: + data = load_data(mhc_path + file) + print(f"Loaded EL data from {file}: {data.shape}") + + if mhc_type == 2: + # Rename columns if necessary + if data.shape[1] >= 4: # Ensure data has enough columns + data.columns = ["peptide", "label", "allele", "core"] + list(range(data.shape[1] - 4)) + # add el data to the dataframe + el_data = pd.concat([el_data, data], ignore_index=True) + else: + print(f"Warning: File {file} has insufficient columns: {data.shape[1]}") + else: + # skip the file and print a warning + print(f"Skipping file: {file}") + if mhc_type == 1: + # Rename columns if necessary + if data.shape[1] >= 3: + data.columns = ["peptide", "label", "allele"] + list(range(data.shape[1] - 3)) + # add el data to the dataframe + el_data = pd.concat([el_data, data], ignore_index=True) + else: + print(f"Warning: File {file} has insufficient columns: {data.shape[1]}") + + print(f"After loading data: {el_data.shape}") + if el_data.empty: + print("No data was loaded. Check file paths and content.") + return + + # get the cell lines rows, drop the rows that has HLA- in their allele column + el_data = el_data[el_data["allele"].str.contains("HLA-") == False] + print(f"After removing HLA- rows: {el_data.shape}") + + ## drop rows with the same allele and peptide (1) + el_data = el_data.drop_duplicates(subset=["allele", "peptide"], keep="first") + print(f"After removing duplicates: {el_data.shape}") + + ## get allele names of cell lines + if mhc_type == 1: + allelelist = mhc_path + "allelelist" + else: + allelelist = mhc_path + "allelelist.txt" # change to allelelist for MHCI + + # Check if allele list file exists + if not os.path.exists(allelelist): + print(f"Allele list file not found: {allelelist}") + return + + allele_map = pd.read_csv(allelelist, delim_whitespace=True, header=None) + allele_map.columns = ["key", "allele_list"] + print(f"Loaded allele map with {len(allele_map)} entries") + + # convert the second column to list if , is present, else return a one element list + allele_map["allele_list"] = allele_map["allele_list"].apply(lambda x: x.split(",") if "," in x else [x]) + + # Create a dictionary for more efficient lookup + allele_dict = dict(zip(allele_map["key"], allele_map["allele_list"])) + + # update the el_data['allele'] dataframe with the allele names + el_data["allele"] = el_data["allele"].apply(lambda x: allele_dict.get(x, [])) + print(f"After mapping alleles: {el_data.shape}") + + # Check if any allele mappings are empty + empty_alleles = el_data[el_data["allele"].apply(lambda x: len(x) == 0)].shape[0] + print(f"Rows with empty allele mappings: {empty_alleles}") + + # add the identifiers + # el_data['cell_line_id'] = el_data.index + # print the length of the el_data + # print(f"Length of el_data: {len(el_data)}") + + # decouple rows with multiple alleles, if the row['allele'] has multiple alleles, then create a new row for each allele + # el_data = el_data.explode("allele") + # print(f"After exploding alleles: {el_data.shape}") + + # reset index for unique identifier + # el_data = el_data.reset_index(drop=True) + + # First add DRA/ before the allele names for alleles with DRB + el_data["allele"] = el_data["allele"].apply(lambda x: ["HLA-DRA/HLA-" + a if "DRB" in a else a for a in x]) + + # all DRB alleles should have HLA-DRA/ in front of them, assert this + assert el_data[el_data["allele"].apply(lambda x: any("DRB" in a and "HLA-DRA/HLA-" not in a for a in x))].empty + + # if H-2 is present in the allele name, add a mice-DRA/mice- in front of it + el_data["allele"] = el_data["allele"].apply(lambda x: ["mice-DRA/mice-" + a if "H-2" in a else a for a in x]) + + # split double alleles with / + # if / not in allele name, replace the second - with / + el_data["allele"] = el_data["allele"].apply( + lambda x: [a[:a.find("-", a.find("-") + 1)] + "/HLA-" + a[a.find("-", a.find("-") + 1) + 1:] if a.count( + "-") >= 2 and "/" not in a else a for a in x]) + + # First add DRA/ before the allele names for alleles with DRB + el_data["allele"] = el_data["allele"].apply(lambda x: "HLA-DRA/HLA-" + x if "DRB" in x else x) + + # all DRB alleles should have HLA-DRA/ in front of them, assert this + assert el_data[el_data["allele"].apply(lambda x: "DRB" in x and "HLA-DRA/HLA-" not in x)].empty + + # if H-2 is present in the allele name, add a mice-DRA/mice- in front of it + el_data["allele"] = el_data["allele"].apply(lambda x: "mice-DRA/mice-" + x if "H-2" in x else x) + + # split double alleles with / + # if / not in allele name, replace the second - with / + el_data["allele"] = el_data["allele"].apply( + lambda x: x[:x.find("-", x.find("-") + 1)] + "/HLA-" + x[x.find("-", x.find("-") + 1) + 1:] if x.count("-") >= 2 and "/" not in x else x) + + # split to two variables, one with labels = 0 and one with labels = 1 + el_data_0 = el_data[el_data["label"] == 0] + el_data_1 = el_data[el_data["label"] == 1] + + # drop the labels column from el_data_1 + # el_data_1 = el_data_1.drop(columns=["label"]) + print(f"Data with label 0: {el_data_0.shape}") + + # Print some sample data for debugging + # print(el_data_1.head()) + + # drop the rows with the same allele and peptide (2) + # el_data_1 = el_data_1.drop_duplicates(subset=["allele", "peptide"], keep="first") + # print(f"After removing duplicates: {el_data_1.shape}") + + # Verify that "/" is present in all allele names + invalid_alleles = el_data_0[el_data_0["allele"].apply(lambda x: "/" in x)] + if not invalid_alleles.empty: + print("Alleles with '/' separator:") + print(invalid_alleles) + + # # Get unique alleles and run NetMHCpan for each + # unique_alleles = el_data_1["allele"].unique() + # print(f"Processing {len(unique_alleles)} unique alleles") + + # # get unique cell line ids + # unique_cell_lines = el_data_1["cell_line_id"].unique() + # + # # assert that the cell_line_id is unique + # print(f"Processing {len(unique_cell_lines)} unique cell line ids") + + return el_data_0, el_data_1, ba_data + + +def safe_remove(path, is_dir=False): + try: + if is_dir: + shutil.rmtree(path) + else: + os.remove(path) + print(f"Successfully removed: {path}") + except PermissionError: + print(f"Permission denied: {path}") + except FileNotFoundError: + print(f"File/Directory not found: {path}") + except OSError as e: + print(f"Error removing {path}: {str(e)}") + except Exception as e: + print(f"Unexpected error with {path}: {str(e)}") + + +def run_netmhcpan_(el_data, true_label, tmp_path, results_dir, chunk_number, mhc_class): + # chunk_dataframe = pd.DataFrame(columns=["MHC", "Peptide", "Of", "Core", "Core_Rel", "Inverted", "Identity", "Score_EL", "%Rank_EL", "Exp_Bind", "Score_BA", "%Rank_BA", "Affinity(nM)", "long_mer", "cell_line_id"]) + chunk_df_path = os.path.join(results_dir, f"el{true_label}_chunk_{chunk_number}.csv") + dropped_rows_path = os.path.join(results_dir, f"el{true_label}_dropped_rows_{chunk_number}.csv") + for idx, cell_row in tqdm(el_data.iterrows(), total=el_data.shape[0], desc=f"Processing chunk {chunk_number}"): + dropped_rows = pd.DataFrame( + columns=["MHC", "Peptide", "Of", "Core", "Core_Rel", "Inverted", "Identity", "Score_EL", "%Rank_EL", + "Exp_Bind", "Score_BA", "%Rank_BA", "Affinity(nM)", "long_mer", "cell_line_id"]) + number_of_alleles = len(eval(cell_row["allele"])) + result_data = pd.DataFrame( + columns=["MHC", "Peptide", "Of", "Core", "Core_Rel", "Inverted", "Identity", "Score_EL", "%Rank_EL", + "Exp_Bind", "Score_BA", "%Rank_BA", "Affinity(nM)", "long_mer", "cell_line_id"]) + peptide = cell_row["peptide"] + for allele in eval(cell_row["allele"]): + # get the unique id for peptide + unique_id = f"{peptide}_{allele}" + + # define path + peptide_fasta_path = os.path.join(tmp_path, f"fasta/el{true_label}_peptide_{unique_id}.fasta") + peptide_fasta_dir = os.path.dirname(peptide_fasta_path) + + # Ensure the directory exists + if not os.path.exists(peptide_fasta_dir): + os.makedirs(peptide_fasta_dir) + + # Write single peptide to fasta file + with open(peptide_fasta_path, "w") as f: + f.write(f">peptide\n{peptide}\n") + + # Output directory for this specific cell_line_id + output_dir = os.path.join(tmp_path, f"output/{true_label}_output_{unique_id}") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + try: + # Run NetMHCpan for this specific peptide + result_ = run_and_parse_netmhcpan( + peptide_fasta_file=peptide_fasta_path, + mhc_type=mhc_class, + output_dir=output_dir, + mhc_allele=allele, + save_csv=False + ) + + # select the top 1 result from netmhcpan output + if result_ is not None and not result_.empty: + result_ = result_.iloc[[0]] + result_['long_mer'] = peptide + result_['cell_line_id'] = str(chunk_number) + "_" + str(idx) + result_['convoluted_label'] = true_label + result_data = pd.concat([result_data, result_]) + + except Exception as e: + print(f"Error processing peptide {peptide} with allele {allele}: {str(e)}") + # Clean up temporary files + safe_remove(peptide_fasta_path) + + # Clean up peptide_fasta directory + peptide_fasta_dir = os.path.dirname(peptide_fasta_path) + if peptide_fasta_dir: # Prevent removing root/current directory + safe_remove(peptide_fasta_dir, is_dir=True) + + # Clean up output directory + safe_remove(output_dir, is_dir=True) + + # save the results to disk + if not result_data.empty: + # assert len(result_data) == number_of_alleles + if len(result_data) != number_of_alleles: + print(f"Warning: Expected {number_of_alleles} results but got {len(result_data)}") + result_data['assigned_label'] = result_data['Score_BA'].apply(lambda x: 1 if eval(x) >= 0.426 else 0) + if true_label == 1: + # Identify cell lines with at least one positive label + positive_cell_lines = result_data.groupby('cell_line_id')['assigned_label'].transform('max') == 1 + + # Separate dropped and kept rows without creating full copies + mask = result_data['cell_line_id'].isin( + result_data.loc[positive_cell_lines, 'cell_line_id'].unique() + ) + dropped_rows = pd.concat([dropped_rows, result_data[~mask]]) + result_data = result_data[mask] + + if true_label == 0: + # Directly concat and filter rows with label == 1 + dropped_rows = pd.concat([dropped_rows, result_data[result_data['assigned_label'] == 1]]) + result_data = result_data[result_data['assigned_label'] == 0] + + # save the results to the chunk_dataframe directly to the disk using append method + result_data.to_csv(chunk_df_path, mode="a", header=not os.path.exists(chunk_df_path), index=False) + dropped_rows.to_csv(dropped_rows_path, mode="a", header=not os.path.exists(dropped_rows_path), index=False) + + +def combine_datasets_(results_dir, include_dropped=False): + """ + dropped rows are the samples that received the wrong label from netmhcpan + Args: + results_dir: + include_dropped: + + Returns: + + """ + # read all CSV files in the results directory + all_csv_files = [os.path.join(results_dir, f) for f in os.listdir(results_dir) if f.endswith('.csv')] + # select el1 files and ignore files with dropped in the names + if include_dropped: + all_csv_files = [f for f in all_csv_files if "el1" in f] + else: + all_csv_files = [f for f in all_csv_files if "el1" in f and "dropped" not in f] + print(f"Combining {len(all_csv_files)} CSV files into final output") + + for csv_file in all_csv_files: + try: + df = pd.read_csv(csv_file) + df.rename(columns={'Peptide': 'peptide', 'MHC': 'allele'}, inplace=True) + # select only the columns that are needed + df = df[["allele", "peptide", "assigned_label", "convoluted_label"]] + except Exception as e: + print(f"Error reading {csv_file}: {str(e)}") + continue + + combined_df = pd.concat([pd.read_csv(f) for f in all_csv_files if os.path.getsize(f) > 0], ignore_index=True) + + return combined_df + +# def run_netmhcpan(el_data_1, unique_alleles, tmp_path, results_dir): +# for allele in unique_alleles: +# # Filter data for the current allele +# allele_data = el_data_1[el_data_1["allele"] == allele] +# +# # Create allele-specific filename base +# allele_filename = allele.replace("/", "_").replace("*", "").replace(":", "") +# +# # Get peptides +# peptides = allele_data['peptide'].values +# print(f"Total peptides for {allele}: {len(peptides)}") +# +# # Remove duplicate peptides +# unique_peptides = np.unique(peptides) +# print(f"Unique peptides for {allele}: {len(unique_peptides)}") +# +# # Process each peptide individually +# for peptide_idx, peptide in enumerate(unique_peptides): +# # Check if this peptide was already processed +# with open(os.path.join(results_dir, "processed_peptides.txt"), "r") if os.path.exists( +# os.path.join(results_dir, "processed_peptides.txt")) else open(os.devnull, "r") as f: +# processed = any(f"{allele}\t{peptide}\n" == line for line in f) +# if processed: +# continue +# +# # get the cell_line_id +# cell_line_id = allele_data[allele_data["peptide"] == peptide]["cell_line_id"].iloc[0] +# +# # Use a timestamp to ensure unique filenames across processes +# timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") +# unique_id = f"{peptide_idx}_{timestamp}" +# +# # Create fasta file with single peptide +# peptide_fasta_path = os.path.join(tmp_path, f"results/el0_peptide_{allele_filename}_{unique_id}.fasta") +# +# # Write single peptide to fasta file +# with open(peptide_fasta_path, "w") as f: +# f.write(f">peptide\n{peptide}\n") +# +# # Output directory for this specific peptide +# output_dir = os.path.join(tmp_path, f"results/el0_output_{allele_filename}_{unique_id}") +# +# try: +# # Run NetMHCpan for this specific peptide +# allele_result = run_and_parse_netmhcpan( +# peptide_fasta_file=peptide_fasta_path, +# mhc_type=2, +# output_dir=output_dir, +# mhc_allele=allele +# ) +# +# # Save results to disk immediately +# if allele_result is not None: +# # take the top 1 result +# allele_result = allele_result.head(1) # taking the first row, because the output is sorted by BF score and we take the best score +# result_path = os.path.join(results_dir, f"{allele_filename}.parquet") +# # add long_mer +# allele_result['long_mer'] = peptide +# allele_result['cell_line_id'] = cell_line_id +# new_table = pa.Table.from_pandas(allele_result) +# if os.path.exists(result_path): +# existing_table = pq.read_table(result_path) +# combined_table = pa.concat_tables([existing_table, new_table]) +# pq.write_table(combined_table, result_path) +# else: +# pq.write_table(new_table, result_path) +# +# with open(os.path.join(results_dir, "processed_peptides.txt"), "a") as processed_file: +# processed_file.write(f"{allele}\t{peptide}\n") +# +# except Exception as e: +# print(f"Error processing peptide {peptide} with allele {allele}: {str(e)}") +# finally: +# # Clean up temporary files immediately +# try: +# if os.path.exists(peptide_fasta_path): +# os.remove(peptide_fasta_path) +# if os.path.exists(output_dir) and os.path.isdir(output_dir): +# import shutil +# shutil.rmtree(output_dir) +# except Exception as e: +# print(f"Failed to remove temporary files: {str(e)}") + + +# def combine_results_(results_dir, final_output_path): +# +# for cell_line_id: +# +# pass +# # combine all csv files +# all_csv_files = [os.path.join(results_dir, f) for f in os.listdir(results_dir) if f.endswith('.csv')] +# dropped_rows = pd.DataFrame() +# if all_csv_files: +# print(f"Combining {len(all_csv_files)} csv files into final output") +# df = None +# for csv_file in all_csv_files: +# try: +# df = pd.read_csv(csv_file) +# except Exception as e: +# print(f"Error reading {csv_file}: {str(e)}") +# continue +# if df: +# # Standardize column names +# if 'Peptide' in df.columns and 'peptide' not in df.columns: +# df['peptide'] = df['Peptide'] +# elif 'Peptide' in df.columns and 'peptide' in df.columns: +# df['peptide'] = df['peptide'].fillna(df['Peptide']) +# if 'MHC' in df.columns and 'allele' not in df.columns: +# df['allele'] = df['MHC'] +# elif 'MHC' in df.columns and 'allele' in df.columns: +# df['allele'] = df['allele'].fillna(df['MHC']) +# # Calculate label from Score_EL if label is empty or doesn't exist +# if 'label' not in df.columns: +# if 'Score_BA' in df.columns: +# df['label'] = df['Score_BA'].apply(lambda x: 1 if eval(x) > 0.426 else 0) +# +# # if at least one label is not 1, select the highest score_BA and set the labels to 1, then save the allele in a list +# if 'cell_line_id' in df.columns and 'label' in df.columns: +# # Create a mask for cell lines with at least one positive label +# positive_cell_lines = df.groupby('cell_line_id')['label'].max() == 1 +# positive_cell_lines = positive_cell_lines[positive_cell_lines].index.tolist() +# # Filter rows +# mask = df['cell_line_id'].isin(positive_cell_lines) +# dropped_df = df[~mask].copy() +# df = df[mask].copy() +# # Save dropped rows to separate file if there are any +# if not dropped_df.empty: +# dropped_rows = pd.concat([dropped_rows, dropped_df]) +# +# # print the length of the dropped rows +# print(f"Length of dropped rows: {len(dropped_rows)}") + +# def combine_results(results_dir, final_output_path): +# # Combine all parquet files at the end +# try: +# all_parquet_files = [os.path.join(results_dir, f) for f in os.listdir(results_dir) if f.endswith('.parquet')] +# if all_parquet_files: +# print(f"Combining {len(all_parquet_files)} parquet files into final output") +# # Read all parquet files as pandas DataFrames +# dfs = [] +# for parquet_file in all_parquet_files: +# try: +# df = pq.read_table(parquet_file).to_pandas() +# dfs.append(df) +# except Exception as e: +# print(f"Error reading {parquet_file}: {str(e)}") +# continue +# if dfs: +# # Concatenate DataFrames (pandas handles schema differences automatically) +# df = pd.concat(dfs, ignore_index=True) +# # Standardize column names +# if 'Peptide' in df.columns and 'peptide' not in df.columns: +# df['peptide'] = df['Peptide'] +# elif 'Peptide' in df.columns and 'peptide' in df.columns: +# df['peptide'] = df['peptide'].fillna(df['Peptide']) +# if 'MHC' in df.columns and 'allele' not in df.columns: +# df['allele'] = df['MHC'] +# elif 'MHC' in df.columns and 'allele' in df.columns: +# df['allele'] = df['allele'].fillna(df['MHC']) +# # Calculate label from Score_EL if label is empty or doesn't exist +# if 'label' not in df.columns: +# if 'Score_BA' in df.columns: +# df['label'] = df['Score_BA'].apply(lambda x: 1 if eval(x) > 0.4 else 0) # Threshold can be adjusted +# else: +# df['label'] = None +# # Handle cell_line_id filtering +# if 'cell_line_id' in df.columns and 'label' in df.columns: +# # Create a mask for cell lines with at least one positive label +# positive_cell_lines = df.groupby('cell_line_id')['label'].max() == 1 +# positive_cell_lines = positive_cell_lines[positive_cell_lines].index.tolist() +# # Filter rows +# mask = df['cell_line_id'].isin(positive_cell_lines) +# dropped_df = df[~mask].copy() +# df = df[mask].copy() +# # Save dropped rows to separate file if there are any +# if not dropped_df.empty: +# dropped_output_path = final_output_path.replace('.parquet', '_dropped.parquet') +# pq.write_table(pa.Table.from_pandas(dropped_df), dropped_output_path) +# print( +# f"Dropped {dropped_df.shape[0]} rows from {len(set(dropped_df['cell_line_id']))} cell lines with no positive labels") +# print(f"Dropped rows saved to: {dropped_output_path}") +# print(f"Kept {df.shape[0]} rows from {len(set(df['cell_line_id']))} cell lines with at least one positive label") +# # Keep only required columns if they exist +# keep_cols = ["allele", "peptide", "label"] +# existing_cols = [col for col in keep_cols if col in df.columns] +# if existing_cols: +# df = df[existing_cols] +# # Save the combined DataFrame to parquet +# pq.write_table(pa.Table.from_pandas(df), final_output_path) +# print(f"Combined results shape: {df.shape}") +# print(f"Final output saved to: {final_output_path}") +# else: +# df = pd.DataFrame(columns=["allele", "peptide", "label"]) +# print("No valid DataFrames were created from parquet files") +# else: +# df = pd.DataFrame(columns=["allele", "peptide", "label"]) +# print("No parquet files were generated") +# except Exception as e: +# print(f"Error combining parquet files: {str(e)}") +# import traceback +# traceback.print_exc() +# df = pd.DataFrame(columns=["allele", "peptide", "label"]) +# return df + + +# def parallelize_netmhcpan(el_data, tmp_path, results_dir): +# run_netmhcpan_(el_data, 1, tmp_path, results_dir, 0) +# +# # with ProcessPoolExecutor() as executor: +# # tasks = [] +# # for allele in unique_cell_line: +# # subset_data = el_data_1[el_data_1["allele"] == allele] +# # tasks.append(executor.submit(run_netmhcpan_, subset_data, [allele], tmp_path, results_dir)) +# # for task in tasks: +# # task.result() + +def run_(arg1, arg2): + mhcII_path = "data/NetMHCpan_dataset/NetMHCIIpan_train/" + # mhcI_path = "data/NetMHCpan_dataset/NetMHCpan_train/" + tmp_path = "data/NetMHCpan_dataset/tmp_II/" + # tmp_path = "data/NetMHCpan_dataset/tmp_I/" + # results_dir = "data/NetMHCpan_dataset/results_I/" + results_dir = "data/NetMHCpan_dataset/results_II/" + + mhc_class = 2 + + # make tmp directory + if not os.path.isdir(tmp_path): + os.mkdir(tmp_path) + + # make results directory + if not os.path.isdir(results_dir): + os.mkdir(results_dir) + + if arg1 == "process_data": + ######################## + # Load data + el_data_0, el_data_1, ba_data = process_data( + mhc_path=mhcI_path, + tmp_path=tmp_path, + mhc_type=mhc_class) + + # save the variables + el_data_0.to_csv(f"{tmp_path}/el_data_0.csv", index=False) + el_data_1.to_csv(f"{tmp_path}/el_data_1.csv", index=False) + ba_data.to_csv(f"{tmp_path}/ba_data.csv", index=False) + params = { + # "unique_cell_lines": unique_cell_lines.tolist() if hasattr(unique_cell_lines, "tolist") else unique_cell_lines, + "tmp_path": tmp_path, + "results_dir": results_dir + } + with open(f"{tmp_path}/params.json", "w") as f: + json.dump(params, f) + + # split el_data to 10 chunks and save separately + chunk_size = len(el_data_1) // 256 + for i in range(256): + el_data_1_chunk = el_data_1.iloc[i * chunk_size: (i + 1) * chunk_size] + el_data_1_chunk.to_csv(f"{tmp_path}/el_data_1_chunk_{i}.csv", index=False) + ######################## + + if arg1 == "run_netmhcpan": + # load the variables and parameters from JSON files + with open(f"{tmp_path}/params.json", "r") as f: + params = json.load(f) + # unique_alleles = np.array(params["unique_alleles"]) + tmp_path = params["tmp_path"] + results_dir = params["results_dir"] + + # load the variables + el_data_1 = pd.read_csv( + f"{tmp_path}/el_data_1_chunk_{arg2}.csv") + el_data_0 = pd.read_csv(f"{tmp_path}/el_data_0.csv") + + # print the len of el_data_1 + print(f"Length of el_data_1: {len(el_data_1)}") + + # print the len of el_data_0 + print(f"Length of el_data_0: {len(el_data_0)}") + + # subset 1000 random samples from the data for testing + # el_data_1 = el_data_1.sample(n=1000, random_state=42) + # unique_cell_lines = el_data_1["cell_line_id"].unique() + + # Run NetMHCpan in parallel + # parallelize_netmhcpan(el_data_1, tmp_path, results_dir) + + # run the netmhcpan for the el_data_1 + # run_netmhcpan_(el_data_1, 1, tmp_path, results_dir, arg2, mhc_class=mhc_class) + + # run the netmhcpan for the el_data_0 + # run_netmhcpan_(el_data_0, 0, tmp_path, results_dir, arg2) + + if arg1 == "combine_results": + # Combine results + df = combine_datasets_(results_dir) + + el_data_0 = pd.read_csv(f"{tmp_path}/el_data_0.csv") + ba_data = pd.read_csv(f"{tmp_path}/ba_data.csv") + + # add labels to ba_data with threshold of 0.426 + ba_data["label"] = ba_data["label"].apply(lambda x: 1 if float(x) >= 0.426 else 0) + ba_data.rename(columns={'label': 'assigned_label'}, inplace=True) + + # explode the el_data_0['allele'] column (we do this because all samples in the el_data_0 have the same label with 100% confidence) + el_data_0["allele"] = el_data_0["allele"].apply(lambda x: eval(x)) + el_data_0 = el_data_0.explode("allele") + el_data_0.rename(columns={'label': 'assigned_label'}, inplace=True) + + # concatenate the dataframes df and el_data_0 and ba_data + df = pd.concat([df, el_data_0], ignore_index=True) + df = pd.concat([df, ba_data], ignore_index=True) + + # add mhc class column + df["mhc_class"] = mhc_class + + # Standardize column names + if 'Peptide' in df.columns and 'peptide' not in df.columns: + df['peptide'] = df['Peptide'] + df.drop(columns=['Peptide'], inplace=True) + elif 'Peptide' in df.columns and 'peptide' in df.columns: + df['peptide'] = df['peptide'].fillna(df['Peptide']) + df.drop(columns=['Peptide'], inplace=True) + if 'MHC' in df.columns and 'allele' not in df.columns: + df['allele'] = df['MHC'] + df.drop(columns=['MHC'], inplace=True) + elif 'MHC' in df.columns and 'allele' in df.columns: + df['allele'] = df['allele'].fillna(df['MHC']) + df.drop(columns=['MHC'], inplace=True) + if "peptide" in df.columns: + if "long_mer" in df.columns: + df["peptide"] = df["peptide"].fillna(df["long_mer"]) + df.drop(columns=["long_mer"], inplace=True) + # rename the peptide to long_mer + df.rename(columns={"peptide": "long_mer"}, inplace=True) + + # drop duplicates + print(f"Before dropping duplicates: {df.shape}") + df = df.drop_duplicates(subset=["allele", "long_mer"], keep="first") + print(f"After dropping duplicates: {df.shape}") + + # add mhc_sequence column + df = add_mhc_sequence_column(df) + + # save the final output + df.to_csv(f"data/NetMHCpan_dataset/combined_data_{mhc_class}.csv", index=False) + +def main(): + if len(sys.argv) < 1: + print("Usage: python script1.py (only required when running run_netmhcpan)") + sys.exit(1) + + if sys.argv[1] not in ["process_data", "run_netmhcpan", "combine_results"]: + print("Invalid argument. Please specify one of the following: process_data, run_netmhcpan, combine_results") + + if sys.argv[1] == "run_netmhcpan": + run_(sys.argv[1], sys.argv[2]) + + else: + run_(sys.argv[1], None) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/run_pMHC_DL.py b/run_pMHC_DL.py new file mode 100644 index 00000000..cace4349 --- /dev/null +++ b/run_pMHC_DL.py @@ -0,0 +1,2720 @@ +import json +import os +import uuid + +import tensorflow as tf +import numpy as np +import pandas as pd +import time +import matplotlib.pyplot as plt +from sklearn.metrics import roc_curve, precision_recall_curve +from sklearn.model_selection import train_test_split +from sklearn.cluster import KMeans +import seaborn as sns +from sklearn.decomposition import PCA +from collections import Counter + +from utils.model import SCQ1DAutoEncoder, MoEModel, BinaryMLP, TabularTransformer, EmbeddingCNN, EnhancedMoEModel + +''' +# Set TensorFlow logging level for more information +tf.get_logger().setLevel('INFO') + + +# def create_dataset(X, batch_size=1, is_training=True): +# """Create TensorFlow dataset with consistent approach for both training and validation.""" +# # Debug info +# print(f"Creating {'training' if is_training else 'validation'} dataset") +# print(f"Data shape: {X.shape}") +# print(f"Data dtype: {X.dtype}") +# print(f"Data range: min={np.min(X):.4f}, max={np.max(X):.4f}, mean={np.mean(X):.4f}") +# +# # Ensure data is float32 (TensorFlow works best with float32) +# X = X.astype(np.float32) +# +# dataset = tf.data.Dataset.from_tensor_slices(X) +# +# # Apply shuffling only for training data +# if is_training: +# dataset = dataset.shuffle(buffer_size=1000) +# +# # Apply batching with drop_remainder=False to handle all data +# dataset = dataset.batch(batch_size) +# +# # Prefetch for better performance +# dataset = dataset.prefetch(tf.data.AUTOTUNE) +# +# # Debug dataset +# for batch in dataset.take(1): +# print(f"Sample batch shape: {batch.shape}") +# print(f"Sample batch dtype: {batch.dtype}") +# print(f"Sample batch range: min={tf.reduce_min(batch):.4f}, max={tf.reduce_max(batch):.4f}") +# +# return dataset + +def create_dataset(X, batch_size=1, is_training=True): + """ + simple function to create a TensorFlow dataset. + Args: + X: + batch_size: + is_training: + + Returns: + + """ + X = tf.convert_to_tensor(X, dtype=tf.float32) + + # Create dataset from tensor slices + dataset = tf.data.Dataset.from_tensor_slices(X) + + # Apply shuffling only for training data + if is_training: + dataset = dataset.shuffle(buffer_size=min(1000, len(X))) + + # Apply batching (drop_remainder=False to handle all data) + dataset = dataset.batch(batch_size) + + # Prefetch for better performance + dataset = dataset.prefetch(tf.data.AUTOTUNE) + + return dataset + +def plot_metrics(history, save_path='training_metrics.png'): + """ + Visualize training metrics over epochs. + + Args: + history: A Keras History object or dictionary containing training history + save_path: Path to save the plot image + """ + # Handle both Keras History objects and dictionaries + history_dict = history.history if hasattr(history, 'history') else history + + # Print available keys to debug + print(f"Available keys in history: {list(history_dict.keys())}") + + metrics = ['loss', 'recon', 'vq', 'perplexity'] + # Check for standard Keras naming (total_loss instead of loss, etc.) + metric_mapping = { + 'loss': ['loss', 'total_loss'], + 'recon': ['recon', 'recon_loss'], + 'vq': ['vq', 'vq_loss'], + 'perplexity': ['perplexity'] + } + + fig, axes = plt.subplots(len(metrics), 1, figsize=(12, 3*len(metrics)), sharex=True) + + # Handle case with only one metric + if len(metrics) == 1: + axes = [axes] + + for i, metric_base in enumerate(metrics): + ax = axes[i] + + # Try different possible metric names + for metric in metric_mapping[metric_base]: + # Plot training metric + if metric in history_dict: + ax.plot(history_dict[metric], 'b-', label=f'Train {metric}') + print(f"Plotting {metric} with {len(history_dict[metric])} points") + break + + # Try different possible validation metric names + for metric in metric_mapping[metric_base]: + val_metric = f'val_{metric}' + if val_metric in history_dict: + ax.plot(history_dict[val_metric], 'r-', label=f'Validation {metric}') + print(f"Plotting {val_metric} with {len(history_dict[val_metric])} points") + break + + ax.set_title(f'{metric_base.capitalize()} over epochs') + ax.set_ylabel('Value') + ax.grid(True) + ax.legend(loc='best') + + plt.xlabel('Epochs') + plt.tight_layout() + + # Save the figure in case display doesn't work + plt.savefig(save_path) + print(f"Plot saved to {save_path}") + + # Try to display + try: + plt.show() + except Exception as e: + print(f"Could not display plot: {e}") + + +def main(): + print("Starting SCQ parameter search...") + + try: + print("Loading peptide embeddings...") + mhc1_pep2vec_embeddings = pd.read_parquet("data/Pep2Vec/wrapper_mhc1.parquet") + mhc2_pep2vec_embeddings = pd.read_parquet("data/Pep2Vec/wrapper_mhc2.parquet") + + # Select the latent columns (the columns that has latent in their name) + mhc1_latent_columns = [col for col in mhc1_pep2vec_embeddings.columns if 'latent' in col] + mhc2_latent_columns = [col for col in mhc2_pep2vec_embeddings.columns if 'latent' in col] + + print(f"Found {len(mhc1_latent_columns)} latent columns for MHC1") + print(f"Found {len(mhc2_latent_columns)} latent columns for MHC2") + + if len(mhc1_latent_columns) == 0 or len(mhc2_latent_columns) == 0: + print("WARNING: No latent columns found. Check column names.") + + # Extract latent features + X_mhc1 = mhc1_pep2vec_embeddings[mhc1_latent_columns].values + X_mhc2 = mhc2_pep2vec_embeddings[mhc2_latent_columns].values + + print(f"MHC1 data shape: {X_mhc1.shape}") + print(f"MHC2 data shape: {X_mhc2.shape}") + + # Data sanity check + print("Data overview:") + print( + f"MHC1 - min: {np.min(X_mhc1):.4f}, max: {np.max(X_mhc1):.4f}, mean: {np.mean(X_mhc1):.4f}, std: {np.std(X_mhc1):.4f}") + print( + f"MHC2 - min: {np.min(X_mhc2):.4f}, max: {np.max(X_mhc2):.4f}, mean: {np.mean(X_mhc2):.4f}, std: {np.std(X_mhc2):.4f}") + + # Check for NaN or infinity values + print(f"MHC1 has NaN: {np.isnan(X_mhc1).any()}, has inf: {np.isinf(X_mhc1).any()}") + print(f"MHC2 has NaN: {np.isnan(X_mhc2).any()}, has inf: {np.isinf(X_mhc2).any()}") + + # Replace any NaN values with zeros + if np.isnan(X_mhc1).any(): + print("Replacing NaN values in MHC1 with zeros") + X_mhc1 = np.nan_to_num(X_mhc1) + + if np.isnan(X_mhc2).any(): + print("Replacing NaN values in MHC2 with zeros") + X_mhc2 = np.nan_to_num(X_mhc2) + + except Exception as e: + print(f"Error loading data: {e}") + return + + # Create output directory for results + output_dir = "output/scq_parameter_search" + os.makedirs(output_dir, exist_ok=True) + + # Define parameter search space - simplified to just one configuration for testing + param_grid = [ + # Explore different codebook_num + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 8, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 32, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 64, 'heads': 4}, + {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 128, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 256, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 512, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 1024, 'heads': 4}, + # # Explore codebook_dim + # {'general_embed_dim': 128, 'codebook_dim': 32, 'codebook_num': 16, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 64, 'codebook_num': 16, 'heads': 4}, + # {'general_embed_dim': 128, 'codebook_dim': 128, 'codebook_num': 16, 'heads': 4}, + # # Explore heads + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 2}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 8}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 16}, + # {'general_embed_dim': 128, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 32}, + # # Explore general_embed_dim + # {'general_embed_dim': 64, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 4}, + # {'general_embed_dim': 256, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 4}, + # {'general_embed_dim': 512, 'codebook_dim': 16, 'codebook_num': 16, 'heads': 4}, + ] + + # Common parameters for all configurations + common_params = { + 'descrete_loss': True, + 'weight_recon': 1.0, + 'weight_vq': 1.0, + } + + # Set up cross-validation - reduced to 2 folds for testing + n_folds = 2 + kf = KFold(n_splits=n_folds, shuffle=True, random_state=42) + + # Batch size for all datasets + batch_size = 1 + + # Store results + results = [] + + # Run parameter search for each dataset + # Just use MHC1 for testing + for dataset_name, X, latent_columns in [ + ("MHC1", X_mhc1, mhc1_latent_columns), + ("MHC2", X_mhc2, mhc2_latent_columns) + ]: + print(f"\n{'=' * 50}") + print(f"Processing {dataset_name} dataset") + print(f"{'=' * 50}\n") + + # Create directory for this dataset + dataset_dir = os.path.join(output_dir, dataset_name) + os.makedirs(dataset_dir, exist_ok=True) + + # Model selection loop + for param_idx, param_config in enumerate(param_grid): + config_name = f"config_{param_idx + 1}" + print(f"\n{'-' * 50}") + print(f"Testing {config_name}: {param_config}") + print(f"{'-' * 50}\n") + + # Merge with common parameters + model_params = {**common_params, **param_config} + + # Create config directory + config_dir = os.path.join(dataset_dir, config_name) + os.makedirs(config_dir, exist_ok=True) + + # Cross-validation metrics + cv_metrics = { + 'loss': [], + 'recon': [], + 'vq': [], + 'perplexity': [] + } + + # Prepare folds + fold_indices = list(kf.split(X)) + + for fold_idx, (train_idx, val_idx) in enumerate(fold_indices): + print(f"Training fold {fold_idx + 1}/{n_folds}") + + # Split data + X_train, X_val = X[train_idx], X[val_idx] + + # Check data shapes + print(f"Train data shape: {X_train.shape}") + print(f"Validation data shape: {X_val.shape}") + + # Convert to TensorFlow dataset using the new function + train_dataset = create_dataset(X_train, batch_size=batch_size, is_training=True) + val_dataset = create_dataset(X_val, batch_size=batch_size, is_training=False) + + # Define the model + model = SCQ_model(**model_params, input_dim=X_train.shape[1]) + + # Print model summary + model.build((None, X_train.shape[1])) + print("Model summary:") + model.summary() + + # Compile the model + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001)) + + # Prepare callbacks + fold_dir = os.path.join(config_dir, f"fold_{fold_idx}") + os.makedirs(fold_dir, exist_ok=True) + + # Create a custom callback to print metrics after each epoch + class MetricsPrinter(tf.keras.callbacks.Callback): + def on_epoch_end(self, epoch, logs=None): + logs = logs or {} + print(f"\nEpoch {epoch + 1} metrics:") + for metric, value in logs.items(): + print(f" {metric}: {value:.6f}") + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(fold_dir, 'checkpoint'), + save_best_only=True, + monitor='val_total_loss', + mode='min' + ), + tf.keras.callbacks.EarlyStopping( + monitor='val_total_loss', + patience=5, + restore_best_weights=True + ), + MetricsPrinter() # Add the custom callback + ] + + # Train the model with fewer epochs for testing + print("Starting model training...") + try: + history = model.fit( + train_dataset, + epochs=30, # Reduced for testing + validation_data=val_dataset, + callbacks=callbacks, + verbose=1 + ) + + # Check if history contains expected metrics + print("\nHistory keys:", history.history.keys()) + for key, values in history.history.items(): + print(f"{key}: {values}") + + # plot + plot_metrics(history) + + except Exception as e: + print(f"Error during training: {e}") + continue + + # Save training history + try: + with open(os.path.join(fold_dir, 'history.json'), 'w') as f: + json.dump(history.history, f) + except Exception as e: + print(f"Error saving history: {e}") + + # Reset metrics before evaluation + print("Resetting metrics before evaluation") + model.reset_metrics() + + # Evaluate on validation set + print("Evaluating model on validation data...") + # After getting validation metrics + try: + val_metrics = model.evaluate(val_dataset, return_dict=True) + print("Validation metrics:", val_metrics) + + # Map the metric names + metric_mapping = { + 'loss': 'total_loss', + 'recon': 'recon_loss', + 'vq': 'vq_loss', + 'perplexity': 'perplexity' + } + + # Record metrics with mapping + for returned_name, expected_name in metric_mapping.items(): + if returned_name in val_metrics: + cv_metrics[expected_name].append(val_metrics[returned_name]) + else: + cv_metrics[expected_name].append(0) + + except Exception as e: + print(f"Error during evaluation: {e}") + # Rest of your error handling + + # Record metrics + for metric in cv_metrics: + cv_metrics[metric].append(val_metrics[metric]) + + # Save fold metrics + with open(os.path.join(fold_dir, 'metrics.json'), 'w') as f: + json.dump(val_metrics, f, indent=2) + + # Calculate average metrics across folds + avg_metrics = {k: np.mean(v) for k, v in cv_metrics.items()} + std_metrics = {k: np.std(v) for k, v in cv_metrics.items()} + + # Report results + print(f"\nAverage metrics for {config_name}:") + for metric, value in avg_metrics.items(): + print(f"{metric}: {value:.6f} (±{std_metrics[metric]:.6f})") + + # Store results + result = { + 'dataset': dataset_name, + 'config': config_name, + **model_params, + **avg_metrics, + **{f"{k}_std": v for k, v in std_metrics.items()}, + 'cv_metrics': cv_metrics + } + results.append(result) + + # Save config results + with open(os.path.join(config_dir, 'results.json'), 'w') as f: + json.dump(result, f, indent=2) + + # Save all results to CSV + results_df = pd.DataFrame([{k: v for k, v in r.items() if k != 'cv_metrics'} for r in results]) + results_df.to_csv(os.path.join(output_dir, 'all_results.csv'), index=False) + + print("\nParameter search complete!") + print(f"Results saved to {output_dir}") + +import os +import traceback + +def simple_run(batch_size=1): + """A simplified run using the MHC dataset instead of random data.""" + print("Starting simplified SCQ model run on MHC data...") + + try: + # Create output directory for results + output_dir = "output/scq_simple_run" + os.makedirs(output_dir, exist_ok=True) + + # Load MHC data + print("Loading peptide embeddings...") + mhc1_pep2vec_embeddings = pd.read_parquet("data/Pep2Vec/wrapper_mhc1.parquet") + + # Select the latent columns + latent_columns = [col for col in mhc1_pep2vec_embeddings.columns if 'latent' in col] + print(f"Found {len(latent_columns)} latent columns for MHC1") + + # Extract latent features + X = mhc1_pep2vec_embeddings[latent_columns].values + print(f"MHC1 data shape: {X.shape}") + + # Data sanity check + print("Data overview:") + print(f"MHC1 - min: {np.min(X):.4f}, max: {np.max(X):.4f}, mean: {np.mean(X):.4f}, std: {np.std(X):.4f}") + + # Replace any NaN values with zeros + if np.isnan(X).any(): + print("Replacing NaN values in MHC1 with zeros") + X = np.nan_to_num(X) + + # Split data into train and test sets (80/20 split) + from sklearn.model_selection import train_test_split + X_train, X_test = train_test_split(X, test_size=0.2, random_state=42) + + # Create datasets + train_dataset = create_dataset(X_train, batch_size=batch_size, is_training=True) + test_dataset = create_dataset(X_test, batch_size=batch_size, is_training=False) + + # Initialize model with dimensions matching the input data + model = SCQ_model( + general_embed_dim=128, + codebook_dim=16, + codebook_num=8, + descrete_loss=True, + heads=4, + input_dim=X_train.shape[1] + ) + + # Print model summary + model.build(input_shape=(None, X_train.shape[1])) + print("Model summary:") + model.summary() + + # Compile model + model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001)) + + # Train with early stopping + callbacks = [ + tf.keras.callbacks.EarlyStopping(monitor='loss', patience=5), + tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(output_dir, 'model_checkpoint.weights.h5'), + save_best_only=True, + save_weights_only=True, # Only save weights, not the full model + monitor='loss' + ) + ] + + # Train model and capture history + history = model.fit( + train_dataset, + epochs=1000, # Reduced from 1000 for quicker testing + callbacks=callbacks, + verbose=1 + ) + + # Plot training metrics + plot_metrics(history, save_path=os.path.join(output_dir, 'training_metrics.png')) + + # Save training history + pd.DataFrame(history.history).to_csv(os.path.join(output_dir, "model_history.csv"), index=False) + + # Test model on test data + print("Evaluating model on test data...") + evaluation = model.evaluate(test_dataset, return_dict=True) + print("Test metrics:", evaluation) + + # Generate and save example outputs + for batch in test_dataset.take(1): + output = model(batch) + # Save sample input and output + np.save(os.path.join(output_dir, "sample_input.npy"), batch.numpy()) + np.savez( + os.path.join(output_dir, "sample_output.npz"), + decoded=output[0].numpy(), + zq=output[1].numpy(), + pj=output[2].numpy() + ) + + print(f"Training complete. Results saved to {output_dir}") + + except Exception as e: + print(f"Error in simple_run: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + # try: + # main() + # except Exception as e: + # print(f"Error in main function: {e}") + # import traceback + # + # traceback.print_exc() + # simple run + simple_run()''' + + +'''import os +import numpy as np +import pandas as pd +import tensorflow as tf +from tensorflow import keras +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +from sklearn.metrics import mean_squared_error +import itertools +from utils.model import SCQ_model + + +def create_dataset(X, batch_size=4, is_training=True): + """Create a TensorFlow dataset from input array X.""" + dataset = tf.data.Dataset.from_tensor_slices(X) + if is_training: + dataset = dataset.shuffle(buffer_size=len(X)) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + return dataset + + +def load_data(data_path, test_size=0.2, random_state=42): + """Load and prepare data for training.""" + # Load data + embeddings = pd.read_parquet(data_path) + + # Select the latent columns + latent_columns = [col for col in embeddings.columns if 'latent' in col] + print(f"Found {len(latent_columns)} latent columns") + + # Extract values and reshape + X = embeddings[latent_columns].values + seq_length = len(latent_columns) + feature_dim = 1 + X = X.reshape(-1, seq_length, feature_dim) + + # Split into train and test sets + X_train, X_test = train_test_split(X, test_size=test_size, random_state=random_state) + + return X_train, X_test, seq_length, feature_dim + + +def train_scq_model(X_train, X_test, feature_dim, general_embed_dim, codebook_dim, + codebook_num, batch_size=4, epochs=20, learning_rate=0.001, + heads=8, descrete_loss=False, output_dir="test_tmp"): + """Train an SCQ model with the given parameters.""" + # Create directories if they don't exist + os.makedirs(output_dir, exist_ok=True) + + # Create datasets + train_dataset = create_dataset(X_train, batch_size=batch_size, is_training=True) + test_dataset = create_dataset(X_test, batch_size=batch_size, is_training=False) + + # Initialize SCQ model + model = SCQ_model(input_dim=int(feature_dim), + general_embed_dim=int(general_embed_dim), + codebook_dim=int(codebook_dim), + codebook_num=int(codebook_num), + descrete_loss=descrete_loss, + heads=int(heads)) + + # Compile model + model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate)) + + # Train model and capture history + history = model.fit(train_dataset, epochs=epochs, batch_size=batch_size) + + # Save training history + history_df = pd.DataFrame(history.history) + history_path = os.path.join(output_dir, f"model_history_cb{codebook_num}_emb{general_embed_dim}.csv") + history_df.to_csv(history_path, index=False) + + # Evaluate model on test data + decoded_outputs, zq_outputs, pj_outputs = evaluate_model(model, test_dataset) + + # Save outputs + output_path = os.path.join(output_dir, f"output_data_cb{codebook_num}_emb{general_embed_dim}.npz") + np.savez(output_path, + decoded=np.vstack(decoded_outputs), + zq=np.vstack(zq_outputs), + pj=np.vstack(pj_outputs)) + + # Calculate metrics + mse = calculate_reconstruction_mse(X_test, np.vstack(decoded_outputs)) + + return model, history, mse + + +def evaluate_model(model, test_dataset): + """Evaluate the model on test data.""" + decoded_outputs = [] + zq_outputs = [] + pj_outputs = [] + + for batch in test_dataset: + output = model(batch) + decoded_outputs.append(output[0].numpy()) + zq_outputs.append(output[1].numpy()) + pj_outputs.append(output[2].numpy()) + + return decoded_outputs, zq_outputs, pj_outputs + + +def calculate_reconstruction_mse(X_test, decoded_output): + """Calculate mean squared error between input and reconstructed output.""" + # Reshape if necessary to match dimensions + if X_test.shape != decoded_output.shape: + # Adjust shapes as needed based on your model's output + pass + + return mean_squared_error(X_test.reshape(-1), decoded_output.reshape(-1)) + + +def plot_training_history(history_df, title="Training History", output_path=None): + """Plot training metrics from history dataframe.""" + plt.figure(figsize=(12, 6)) + + # Plot all metrics in the history + for column in history_df.columns: + plt.plot(history_df[column], label=column) + + plt.title(title) + plt.xlabel('Epoch') + plt.ylabel('Value') + plt.legend() + plt.grid(True) + + if output_path: + plt.savefig(output_path) + plt.close() + else: + plt.show() + + +def plot_reconstruction_comparison(original, reconstructed, n_samples=5, output_path=None): + """Plot comparison between original and reconstructed sequences.""" + # Randomly select n_samples sequences to visualize + indices = np.random.choice(len(original), size=min(n_samples, len(original)), replace=False) + + plt.figure(figsize=(15, 3 * n_samples)) + + for i, idx in enumerate(indices): + # Plot original sequence + plt.subplot(n_samples, 2, 2 * i + 1) + plt.plot(original[idx].flatten()) + plt.title(f"Original Sequence {idx}") + plt.grid(True) + + # Plot reconstructed sequence + plt.subplot(n_samples, 2, 2 * i + 2) + plt.plot(reconstructed[idx].flatten()) + plt.title(f"Reconstructed Sequence {idx}") + plt.grid(True) + + plt.tight_layout() + + if output_path: + plt.savefig(output_path) + plt.close() + else: + plt.show() + + +def parameter_search(X_train, X_test, feature_dim, batch_size=4, epochs=20, learning_rate=0.001, + codebook_nums=[4,8,16,32,64,128,256,512,1024], embed_dims=[64, 128, 256], + codebook_dim=21, heads=8, output_dir="parameter_search"): + """ + Perform grid search over codebook_num and general_embed_dim parameters. + Returns the best parameters based on reconstruction MSE. + """ + os.makedirs(output_dir, exist_ok=True) + + results = [] + + # Create all combinations of parameters + param_combinations = list(itertools.product(codebook_nums, embed_dims)) + total_combinations = len(param_combinations) + + print(f"Starting parameter search with {total_combinations} combinations...") + + for i, (codebook_num, embed_dim) in enumerate(param_combinations): + print(f"Training combination {i + 1}/{total_combinations}: codebook_num={codebook_num}, embed_dim={embed_dim}") + + try: + # Train the model with this parameter combination + model, history, mse = train_scq_model( + X_train, X_test, feature_dim, + general_embed_dim=embed_dim, + codebook_dim=codebook_dim, + codebook_num=codebook_num, + batch_size=batch_size, + epochs=epochs, + learning_rate=learning_rate, + heads=heads, + output_dir=output_dir + ) + + # Store results + results.append({ + 'codebook_num': codebook_num, + 'general_embed_dim': embed_dim, + 'mse': mse, + 'history': history + }) + + # Plot training history + history_df = pd.DataFrame(history.history) + plot_training_history( + history_df, + title=f"Training History (codebook_num={codebook_num}, embed_dim={embed_dim})", + output_path=os.path.join(output_dir, f"history_plot_cb{codebook_num}_emb{embed_dim}.png") + ) + + except Exception as e: + print(f"Error training with codebook_num={codebook_num}, embed_dim={embed_dim}: {e}") + + # Create results dataframe + results_df = pd.DataFrame([(r['codebook_num'], r['general_embed_dim'], r['mse']) + for r in results], + columns=['codebook_num', 'general_embed_dim', 'mse']) + + # Save results + results_df.to_csv(os.path.join(output_dir, "parameter_search_results.csv"), index=False) + + # Find best parameters + best_idx = results_df['mse'].idxmin() + best_params = results_df.loc[best_idx] + + print(f"Parameter search complete. Best parameters:") + print(f" codebook_num: {best_params['codebook_num']}") + print(f" general_embed_dim: {best_params['general_embed_dim']}") + print(f" MSE: {best_params['mse']}") + + # Plot results heatmap + plot_parameter_search_results(results_df, output_dir) + + return best_params, results_df + + +def plot_parameter_search_results(results_df, output_dir): + """Plot heatmap of parameter search results.""" + # Create pivot table for heatmap + pivot_df = results_df.pivot(index='codebook_num', columns='general_embed_dim', values='mse') + + plt.figure(figsize=(10, 8)) + plt.imshow(pivot_df, cmap='viridis_r') # Reverse colormap so darker is better (lower MSE) + + # Set labels + plt.colorbar(label='MSE (lower is better)') + plt.title('Parameter Search Results') + plt.xlabel('General Embedding Dimension') + plt.ylabel('Codebook Number') + + # Set tick labels + plt.xticks(range(len(pivot_df.columns)), pivot_df.columns) + plt.yticks(range(len(pivot_df.index)), pivot_df.index) + + # Add text annotations + for i in range(len(pivot_df.index)): + for j in range(len(pivot_df.columns)): + value = pivot_df.iloc[i, j] + if not np.isnan(value): + plt.text(j, i, f'{value:.4f}', ha='center', va='center', + color='white' if value > pivot_df.values.mean() else 'black') + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "parameter_search_heatmap.png")) + plt.close() + + +def run_scq_pipeline(data_path, output_dir="test_tmp", run_param_search=True, + batch_size=4, epochs=3, learning_rate=0.001, + codebook_num=5, general_embed_dim=128, codebook_dim=32, heads=8): + """ + Main function to run the complete SCQ pipeline with optional parameter search. + """ + # Load and prepare data + X_train, X_test, seq_length, feature_dim = load_data(data_path) + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Save input data + np.save(os.path.join(output_dir, "input_data.npy"), X_test) + + if run_param_search: + # Run parameter search to find optimal values + best_params, results_df = parameter_search( + X_train, X_test, feature_dim, + batch_size=batch_size, + epochs=epochs, + learning_rate=learning_rate, + output_dir=os.path.join(output_dir, "param_search") + ) + + # Use best parameters for final model - ensure they are integers + codebook_num = int(best_params['codebook_num']) + general_embed_dim = int(best_params['general_embed_dim']) + + # Train final model with selected/best parameters + print(f"Training final model with codebook_num={codebook_num}, general_embed_dim={general_embed_dim}") + final_model, final_history, final_mse = train_scq_model( + X_train, X_test, feature_dim, + general_embed_dim=general_embed_dim, + codebook_dim=codebook_dim, + codebook_num=codebook_num, + batch_size=batch_size, + epochs=epochs, + learning_rate=learning_rate, + heads=heads, + output_dir=output_dir + ) + + # Evaluate the final model + test_dataset = create_dataset(X_test, batch_size=batch_size, is_training=False) + decoded_outputs, zq_outputs, pj_outputs = evaluate_model(final_model, test_dataset) + + # Create visualizations + history_df = pd.DataFrame(final_history.history) + plot_training_history( + history_df, + title=f"Final Model Training History (codebook_num={codebook_num}, embed_dim={general_embed_dim})", + output_path=os.path.join(output_dir, "final_model_history_plot.png") + ) + + plot_reconstruction_comparison( + X_test, + np.vstack(decoded_outputs), + n_samples=5, + output_path=os.path.join(output_dir, "reconstruction_comparison.png") + ) + + print(f"Pipeline completed successfully.") + print(f"Final model MSE: {final_mse}") + print(f"Final model parameters: codebook_num={codebook_num}, general_embed_dim={general_embed_dim}") + print(f"Results saved to {output_dir}") + + return final_model, final_mse, (codebook_num, general_embed_dim) + + +# Example usage: +if __name__ == "__main__": + # Example usage of the pipeline + data_path = "data/Pep2Vec/wrapper_mhc1.parquet" + + # # Run full pipeline with parameter search + # model, mse, best_params = run_scq_pipeline( + # data_path=data_path, + # output_dir="scq_results", + # run_param_search=True, + # batch_size=4, + # epochs=20 + # ) + + # Alternatively, run with specific parameters (no search) + model, mse, params = run_scq_pipeline( + data_path=data_path, + output_dir="scq_results_fixed", + run_param_search=False, + codebook_num=16, + codebook_dim=64, + heads=8, + batch_size=1, + general_embed_dim=256, + epochs=20 + )''' + + +# TODO implement a simple training pipeline for VQUnet +def train_and_evaluate_scqvae( + data_path, + val_data_path=None, + input_type='latent1024', # Options: 'latent1024', 'pMHC-sequence' + num_embeddings=32, + embedding_dim=64, # Increased from 32 + batch_size=32, # Increased from 1 + epochs=20, # Increased from 10 + learning_rate=1e-5, # Adjusted from 1e-4 + commitment_beta=0.25, + output_dir='data/SCQvae', + visualize=True, + save_model=True, + init_k_means=False, + random_state=42, + output_data="train_val_seperate", # Options: "val", "train", "train_val_seperate" + **kwargs +): + """ + Train and evaluate a VQ-VAE model on peptide embedding data with improved codebook utilization. + + Args: + data_path: Path to the parquet file containing peptide embeddings + input_type: Type of input data ('latent1024' or 'pMHC-sequence', 'attention51') + num_embeddings: Number of clusters/codes in the codebook + embedding_dim: Dimension of each codebook vector + batch_size: Batch size for training + epochs: Number of training epochs + learning_rate: Learning rate for the optimizer + commitment_beta: Beta parameter for commitment loss + output_dir: Directory to save outputs + visualize: Whether to generate visualizations + save_model: Whether to save model weights + random_state: Random seed for reproducibility + test_size: Fraction of data to use for validation + + Returns: + model: Trained model + history: Training history + latent_data: Dictionary containing latent representations and indices + """ + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # --- Helper Functions --- + def create_dataset(X, y=None, batch_size=32, is_training=True): + """Create a TensorFlow dataset from input array X and optional labels y.""" + # X shape: (num_samples, seq_length) + # y shape: (num_samples, ...) or None + if y is not None: + # Create dataset with both features and labels, ensuring features are float32 + dataset = tf.data.Dataset.from_tensor_slices((tf.cast(X, tf.float32), y)) + else: + # Create dataset with just features, ensuring features are float32 + dataset = tf.data.Dataset.from_tensor_slices(tf.cast(X, tf.float32)) + + if is_training: + # dataset = dataset.shuffle(buffer_size=len(X)) # Shuffle all samples + pass + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(tf.data.AUTOTUNE) # Prefetch for better performance + return dataset + + # def load_data(data_path, val_data_path=None, test_size=0.2, random_state=42, input_type='latent1024'): + # """ + # Load and prepare data for training with simplified validation. + # + # Args: + # data_path: Path to training data parquet file + # val_data_path: Optional path to validation data + # test_size: Fraction of data to use for validation if val_data_path is None + # random_state: Random seed for reproducibility + # input_type: Type of input data ('latent1024' or 'pMHC-sequence', 'attention51') + # + # Returns: + # X_train: Training data features with unique indices + # X_test: Test/validation data features with unique indices + # X_val: Validation data features with unique indices + # y_train: Training data labels + # y_test: Test/validation data labels + # y_val: Validation data labels + # seq_length: Length of each sequence including the index feature + # """ + # print(f"Loading training data from {data_path}") + # df = pd.read_parquet(data_path) + # + # # Add original indices to track samples + # df['original_idx'] = np.arange(len(df)) + # + # # Handle different input types + # if input_type == 'latent1024': + # columns = [col for col in df.columns if 'latent' in col] + # if not columns: + # raise ValueError("No latent columns found in the dataset") + # seq_length = len(columns) + # print(f"Found {seq_length} latent columns") + # elif input_type == 'pMHC-sequence': + # # Only check for peptide column as mhc_sequence is now optional + # if 'peptide' not in df.columns: + # raise ValueError("Required column 'peptide' not found in the dataset") + # columns = ['peptide'] + # if 'mhc_sequence' in df.columns: + # columns.append('mhc_sequence') + # print(f"Using sequence columns: {columns}") + # # Note: pMHC-sequence processing would be implemented here + # raise NotImplementedError("pMHC-sequence input not yet implemented") + # elif input_type == 'attention51': + # # Check for attention columns + # columns = [col for col in df.columns if 'attn_' in col] + # if not columns: + # raise ValueError("No attention columns found in the dataset") + # seq_length = len(columns) + # print(f"Found {seq_length} attention columns") + # else: + # raise ValueError(f"Unknown input_type: {input_type}. Expected 'latent1024' or 'pMHC-sequence'") + # + # label_col = 'binding_label' + # y = df[label_col].values + # + # # Extract original indices to ensure uniqueness across splits + # original_indices = df['original_idx'].values + # + # # Extract features and clean data + # X = df[columns].values + # if np.isnan(X).any() or np.isinf(X).any(): + # print("Warning: Dataset contains NaN or Inf values. Replacing with zeros.") + # X = np.nan_to_num(X) + # + # # Add index as an additional feature + # X_with_idx = np.column_stack((original_indices.reshape(-1, 1), X)) + # + # # Reshape with the added index column + # seq_length_with_idx = seq_length + 1 + # X_with_idx = X_with_idx.reshape(-1, seq_length_with_idx) + # + # print(f"Data shape with indices: {X_with_idx.shape}, Data range: [{np.min(X):.4f}, {np.max(X):.4f}]") + # print(f"Label distribution: {np.bincount(y.astype(int) if y.dtype != object else [0])}") + # + # # Initialize test variables + # X_test = None + # y_test = None + # + # # Handle validation data if provided + # if val_data_path: + # try: + # print(f"Loading validation data from {val_data_path}") + # val_df = pd.read_parquet(val_data_path) + # + # # Create unique indices for validation data by offsetting from training data + # val_df['original_idx'] = np.arange(len(val_df)) + len(df) + 1000 # Add 1000 as buffer + # + # val_X = val_df[columns].values + # val_indices = val_df['original_idx'].values + # + # if np.isnan(val_X).any() or np.isinf(val_X).any(): + # print("Warning: Validation set contains NaN or Inf values. Replacing with zeros.") + # val_X = np.nan_to_num(val_X) + # + # # Extract validation labels + # if label_col is not None and label_col in val_df.columns: + # val_y = val_df[label_col].values + # else: + # print("Warning: No binding label column found in validation data. Using dummy labels.") + # val_y = np.zeros(len(val_df)) + # + # # Add indices to validation features + # X_test = np.column_stack((val_indices.reshape(-1, 1), val_X)) + # X_test = X_test.reshape(-1, seq_length_with_idx) + # y_test = val_y + # + # print(f"Validation data shape with indices: {X_test.shape}") + # print(f"Validation label distribution: {np.bincount(y_test.astype(int) if y_test.dtype != object else [0])}") + # except Exception as e: + # print(f"Error loading validation data: {e}") + # X_test = None + # y_test = None + # + # # Split training data for validation set (the indices will automatically be split correctly) + # X_train, X_val, y_train, y_val = train_test_split(X_with_idx, y, test_size=test_size, random_state=random_state) + # print(f"Training data shape with indices: {X_train.shape}, Validation data shape: {X_val.shape}") + # print(f"Training label distribution: {np.bincount(y_train.astype(int) if y_train.dtype != object else [0])}") + # print(f"Validation label distribution: {np.bincount(y_val.astype(int) if y_val.dtype != object else [0])}") + # + # # Verify index uniqueness + # all_indices = np.concatenate([ + # X_train[:, 0], + # X_val[:, 0], + # X_test[:, 0] if X_test is not None else np.array([]) + # ]) + # unique_indices = np.unique(all_indices) + # if len(unique_indices) != len(all_indices): + # print("Warning: Some indices are not unique across data splits!") + # else: + # print(f"Successfully created {len(unique_indices)} unique indices across all data splits") + # + # return X_train, X_test, X_val, y_train, y_test, y_val, seq_length_with_idx + + def load_data(data_path, input_type='latent1024'): + """ Load training and validation folds along with test datasets. + Args: + data_path (str): Path to the directory containing fold data + input_type (str): Type of input data to extract ('latent1024', 'attention51', etc.) + + Returns: + tuple: (folds, test1, test2, seq_length_with_idx) + """ + print(f"[load_data] Loading data from {data_path}") + + # Initialize containers + folds = [] + seq_length = 0 + + # Load all available folds + fold_index = 0 + while True: + train_path = os.path.join(data_path, f"pep2vec_output_train_fold_{fold_index}.parquet") + val_path = os.path.join(data_path, f"pep2vec_output_val_fold_{fold_index}.parquet") + + print(f"[load_data] Checking for fold {fold_index}:") + print(f" train_path: {train_path} exists: {os.path.exists(train_path)}") + print(f" val_path: {val_path} exists: {os.path.exists(val_path)}") + + if not (os.path.exists(train_path) and os.path.exists(val_path)): + print(f"[load_data] No more folds found at index {fold_index}.") + break + + # Load train and validation data for this fold + train_df = pd.read_parquet(train_path) + val_df = pd.read_parquet(val_path) + print(f"[load_data] Fold {fold_index} train_df shape: {train_df.shape}, columns: {list(train_df.columns)}") + print(f"[load_data] Fold {fold_index} val_df shape: {val_df.shape}, columns: {list(val_df.columns)}") + + num_rows_train = len(train_df) + num_rows_val = len(val_df) + # drop rows with NaN binding_label + train_df = train_df.dropna(subset=['binding_label']) + val_df = val_df.dropna(subset=['binding_label']) + print(f"Number of dropped rows in train_df: {num_rows_train - len(train_df)}") + print(f"Number of dropped rows in val_df: {num_rows_val - len(val_df)}") + + # Process features based on input_type + if input_type == 'latent1024': + train_features = [col for col in train_df.columns if 'latent' in col] + val_features = [col for col in val_df.columns if 'latent' in col] + + print(f"[load_data] Fold {fold_index} latent features: {train_features}") + if not train_features or not val_features: + raise ValueError(f"No latent features found in fold {fold_index}") + + X_train = train_df[train_features].values + X_val = val_df[val_features].values + seq_length = len(train_features) + + elif input_type == 'attention51': + train_features = [col for col in train_df.columns if 'attn_' in col] + val_features = [col for col in val_df.columns if 'attn_' in col] + + print(f"[load_data] Fold {fold_index} attention features: {train_features}") + if not train_features or not val_features: + raise ValueError(f"No attention features found in fold {fold_index}") + + X_train = train_df[train_features].values + X_val = val_df[val_features].values + seq_length = len(train_features) + + elif input_type == 'pMHC-sequence': + print(f"[load_data] Fold {fold_index} pMHC-sequence not implemented") + raise NotImplementedError("pMHC-sequence input type not yet implemented") + else: + raise ValueError(f"Unknown input_type: {input_type}") + + print(f"[load_data] Fold {fold_index} X_train shape: {X_train.shape}, X_val shape: {X_val.shape}") + + # Check for idx column + print(f"[load_data] Fold {fold_index} train_df has 'idx': {'idx' in train_df.columns}") + print(f"[load_data] Fold {fold_index} val_df has 'idx': {'idx' in val_df.columns}") + + # Extract labels if available + y_train = train_df['binding_label'].values if 'binding_label' in train_df.columns else None + y_val = val_df['binding_label'].values if 'binding_label' in val_df.columns else None + + # Append fold data + folds.append((X_train, y_train, X_val, y_val)) + fold_index += 1 + + print(f"[load_data] Loaded {len(folds)} folds") + + X_test1, y_test1 = None, None + X_test2, y_test2 = None, None + + # Load test1 (stratified) dataset + test1_path = os.path.join(data_path, "pep2vec_output_test1_stratified.parquet") + print(f"[load_data] Checking for test1: {test1_path} exists: {os.path.exists(test1_path)}") + if os.path.exists(test1_path): + test1_df = pd.read_parquet(test1_path) + print(f"[load_data] test1_df shape: {test1_df.shape}, columns: {list(test1_df.columns)}") + + if input_type == 'latent1024': + latent_cols = [col for col in test1_df.columns if 'latent' in col] + test1_df['idx_feature'] = [int(uuid.uuid4().int) % (2 ** 31 - 1) for _ in range(len(test1_df))] + if 'binding_label' in test1_df.columns: + y_test1 = test1_df['binding_label'].values + select_cols = ['idx_feature'] + latent_cols + else: + select_cols = ['idx_feature'] + latent_cols + print(f"[load_data] test1 selected columns: {select_cols}") + if latent_cols: + X_test1 = test1_df[select_cols].values + print(f"[load_data] Loaded test1 dataset: {X_test1.shape}") + print(f"[load_data] test1_df has 'idx_feature': {'idx_feature' in test1_df.columns}") + + # Load test2 (single unique allele) dataset + test2_path = os.path.join(data_path, "pep2vec_output_test2_single_unique_allele.parquet") + print(f"[load_data] Checking for test2: {test2_path} exists: {os.path.exists(test2_path)}") + if os.path.exists(test2_path): + test2_df = pd.read_parquet(test2_path) + print(f"[load_data] test2_df shape: {test2_df.shape}, columns: {list(test2_df.columns)}") + + if input_type == 'latent1024': + latent_cols = [col for col in test2_df.columns if 'latent' in col] + test2_df['idx_feature'] = [int(uuid.uuid4().int) % (2 ** 31 - 1) for _ in range(len(test2_df))] + if 'binding_label' in test2_df.columns: + y_test2 = test2_df['binding_label'].values + select_cols = ['idx_feature'] + latent_cols + else: + select_cols = ['idx_feature'] + latent_cols + print(f"[load_data] test2 selected columns: {select_cols}") + if latent_cols: + X_test2 = test2_df[select_cols].values + print(f"[load_data] Loaded test2 dataset: {X_test2.shape}") + print(f"[load_data] test2_df has 'idx_feature': {'idx_feature' in test2_df.columns}") + + print(f"[load_data] seq_length: {seq_length}") + + return folds, X_test1, y_test1, X_test2, y_test2, seq_length + + + # def load_data_hash(data_path, val_data_path=None, test_size=0.2, random_state=42, input_type='latent1024', cache=True): + # """ + # Load and prepare data for training with advanced features. + # + # Args: + # data_path: Path to training data parquet file + # val_data_path: Optional path to validation data + # test_size: Fraction of data to use for validation if val_data_path is None + # random_state: Random seed for reproducibility + # input_type: Type of input data ('latent1024' or 'pMHC-sequence') + # cache: Whether to cache results to avoid reprocessing + # + # Returns: + # X_train: Training data + # X_test: Test/validation data + # X_val: Validation data (None if val_data_path is None and data is loaded from cache) + # seq_length: Length of each sequence + # """ + # import hashlib + # import os + # from functools import lru_cache + # + # # Create a cache key based on the input parameters + # cache_key = f"{os.path.basename(data_path)}_{test_size}_{random_state}_{input_type}" + # if val_data_path: + # cache_key += f"_{os.path.basename(val_data_path)}" + # cache_file = f".cache_{hashlib.md5(cache_key.encode()).hexdigest()}.npz" + # + # # Check if cached results exist + # if cache and os.path.exists(cache_file): + # print(f"Loading cached data from {cache_file}") + # cached = np.load(cache_file, allow_pickle=True) + # X_val = cached['X_val'] if 'X_val' in cached else None + # return cached['X_train'], cached['X_test'], X_val, int(cached['seq_length']) + # + # # Efficient ID generation with caching + # @lru_cache(maxsize=10000) + # def generate_id(f1, f2): + # combined = f"{f1}-{f2}" + # return hashlib.sha256(combined.encode('utf-8')).hexdigest() + # + # # Load and validate training data + # try: + # print(f"Loading training data from {data_path}") + # df = pd.read_parquet(data_path) + # + # # Validate required columns exist + # required_cols = ['mhc_sequence', 'peptide'] + # missing_cols = [col for col in required_cols if col not in df.columns] + # if missing_cols: + # raise ValueError(f"Missing required columns: {missing_cols}") + # + # # Generate unique IDs + # df['unique_id'] = df.apply(lambda row: generate_id(row['mhc_sequence'], row['peptide']), axis=1) + # + # # Handle different input types + # if input_type == 'latent1024': + # columns = [col for col in df.columns if 'latent' in col] + # if not columns: + # raise ValueError("No latent columns found in the dataset") + # seq_length = len(columns) + # print(f"Found {seq_length} latent columns") + # elif input_type == 'pMHC-sequence': + # columns = ['mhc_sequence', 'peptide'] + # print(f"Using pMHC sequence columns: {columns}") + # # Note: pMHC-sequence processing would be implemented here + # raise NotImplementedError("pMHC-sequence input not yet implemented") + # else: + # raise ValueError(f"Unknown input_type: {input_type}. Expected 'latent1024' or 'pMHC-sequence'") + # + # # Extract features and check for invalid values + # X = df[columns].values + # if np.isnan(X).any() or np.isinf(X).any(): + # print("Warning: Dataset contains NaN or Inf values. Replacing with zeros.") + # X = np.nan_to_num(X) + # + # X = X.reshape(-1, seq_length) + # print(f"Data shape: {X.shape}, Data range: [{np.min(X):.4f}, {np.max(X):.4f}]") + # + # # Initialize X_val + # X_val = None + # + # # Handle validation data + # if val_data_path: + # try: + # print(f"Loading validation data from {val_data_path}") + # val_df = pd.read_parquet(val_data_path) + # val_df['unique_id'] = val_df.apply(lambda row: generate_id(row['mhc_sequence'], row['peptide']), axis=1) + # + # # Check for data leakage + # train_ids = set(df['unique_id']) + # val_ids = set(val_df['unique_id']) + # overlap = train_ids.intersection(val_ids) + # if overlap: + # print(f"Warning: {len(overlap)} overlapping samples between training and validation sets") + # + # val_X = val_df[columns].values + # if np.isnan(val_X).any() or np.isinf(val_X).any(): + # print("Warning: Validation set contains NaN or Inf values. Replacing with zeros.") + # val_X = np.nan_to_num(val_X) + # + # X_val = val_X.reshape(-1, seq_length) + # print(f"Validation data shape: {X_val.shape}") + # + # # Still split training data for test set + # X_train, X_test = train_test_split(X, test_size=test_size, random_state=random_state) + # except Exception as e: + # print(f"Error loading validation data: {e}") + # # Continue with default split if validation data loading fails + # X_train, X_test = train_test_split(X, test_size=test_size, random_state=random_state) + # else: + # print(f"Splitting data with test_size={test_size}") + # X_train, X_test = train_test_split(X, test_size=test_size, random_state=random_state) + # + # # Save to cache + # if cache: + # print(f"Caching results to {cache_file}") + # if X_val is not None: + # np.savez(cache_file, X_train=X_train, X_test=X_test, X_val=X_val, seq_length=seq_length) + # else: + # np.savez(cache_file, X_train=X_train, X_test=X_test, seq_length=seq_length) + # + # return X_train, X_test, X_val, seq_length + # + # except Exception as e: + # print(f"Error loading data: {str(e)}") + # import traceback + # traceback.print_exc() + # raise + + def initialize_codebook_with_kmeans(X_train, num_embeddings, embedding_dim): + """Initialize codebook vectors using k-means clustering.""" + kmeans = KMeans(n_clusters=num_embeddings, random_state=random_state) + flat_data = X_train.reshape(-1, X_train.shape[-1]) + kmeans.fit(flat_data) + return kmeans.cluster_centers_.astype(np.float32) + + def plot_training_metrics(history, save_path=None): + """Plot the training metrics over epochs.""" + fig, axes = plt.subplots(4, 1, figsize=(12, 20), sharex=True) + axes[0].plot(history.history['total_loss'], 'b-', label='Train') + if 'val_total_loss' in history.history: + axes[0].plot(history.history['val_total_loss'], 'r-', label='Validation') + axes[0].set_title('Total Loss') + axes[0].set_ylabel('Loss') + axes[0].grid(True) + axes[0].legend() + + axes[1].plot(history.history['recon_loss'], 'b-', label='Train') + if 'val_recon_loss' in history.history: + axes[1].plot(history.history['val_recon_loss'], 'r-', label='Validation') + axes[1].set_title('Reconstruction Loss') + axes[1].set_ylabel('Loss') + axes[1].grid(True) + axes[1].legend() + + axes[2].plot(history.history['vq_loss'], 'b-', label='Train') + if 'val_vq_loss' in history.history: + axes[2].plot(history.history['val_vq_loss'], 'r-', label='Validation') + axes[2].set_title('VQ Loss') + axes[2].set_ylabel('Loss') + axes[2].grid(True) + axes[2].legend() + + axes[3].plot(history.history['perplexity'], 'b-', label='Train') + if 'val_perplexity' in history.history: + axes[3].plot(history.history['val_perplexity'], 'r-', label='Validation') + axes[3].set_title('Perplexity') + axes[3].set_xlabel('Epochs') + axes[3].set_ylabel('Perplexity') + axes[3].grid(True) + axes[3].legend() + + plt.tight_layout() + if save_path: + plt.savefig(save_path) + if visualize: + plt.show() + else: + plt.close() + + def plot_reconstructions(original, reconstructed, n_samples=5, save_path=None): + """Plot comparison between original and reconstructed sequences.""" + n_samples = min(n_samples, len(original)) + plt.figure(figsize=(15, 3 * n_samples)) + for i in range(n_samples): + plt.subplot(n_samples, 2, 2 * i + 1) + plt.plot(original[i]) + plt.title(f"Original Sequence {i + 1}") + plt.grid(True) + plt.subplot(n_samples, 2, 2 * i + 2) + plt.plot(reconstructed[i]) + plt.title(f"Reconstructed Sequence {i + 1}") + plt.grid(True) + plt.tight_layout() + if save_path: + plt.savefig(save_path) + if visualize: + plt.show() + else: + plt.close() + + def plot_codebook_usage(indices, num_embeddings, save_path=None): + """Visualize the usage distribution of codebook vectors. + + Args: + indices: Hard indices from model output (integer indices of assigned codes) + num_embeddings: Total number of vectors in the codebook + save_path: Path to save the visualization + + Returns: + used_vectors: Number of vectors used + usage_percentage: Percentage of codebook utilized + """ + # Convert from tensor to numpy if needed + if isinstance(indices, tf.Tensor): + indices = indices.numpy() + + # Ensure indices is a NumPy array before proceeding + if not isinstance(indices, np.ndarray): + try: + # Attempt conversion if it's list-like + indices = np.array(indices) + except Exception as e: + print( + f"Error in plot_codebook_usage: Input 'indices' is not a NumPy array or convertible. Type: {type(indices)}. Error: {e}") + # Return default values or raise an error + return 0, 0.0 + + # Flatten indices to 1D array for counting + try: + flat_indices = indices.flatten() + except AttributeError: + print(f"Error in plot_codebook_usage: Cannot flatten 'indices'. Type: {type(indices)}") + return 0, 0.0 # Return default values + + # Count occurrences of each codebook vector + try: + unique, counts = np.unique(flat_indices, return_counts=True) + except TypeError as e: + print( + f"Error in plot_codebook_usage: Cannot compute unique values for 'flat_indices'. Type: {type(flat_indices)}. Error: {e}") + return 0, 0.0 # Return default values + + # Create full distribution including zeros for unused vectors + full_distribution = np.zeros(num_embeddings) + for idx, count in zip(unique, counts): + if 0 <= idx < num_embeddings: # Ensure index is valid + full_distribution[int(idx)] = count + + # Calculate usage statistics + used_vectors = np.sum(full_distribution > 0) + usage_percentage = (used_vectors / num_embeddings) * 100 + + # Create the plot + plt.figure(figsize=(12, 6)) + bar_positions = np.arange(num_embeddings) + bars = plt.bar(bar_positions, full_distribution) + + # Color bars by frequency + max_count = np.max(full_distribution) + if max_count > 0: + for i, bar in enumerate(bars): + intensity = full_distribution[i] / max_count + bar.set_color(plt.cm.viridis(intensity)) + + plt.xlabel('Codebook Vector Index') + plt.ylabel('Usage Count') + plt.title( + f'Codebook Vector Usage Distribution\n{used_vectors}/{num_embeddings} vectors used ({usage_percentage:.1f}%)') + plt.grid(axis='y', linestyle='--', alpha=0.7) + + # Add colorbar + sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(0, max_count)) + sm.set_array([]) + cbar = plt.colorbar(sm) + cbar.set_label('Frequency') + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + print(f"Codebook usage: {used_vectors}/{num_embeddings} vectors used ({usage_percentage:.1f}%)") + if visualize: + plt.show() + else: + plt.close() + return used_vectors, usage_percentage + + def plot_soft_cluster_distribution(soft_cluster_probs, num_samples=None, save_path=None): + """ + Visualize the distribution of soft clustering probabilities across samples. + + Args: + soft_cluster_probs: List of arrays or single array containing probability + distributions across clusters for each sample + num_samples: Number of samples to visualize (default: 10) + save_path: Path to save the visualization (default: None) + + Returns: + None + """ + if num_samples is None: + num_samples = len(soft_cluster_probs) + # Process input to get a clean 2D array (samples x clusters) + if isinstance(soft_cluster_probs, list): + if not soft_cluster_probs: + print("Warning: Empty list provided to plot_soft_cluster_distribution") + return + sample_batch = soft_cluster_probs[0] + if isinstance(sample_batch, tf.Tensor): + sample_batch = sample_batch.numpy() + soft_cluster_probs = sample_batch + elif isinstance(soft_cluster_probs, tf.Tensor): + soft_cluster_probs = soft_cluster_probs.numpy() + + # Reshape if needed + if soft_cluster_probs.ndim > 2: + print(f"Reshaping array from shape {soft_cluster_probs.shape} to 2D format") + if soft_cluster_probs.shape[1] == 1: + soft_cluster_probs = soft_cluster_probs.reshape(soft_cluster_probs.shape[0], + soft_cluster_probs.shape[2]) + else: + soft_cluster_probs = soft_cluster_probs.reshape(soft_cluster_probs.shape[0], -1) + + # Verify valid data + if soft_cluster_probs.size == 0: + print("Warning: Empty array provided to plot_soft_cluster_distribution") + return + + n_samples = min(len(soft_cluster_probs), 1000) # Limit to prevent memory issues + n_clusters = soft_cluster_probs.shape[1] + + # Create a figure with 2 subplots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12), gridspec_kw={'height_ratios': [2, 1]}) + + # 1. Heatmap visualization (top) + display_samples = min(num_samples, n_samples) + im = ax1.imshow(soft_cluster_probs[:display_samples], aspect='auto', cmap='viridis') + ax1.set_xlabel('Cluster Index') + ax1.set_ylabel('Sample Index') + ax1.set_title(f'Soft Cluster Probability Heatmap (First {display_samples} Samples)') + plt.colorbar(im, ax=ax1, label='Probability') + + # 2. Aggregated statistics (bottom) + # Calculate statistics across all samples for each cluster + cluster_means = np.mean(soft_cluster_probs, axis=0) + cluster_max_counts = np.sum(np.argmax(soft_cluster_probs, axis=1)[:, np.newaxis] == np.arange(n_clusters), axis=0) + + # Create a twin axis for the bar plot + ax2_twin = ax2.twinx() + + # Plot mean probability for each cluster (line) + x = np.arange(n_clusters) + ax2.plot(x, cluster_means, 'r-', linewidth=2, label='Mean Probability') + ax2.set_ylabel('Mean Probability', color='r') + ax2.tick_params(axis='y', labelcolor='r') + ax2.set_ylim(0, max(cluster_means) * 1.2) + + # Plot histogram of cluster assignments (bars) + ax2_twin.bar(x, cluster_max_counts, alpha=0.3, label='Assignment Count') + ax2_twin.set_ylabel('Number of Samples\nwith Highest Probability', color='b') + ax2_twin.tick_params(axis='y', labelcolor='b') + + # Add labels and grid + ax2.set_xlabel('Cluster Index') + ax2.set_title('Cluster Usage Statistics Across All Samples') + ax2.set_xticks(np.arange(0, n_clusters, max(1, n_clusters // 20))) + ax2.grid(True, linestyle='--', alpha=0.5, axis='y') + + # Create custom legend + lines, labels = ax2.get_legend_handles_labels() + lines2, labels2 = ax2_twin.get_legend_handles_labels() + ax2.legend(lines + lines2, labels + labels2, loc='upper right') + + # Add overall statistics as text + active_clusters = np.sum(np.max(soft_cluster_probs, axis=0) > 0.01) + most_used_cluster = np.argmax(cluster_max_counts) + ax2.text(0.02, 0.95, + f"Active clusters: {active_clusters}/{n_clusters} ({active_clusters/n_clusters:.1%})\n" + f"Most used cluster: {most_used_cluster} ({cluster_max_counts[most_used_cluster]} samples)", + transform=ax2.transAxes, verticalalignment='top', + bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + plt.tight_layout() + + # Save if path is provided + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + + # Show or close based on global visualize flag + if visualize: + plt.show() + else: + plt.close() + + + def plot_cluster_distribution(soft_cluster_probs, save_path=None): + """ + Plot distribution of samples across clusters based on one-hot encodings. + + Args: + soft_cluster_probs: Soft cluster probabilities + save_path: Path to save the plot + + Returns: + used_clusters: Number of clusters used + usage_percentage: Percentage of clusters used + """ + # Convert soft cluster probabilities to one-hot encodings + one_hot_encodings = tf.one_hot(tf.argmax(soft_cluster_probs, axis=-1), depth=soft_cluster_probs.shape[-1]) + one_hot_encodings = tf.cast(one_hot_encodings, tf.float32) + + # print(f"one_hot_encodings shape: {one_hot_encodings.shape}") + # print first 5 values + # print(f"one_hot_encodings values: {one_hot_encodings[:5]}") + + # Convert one-hot to cluster indices if needed + if isinstance(one_hot_encodings, tf.Tensor): + one_hot_encodings = one_hot_encodings.numpy() + + # Handle different shapes of one_hot_encodings + if len(one_hot_encodings.shape) == 3: # (batch, seq_len, num_embeddings) + cluster_assignments = np.argmax(one_hot_encodings, axis=-1).flatten() + else: # (batch, num_embeddings) + cluster_assignments = np.argmax(one_hot_encodings, axis=-1).flatten() + + # Count occurrences of each cluster + unique_clusters, counts = np.unique(cluster_assignments, return_counts=True) + + # Create a full distribution including zeros for unused clusters + num_clusters = one_hot_encodings.shape[-1] + full_distribution = np.zeros(num_clusters) + for cluster, count in zip(unique_clusters, counts): + full_distribution[cluster] = count + + # Calculate usage statistics + used_clusters = np.sum(full_distribution > 0) + usage_percentage = (used_clusters / num_clusters) * 100 + + # Create the plot + plt.figure(figsize=(12, 6)) + bars = plt.bar(np.arange(num_clusters), full_distribution) + + # Color bars by frequency + max_count = np.max(full_distribution) + if max_count > 0: + for i, bar in enumerate(bars): + intensity = full_distribution[i] / max_count + bar.set_color(plt.cm.plasma(intensity)) + + plt.xlabel('Cluster Index') + plt.ylabel('Number of Samples') + plt.title( + f'Sample Distribution Across Clusters\n{used_clusters}/{num_clusters} clusters used ({usage_percentage:.1f}%)') + plt.grid(axis='y', linestyle='--', alpha=0.7) + + # Add a colorbar + sm = plt.cm.ScalarMappable(cmap=plt.cm.plasma, norm=plt.Normalize(0, max_count)) + sm.set_array([]) + cbar = plt.colorbar(sm) + cbar.set_label('Sample Count') + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + print(f"Cluster usage: {used_clusters}/{num_clusters} clusters contain samples ({usage_percentage:.1f}%)") + if visualize: + plt.show() + else: + plt.close() + return used_clusters, usage_percentage + + def process_and_save(dataset: tf.data.Dataset, + split_name: str, + model: tf.keras.Model, + output_dir: str, + num_embeddings: int): + """ + Quantize `dataset` through `model`, assemble into a DataFrame, + save to parquet, and plot distributions with split-specific filenames. + """ + quantized_latents = [] + cluster_indices_soft = [] + cluster_indices_hard = [] + labels = [] + + # Extract latent codes for every batch + for batch_X, batch_y in dataset: + Zq, out_P_proj, _, _ = model.encode_(batch) + quantized_latents.append(Zq.numpy()) + cluster_indices_soft.append(out_P_proj.numpy()) + cluster_indices_hard.append(tf.argmax(out_P_proj, axis=-1).numpy()) + labels.append(batch_y.numpy()) + + + # Concatenate across batches + quantized_latent = np.concatenate(quantized_latents, axis=0) + soft_probs = np.concatenate(cluster_indices_soft, axis=0) + hard_assign = np.concatenate(cluster_indices_hard, axis=0) + labels = np.concatenate(labels, axis=0) + + # Build DataFrame + records = [] + for i in range(len(quantized_latent)): + rec = {} + flat_latent = quantized_latent[i].flatten() + for j, v in enumerate(flat_latent): + rec[f'latent_{j}'] = float(v) + + flat_soft = soft_probs[i].flatten() + for j, v in enumerate(flat_soft): + rec[f'soft_cluster_{j}'] = float(v) + + rec['hard_cluster'] = int(hard_assign[i]) + # binding label if available + if i < len(labels): + lbl = labels[i] + # Handle scalar, 0-dim, or 1-dim arrays + if isinstance(lbl, (np.ndarray, list)) and np.asarray(lbl).size > 0: + rec['binding_label'] = float(np.asarray(lbl).flatten()[0]) + else: + rec['binding_label'] = float(lbl) + else: + rec['binding_label'] = np.nan + records.append(rec) + + df = pd.DataFrame(records) + # ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Save parquet + parquet_path = os.path.join(output_dir, f'quantized_outputs_{split_name}.parquet') + df.to_parquet(parquet_path, index=False) + print(f"[{split_name}] saved parquet → {parquet_path}") + + # Plot distributions + plot_cluster_distribution(soft_probs, + save_path=os.path.join(output_dir, f'cluster_distribution_soft_{split_name}.png')) + plot_codebook_usage(hard_assign, num_embeddings, + save_path=os.path.join(output_dir, f'codebook_usage_{split_name}.png')) + plot_soft_cluster_distribution(soft_probs, 20, + save_path=os.path.join(output_dir, + f'soft_cluster_distribution_{split_name}.png')) + + def remove_id(dataset): + """Remove the first column (ID) from the feature tensor, keep y if present.""" + return dataset.map(lambda x, y=None: (x[:, 1:], y) if y is not None else x[:, 1:]) + + def remove_y(dataset): + """Remove the last column (y) from the feature tensor.""" + return dataset.map(lambda x, y: x) + + def plot_PCA(zq, y, num_embeddings, save_path=None): + """Plot PCA of the quantized latents.""" + pca = PCA(n_components=2) + # print the shape of zq + print(f"zq shape: {zq.shape}") + # reshape from B, 1, N to B, N + if len(zq.shape) == 3 and zq.shape[1] == 1: + zq = zq.reshape(zq.shape[0], zq.shape[2]) + zq_2d = pca.fit_transform(zq) + + plt.figure(figsize=(10, 8)) + scatter = plt.scatter(zq_2d[:, 0], zq_2d[:, 1], c=y, cmap='viridis', alpha=0.5) + plt.colorbar(scatter, label='Binding Label') + plt.title(f'PCA of Quantized Latents (num_embeddings={num_embeddings})') + plt.xlabel('PCA Component 1') + plt.ylabel('PCA Component 2') + if save_path: + plt.savefig(save_path) + if visualize: + plt.show() + else: + plt.close() + + # ======================= End of Helper Functions ======================= + # --- Main Pipeline for Folds --- + print("--- Main Pipeline (Folds) ---") + + print("Loading/Generating Training Data...") + folds, X_test1, y_test1, X_test2, y_test2, seq_length = load_data(data_path, input_type=input_type) + + print(f"Folds shape: {[f[0].shape for f in folds]}") + print(f"X_test1 shape: {X_test1.shape if X_test1 is not None else 'None'}") + print(f"X_test2 shape: {X_test2.shape if X_test2 is not None else 'None'}") + print(f"Sequence length: {seq_length}") + + for fold_idx, (X_train, y_train, X_val, y_val) in enumerate(folds): + if fold_idx == 0: # only for the first fold # TODO: remove this + print(f"\n=== Processing Fold {fold_idx + 1}/{len(folds)} ===") + + # Preserve original labels before dropping + if output_data == "all": + original_data = np.concatenate([X_train, X_val], axis=0) + original_labels = np.concatenate([y_train, y_val], axis=0) + elif output_data == "train": + original_data = X_train + original_labels = y_train + else: + original_data = X_val + original_labels = y_val + + # Initialize codebook with k-means + print("Initializing codebook with k-means...") + if init_k_means: + print("Initializing codebook with k-means") + initial_codebook = initialize_codebook_with_kmeans(X_train, num_embeddings, seq_length) + else: + print("Not initializing codebook with k-means") + initial_codebook = None + + # Create datasets + print("Creating TensorFlow Datasets...") + train_dataset = create_dataset(X_train, y_train, batch_size=batch_size, is_training=True) + val_dataset = create_dataset(X_val, y_val, batch_size=batch_size, is_training=False) + print("Datasets created.") + + # --- Model Instantiation --- + print("Building the SCQ1DAutoEncoder model...") + input_shape = (seq_length,) + model = SCQ1DAutoEncoder( + input_dim=input_shape, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_beta=commitment_beta, + scq_params={ + 'lambda_reg': 1.0, + 'discrete_loss': False, + 'reset_dead_codes': True, + 'usage_threshold': 1e-4, + 'reset_interval': 5 + }, + cluster_lambda=1 + ) + print("Model built.") + + # --- Compile and Train --- + print("Compiling the model...") + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate)) + print("Model compiled.") + + # Create balanced dataset for training + print("Creating balanced dataset...") + label_counts = Counter(y_train) + min_count = min(label_counts.values()) + balanced_indices = [] + for label in label_counts: + inds = np.where(y_train == label)[0] + sampled = np.random.choice(inds, min_count, replace=False) + balanced_indices.extend(sampled) + balanced_indices = np.array(balanced_indices) + + # Create balanced dataset + X_bal = X_train[balanced_indices] + y_bal = y_train[balanced_indices] + print(f"Balanced dataset shape: {X_bal.shape}") + balanced_dataset_y = create_dataset(X_bal, y_bal, batch_size=batch_size, is_training=True) + balanced_dataset = balanced_dataset_y.map(lambda x, y: x) # Remove labels + + # Print shapes for debugging + for batch in balanced_dataset.take(1): + print(f"Balanced dataset batch shape: {batch.shape}") + break + for features, labels in val_dataset.take(1): + print(f"Validation dataset batch features shape: {features.shape}, labels shape: {labels.shape}") + break + + # Initial training on balanced set + print(f"Training on balanced set for {epochs} epochs...") + start_time = time.time() + history = model.fit( + balanced_dataset, + epochs=epochs, + validation_data=val_dataset + ) + end_time = time.time() + print(f"Balanced training finished in {end_time - start_time:.2f} seconds.") + + # Freeze encoder layers before fine-tuning + model.encoder.trainable = False + for layer in model.encoder.layers: + layer.trainable = False + + # Recompile model after freezing layers + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate)) + + # # Fine-tune on full dataset + # print(f"Fine-tuning on full dataset for {epochs} epochs...") + # start_time = time.time() + # history = model.fit( + # train_dataset, + # epochs=epochs // 10, + # validation_data=val_dataset + # ) + # end_time = time.time() + # print(f"Fine-tuning finished in {end_time - start_time:.2f} seconds.") + + # --- Evaluation and Visualization --- + print("\nEvaluating the model...") + + if visualize: + print("\nPlotting training history...") + plot_training_metrics(history, save_path=os.path.join(output_dir, f'vqvae_training_metrics_fold{fold_idx}.png')) + print("\nPerforming example inference...") + try: + example_batch, y_batch = next(iter(val_dataset)) + except Exception: + example_batch = next(iter(val_dataset)) + output = model(example_batch, training=False) + reconstruction, quantized_latent, cluster_indices, vq_loss, perplexity = output + print("\nVisualizing reconstructions...") + plot_reconstructions(example_batch.numpy(), reconstruction.numpy(), + save_path=os.path.join(output_dir, f'vqvae_reconstructions_fold{fold_idx}.png')) + Zq, out_P_proj, vq_loss, perplexity = model.encode_(example_batch) + print("\nVisualizing encoder output...") + plot_PCA(Zq.numpy(), y_batch, num_embeddings, + save_path=os.path.join(output_dir, f'scqvae_PCA_fold{fold_idx}.png')) + + # --- Save Model --- + if save_model: + model_dir = os.path.join(output_dir, f'model_fold{fold_idx}') + os.makedirs(model_dir, exist_ok=True) + model.save_weights(os.path.join(model_dir, 'vqvae_model_weights.h5')) + print(f"\nModel weights saved to '{os.path.join(model_dir, 'vqvae_model_weights.h5')}'") + + # --- Extract Latent Space --- + print("\nExtracting quantized latent space...") + + if output_data == "train_val_combined": + out_ds = tf.data.Dataset.concatenate(train_dataset, val_dataset) + process_and_save(out_ds, f"combined_fold{fold_idx}", model, original_labels, output_dir, num_embeddings) + + elif output_data in ("train", "val"): + ds_name = output_data + ds = train_dataset if ds_name == "train" else val_dataset + process_and_save(ds, f"{ds_name}_fold{fold_idx}", model, original_labels, output_dir, num_embeddings) + + elif output_data == "train_val_seperate": + process_and_save(train_dataset, f"train_fold{fold_idx}", model, output_dir, num_embeddings) + process_and_save(val_dataset, f"val_fold{fold_idx}", model, output_dir, num_embeddings) + + else: + raise ValueError("Invalid output_data. Must be 'train', 'val', 'train_val_combined', or 'train_val_seperate'.") + + print(f"\nAll requested splits processed for all folds. Outputs are in: {output_dir}") + + # out_dataset2 = None + # if output_data == "train_val_combined": + # out_dataset1 = tf.data.Dataset.concatenate(train_dataset, val_dataset) + # elif output_data == "train": + # out_dataset1 = train_dataset # using only training dataset for quantization + # elif output_data == "val": + # out_dataset1 = val_dataset # using only validation dataset for quantization + # elif output_data == "train_val_seperate": + # out_dataset1 = train_dataset + # out_dataset2 = val_dataset # TODO + # else: + # raise ValueError("Invalid output_data option. Choose 'train', 'val', or 'train_val_combined'.") + # + # quantized_latents, cluster_indices_hard_assign, cluster_indices_soft_assign = [], [], [] + # for batch in out_dataset1: + # output, Zq, out_P_proj, _, _ = model(batch, training=False) + # print(f"\nBatch shape: {batch.shape}") + # print(f"\nOutput shape: {output.shape}") + # print(f"\nZq shape: {Zq.shape}") + # print(f"\nout_P_proj shape: {out_P_proj.shape}") + # + # quantized_latents.append(Zq.numpy()) + # # Append soft assignments + # cluster_indices_soft_assign.append(out_P_proj.numpy()) # shows the probability of each index + # cluster_indices_hard_assign.append(tf.argmax(out_P_proj, axis=-1).numpy()) # shows which index has the highest probability + # quantized_latent = np.concatenate(quantized_latents, axis=0) + # cluster_indices_soft = np.concatenate(cluster_indices_soft_assign, axis=0) + # cluster_indices_hard = np.concatenate(cluster_indices_hard_assign, axis=0) + # + # # Compile quantized outputs into a DataFrame and include original binding labels + # # TODO save + # # Check dimensions of our data + # print(f"Quantized latent shape: {quantized_latent.shape}") + # print(f"Cluster indices soft shape: {cluster_indices_soft.shape}") + # print(f"Cluster indices hard shape: {cluster_indices_hard.shape}") + # print( + # f"Original labels shape: {original_labels.shape if hasattr(original_labels, 'shape') else len(original_labels)}") + # + # # Create a list of records instead of column-wise dictionary + # # This avoids dimension issues when creating the DataFrame + # records = [] + # + # for i in range(len(quantized_latent)): + # record = {} + # + # # Add latent features (flattened if needed) + # if len(quantized_latent.shape) > 2: + # flat_latent = quantized_latent[i].flatten() + # for j in range(len(flat_latent)): + # record[f'latent_{j}'] = float(flat_latent[j]) + # else: + # for j in range(quantized_latent.shape[1]): + # record[f'latent_{j}'] = float(quantized_latent[i, j]) + # + # # Add soft cluster probabilities (flattened if needed) + # if len(cluster_indices_soft.shape) > 2: + # flat_soft = cluster_indices_soft[i].flatten() + # for j in range(len(flat_soft)): + # record[f'soft_cluster_{j}'] = float(flat_soft[j]) + # else: + # for j in range(cluster_indices_soft.shape[1]): + # record[f'soft_cluster_{j}'] = float(cluster_indices_soft[i, j]) + # + # # Add hard cluster assignment + # record['hard_cluster'] = int(cluster_indices_hard[i]) + # + # # Add binding label (if available) + # if i < len(original_labels): + # # Convert to scalar if it's an array + # if hasattr(original_labels[i], 'shape') and original_labels[i].shape: + # record['binding_label'] = float(original_labels[i][0]) + # else: + # record['binding_label'] = float(original_labels[i]) + # else: + # record['binding_label'] = np.nan + # + # records.append(record) + # + # # Create DataFrame from records + # results_df = pd.DataFrame(records) + # + # # Print head of the output + # print("\nHead of the output DataFrame:") + # print(results_df.head()) + # + # # Save as parquet + # parquet_path = os.path.join(output_dir, 'quantized_outputs.parquet') + # results_df.to_parquet(parquet_path, index=False) + # print(f"Quantized outputs saved to: {parquet_path}") + # + # print(f"\n head of Cluster indices hard: {cluster_indices_hard[:5]}") + # print(f"\n head of Cluster indices soft: {cluster_indices_soft[:5]}") + # # print shape of soft cluster indices + # print(f"\n softs shape: {np.array(cluster_indices_soft).shape}") + # print(f"\n head of Quantized latents: {quantized_latent[:5]}") + # + # # Create distribution plots for all data + # print("\nCreating distribution plots for all processed data...") + # plot_cluster_distribution(cluster_indices_soft, save_path=os.path.join(output_dir, 'cluster_distribution_soft.png')) + # plot_codebook_usage(cluster_indices_hard, num_embeddings, + # save_path=os.path.join(output_dir, 'full_codebook_usage.png')) + # plot_soft_cluster_distribution(np.array(cluster_indices_soft), 20, 'soft_cluster_distribution.png') + # print(f"Latent representations saved to {output_dir}") + + +# write a pipeline that runs the MoE model on the data. +# +# first load the data and if exists the val_data. +# +# then if set latent, load the quantized_latent.npy and load cluster_indices.npy as the cluster_indices_probs, set the number of experts as the length of on of the cluster cluster_indices_probs vector. +# +# then if val_data_path not set, split the data to train and validation and train the model. +# +# then evaluate the model and draw accuracy plots. + +def train_and_evaluate_moe( + data_path, + val_data_path=None, + latent_path=None, + input_type='latent', # Options: 'latent', 'pMHC-sequence', 'attention' + model_type='MoE', # Options: 'MoE', 'MLP', 'transformer', 'CNN' + batch_size=32, + epochs=20, + learning_rate=1e-4, + hidden_dim=64, + output_dir='data/MoE', + visualize=True, + save_model=True, + random_state=42, + test_size=0.2, + use_soft_clusters_as_gates= True, + **kwargs + ): + """ + Load data from a parquet file, train a Mixture-of-Experts (MoE) model, and evaluate performance. + + Parameters: + data_path (str): Path to the input parquet file containing latent representations and cluster probabilities. + val_data_path (str, optional): Path to a separate validation parquet file. If not provided, a train/val split is used. + latent_path (str, optional): Path to a pretrained encoder model for sequence input. + input_type (str): 'latent' to use precomputed features, 'pMHC-sequence' to encode sequences, and 'attention' for attention-based features. + batch_size (int): Batch size for training. + epochs (int): Number of training epochs. + learning_rate (float): Learning rate for the optimizer. + hidden_dim (int): Size of the hidden layer in the MoE model. + output_dir (str): Directory to save models and figures. + visualize (bool): Whether to generate PCA visualizations. + save_model (bool): Whether to save the trained model. + random_state (int): Seed for reproducibility. + test_size (float): Fraction of data to use for validation if no val_data_path. + **kwargs: Additional keyword arguments forwarded to model.fit (e.g., callbacks). + + Returns: + model (tf.keras.Model): Trained MoE model. + history (tf.keras.callbacks.History): Training history object. + eval_results (list): Evaluation results on validation data. + """ + # ======================= Helper Functions ======================= + from sklearn.metrics import roc_curve, auc, precision_recall_curve + + def create_visualizations(X, X_val, y, y_val, y_pred, y_proba, soft_clusters, soft_clusters_val, history, output_dir, + random_state, visualize): + print("Creating PCA visualizations...") + pca = PCA(n_components=2, random_state=random_state) + X_all = np.vstack([X, X_val]) + y_all = np.concatenate([y, y_val]) + soft_clusters_all = np.vstack([soft_clusters, soft_clusters_val]) + pca_proj = pca.fit_transform(X_all) + print("Unique labels in y_all:", np.unique(y_all)) + + df_vis = pd.DataFrame({ + 'PC1': pca_proj[:, 0], + 'PC2': pca_proj[:, 1], + 'label': y_all, + 'cluster': np.argmax(soft_clusters_all, axis=1) + }) + + plt.figure(figsize=(8, 6)) + sns.scatterplot(data=df_vis, x='PC1', y='PC2', hue='label', alpha=0.6) + plt.title('PCA of Dataset Colored by Binding Label') + plt.savefig(os.path.join(output_dir, 'pca_binding_label.png')) + print(f"Saved binding label PCA plot to {os.path.join(output_dir, 'pca_binding_label.png')}") + if visualize: + plt.show() + plt.close() + + plt.figure(figsize=(8, 6)) + sns.scatterplot(data=df_vis, x='PC1', y='PC2', hue='cluster', legend=False, alpha=0.6) + plt.title('PCA of Dataset Colored by Soft Cluster Assignment') + plt.savefig(os.path.join(output_dir, 'pca_cluster_assignment.png')) + print(f"Saved cluster assignment PCA plot to {os.path.join(output_dir, 'pca_cluster_assignment.png')}") + if visualize: + plt.show() + plt.close() + + plt.figure(figsize=(12, 5)) + plt.subplot(1, 2, 1) + plt.plot(history.history['loss'], label='Training Loss') + plt.plot(history.history['val_loss'], label='Validation Loss') + plt.title('Model Loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend() + + plt.subplot(1, 2, 2) + plt.plot(history.history['accuracy'], label='Training Accuracy') + plt.plot(history.history['val_accuracy'], label='Validation Accuracy') + plt.title('Model Accuracy') + plt.xlabel('Epoch') + plt.ylabel('Accuracy') + plt.legend() + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, 'training_history.png')) + print(f"Saved training history plot to {os.path.join(output_dir, 'training_history.png')}") + if visualize: + plt.show() + plt.close() + + # Add ROC and PRC plots if predictions are available + if y_proba is not None: + fpr, tpr, _ = roc_curve(y_val, y_proba) + roc_auc = auc(fpr, tpr) + + plt.figure(figsize=(8, 6)) + plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})') + plt.plot([0, 1], [0, 1], 'k--') + plt.title('Receiver Operating Characteristic') + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.legend(loc='lower right') + plt.savefig(os.path.join(output_dir, 'roc_curve.png')) + print(f"Saved ROC curve plot to {os.path.join(output_dir, 'roc_curve.png')}") + if visualize: + plt.show() + plt.close() + + precision, recall, _ = precision_recall_curve(y_val, y_proba) + pr_auc = auc(recall, precision) + + plt.figure(figsize=(8, 6)) + plt.plot(recall, precision, label=f'PR curve (area = {pr_auc:.2f})') + plt.title('Precision-Recall Curve') + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.legend(loc='lower left') + plt.savefig(os.path.join(output_dir, 'precision_recall_curve.png')) + print(f"Saved Precision-Recall curve plot to {os.path.join(output_dir, 'precision_recall_curve.png')}") + if visualize: + plt.show() + plt.close() + + def evaluation_metrics(y_true, y_pred, y_prob=None): + """ + Calculate evaluation metrics for the model predictions. + + Args: + y_true (np.ndarray): True labels. + y_pred (np.ndarray): Predicted labels (class predictions). + y_prob (np.ndarray, optional): Predicted probabilities for the positive class. + Required for metrics like AUROC and AUPRC. + + Returns: + dict: Dictionary of evaluation metrics. + """ + from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, + roc_auc_score, average_precision_score, balanced_accuracy_score, + matthews_corrcoef, cohen_kappa_score, confusion_matrix) + + # Basic classification metrics + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, average='weighted') + recall = recall_score(y_true, y_pred, average='weighted') + f1 = f1_score(y_true, y_pred, average='weighted') + + # Additional classification metrics + balanced_acc = balanced_accuracy_score(y_true, y_pred) + mcc = matthews_corrcoef(y_true, y_pred) + kappa = cohen_kappa_score(y_true, y_pred) + + # Confusion matrix-based metrics + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 + npv = tn / (tn + fn) if (tn + fn) > 0 else 0 + + metrics = { + 'accuracy': accuracy, + 'precision': precision, + 'recall': recall, + 'f1_score': f1, + 'balanced_accuracy': balanced_acc, + 'mcc': mcc, + 'kappa': kappa, + 'specificity': specificity, + 'npv': npv + } + + # Probability-based metrics (only if probabilities are provided) + if y_prob is not None: + try: + auroc = roc_auc_score(y_true, y_prob) + auprc = average_precision_score(y_true, y_prob) + metrics.update({ + 'auroc': auroc, + 'auprc': auprc + }) + except Exception as e: + print(f"Warning: Could not calculate probability-based metrics: {e}") + + return metrics + + # ============================================== + + # Set random seeds for reproducibility + np.random.seed(random_state) + tf.random.set_seed(random_state) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Load the main dataset + print(f"Loading training data from {data_path}") + df = pd.read_parquet(data_path) + + # drop rows with nan labels + if 'binding_label' in df.columns: + df = df.dropna(subset=['binding_label']) + + # Extract features from latent columns + if input_type == 'latent': + # Get all latent columns + latent_cols = [col for col in df.columns if col.startswith('latent_')] + if not latent_cols: + raise ValueError("No latent_* columns found in the dataset") + print(f"Found {len(latent_cols)} latent columns") + X = df[latent_cols].values + + elif input_type == 'pMHC-sequence': + if latent_path is None: + raise ValueError("latent_path must be provided when input_type='pMHC-sequence'") + # Load sequence encoder (pretrained model) + encoder = tf.keras.models.load_model(latent_path) + if 'peptide' not in df.columns or 'mhc_sequence' not in df.columns: + raise ValueError("Required columns 'peptide' and 'mhc_sequence' not found") + sequences = df[['peptide', 'mhc_sequence']].values + # Encode sequences into latent vectors + X = encoder.predict(sequences, batch_size=batch_size) + + elif input_type == 'attention': + # get all attention columns + attention_cols = [col for col in df.columns if col.startswith('attn_')] + if not attention_cols: + raise ValueError("No attention_* columns found in the dataset") + print(f"Found {len(attention_cols)} attention columns") + X = df[attention_cols].values + + else: + raise ValueError(f"Unsupported input_type: {input_type}") + + # Extract soft cluster probabilities + if use_soft_clusters_as_gates: + soft_cluster_cols = [col for col in df.columns if col.startswith('soft_cluster_')] + if not soft_cluster_cols: + print(f"Dataframe columns: {df.columns}") + raise ValueError("No soft_cluster_* columns found in the dataset") + print(f"Found {len(soft_cluster_cols)} soft cluster probability columns") + soft_clusters = df[soft_cluster_cols].values + else: + soft_clusters = np.zeros((X.shape[0], 32)) # Dummy soft clusters if not used + + # Extract binding labels + if 'binding_label' not in df.columns: + raise ValueError("Required column 'binding_label' not found in the dataset") + y = df['binding_label'].values + + print(f"Data loaded: X shape={X.shape}, soft_clusters shape={soft_clusters.shape}, y shape={y.shape}") + + # Prepare validation data if provided + if val_data_path: + print(f"Loading validation data from {val_data_path}") + df_val = pd.read_parquet(val_data_path) + + # drop rows with nan labels + if 'binding_label' in df_val.columns: + df_val = df_val.dropna(subset=['binding_label']) + + if input_type == 'latent': + # Use same latent columns as training + X_val = df_val[latent_cols].values + elif input_type == 'pMHC-sequence': + raise NotImplementedError("Validation data for pMHC-sequence input type is not implemented") + elif input_type == 'attention': + # Use same attention columns as training + attention_cols_val = [col for col in df_val.columns if col.startswith('attn_')] + if not attention_cols_val: + raise ValueError("No attention_* columns found in the validation dataset") + print(f"Found {len(attention_cols_val)} attention columns in validation data") + X_val = df_val[attention_cols_val].values + else: + # Encode sequence data + sequences_val = df_val[['peptide', 'mhc_sequence']].values + X_val = encoder.predict(sequences_val, batch_size=batch_size) + + if use_soft_clusters_as_gates: + # Use same soft cluster columns + soft_clusters_val = df_val[soft_cluster_cols].values + else: + soft_clusters_val = np.zeros((X_val.shape[0], 32)) + + y_val = df_val['binding_label'].values + + print(f"Validation data loaded: X_val shape={X_val.shape}, soft_clusters_val shape={soft_clusters_val.shape}") + else: + # Split training data for validation + print(f"Splitting data with test_size={test_size}") + X, X_val, soft_clusters, soft_clusters_val, y, y_val = train_test_split( + X, soft_clusters, y, + test_size=test_size, + random_state=random_state, + stratify=y if len(np.unique(y)) > 1 else None + ) + print(f"Data split: train={X.shape[0]} samples, validation={X_val.shape[0]} samples") + + # Build tf.data datasets + train_dataset = tf.data.Dataset.from_tensor_slices(((X, soft_clusters), y)) + train_dataset = train_dataset.shuffle(buffer_size=1000, seed=random_state).batch(batch_size) + + val_dataset = tf.data.Dataset.from_tensor_slices(((X_val, soft_clusters_val), y_val)) + val_dataset = val_dataset.batch(batch_size) + + # Instantiate and compile the MoE model + num_experts = soft_clusters.shape[1] + feature_dim = X.shape[1] + print(f"Building MoE model with {num_experts} experts and feature_dim={feature_dim}") + + if model_type == 'MoE': + model = MoEModel( + feature_dim, + hidden_dim=hidden_dim, + num_experts=num_experts, + use_provided_gates=use_soft_clusters_as_gates, + ) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=['accuracy'], + ) + elif model_type == 'MoE_2': + model = EnhancedMoEModel( + feature_dim, + hidden_dim=hidden_dim, + num_experts=num_experts, + use_hard_clustering=use_soft_clusters_as_gates, + ) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=['accuracy'], + ) + elif model_type == 'MLP': + model = BinaryMLP( + feature_dim, + ) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=['accuracy'], + ) + elif model_type == 'transformer': + model = TabularTransformer( + feature_dim, + ) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=['accuracy'], + ) + elif model_type == 'CNN': + model = EmbeddingCNN( + feature_dim, + ) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=['accuracy'], + ) + else: + raise ValueError(f"Unsupported model_type: {model_type}") + + # --- create a balanced dataset --- + # label_counts = Counter(y) + # min_count = min(label_counts.values()) + # balanced_indices = np.concatenate([ + # np.random.choice(np.where(y == lbl)[0], min_count, replace=False) + # for lbl in label_counts + # ]) + # np.random.shuffle(balanced_indices) + # X_bal = X[balanced_indices] + # soft_bal = soft_clusters[balanced_indices] + # y_bal = y[balanced_indices] + # + # train_bal_ds = tf.data.Dataset.from_tensor_slices(((X_bal, soft_bal), y_bal)) \ + # .shuffle(buffer_size=1000, seed=random_state) \ + # .batch(batch_size) + # + # # --- Train the model --- + # print(f"Training on balanced set for {epochs} epochs...") + # history = model.fit( + # train_bal_ds, + # validation_data=val_dataset, + # epochs=epochs, + # **kwargs + # ) + + # TODO experimental + # # Freeze initial half of layers, then fine-tune on full dataset + # for layer in model.layers[: len(model.layers) // 2]: + # layer.trainable = False + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=["accuracy"], + # tf.keras.metrics.BinaryAccuracy(), + # tf.keras.metrics.Precision(), + # tf.keras.metrics.Recall(), + # tf.keras.metrics.AUC() + # ], + # run_eagerly=False + ) + # tf.config.run_functions_eagerly(True) + print(f"Fine-tuning on full dataset for {epochs} epochs...") + history = model.fit( + train_dataset, + validation_data=val_dataset, + epochs=epochs, + **kwargs + ) + + eval_results = model.evaluate(val_dataset) + print(f"Validation loss: {eval_results[0]:.4f}, accuracy: {eval_results[1]:.4f}") + # Get predictions + y_proba = model.predict((X_val, soft_clusters_val)) + + # Ensure y_proba is a NumPy array before using NumPy operations + if isinstance(y_proba, tf.Tensor): + y_prob_np = y_proba.numpy() + else: + y_prob_np = y_proba + + # Convert probabilities to binary predictions + y_pred = (y_prob_np > 0.5).astype(int) + eval_dict = evaluation_metrics(y_val, y_pred, y_prob=y_prob_np) + print(f"Evaluation metrics: {eval_dict}") + + # TODO fix later + # Save the trained model + # if save_model: + # save_path = os.path.join(output_dir, 'moe_model') + # os.makedirs(save_path, exist_ok=True) # Ensure the directory exists + # model.save(save_path) + # print(f"Model saved to {save_path}") + + # Visualization using PCA + if visualize: + print("Creating PCA visualizations...") + create_visualizations(X, X_val, y, y_val, y_pred, y_proba, soft_clusters, soft_clusters_val, history, output_dir, + random_state, visualize) + + return model, history, eval_results + + +if __name__ == "__main__": + dataset_folder = "NetMHCIpan_dataset" + # Call train_and_evaluate_scqvae without trying to unpack return values + train_and_evaluate_scqvae( + data_path=f"data/Pep2Vec/{dataset_folder}_new_subset", + # val_data_path=f"data/Pep2Vec/{dataset_folder}_new_subset/pep2vec_output_val_fold_0.parquet", + input_type="latent1024", + num_embeddings=32, + embedding_dim=32, + batch_size=1024, + epochs=20, + output_dir=f"data/SCQvae/{dataset_folder}", + visualize=True, + save_model=True, + init_k_means=False, + random_state=42, + output_data="train_val_seperate", # Options: "val", "train", "train_val_seperate" + ) + m = "MLP" # Options: "MoE", "MLP", "transformer", "CNN" + print(f"Running {m}") + train_and_evaluate_moe( + data_path=f"data/SCQvae/{dataset_folder}/quantized_outputs_train_fold0.parquet", + val_data_path=f"data/SCQvae/{dataset_folder}/quantized_outputs_val_fold0.parquet", + # data_path=f"data/Pep2Vec/{dataset_folder}_new_subset/pep2vec_output_train_fold_0.parquet", + # val_data_path=f"data/Pep2Vec/{dataset_folder}_new_subset/pep2vec_output_val_fold_0.parquet", + input_type="latent", + model_type=m, + batch_size=128, + epochs=10, + learning_rate=1e-4, + hidden_dim=16, + output_dir=f"data/MoE/{dataset_folder}", + visualize=True, + save_model=True, + random_state=42, + test_size=0.2, # fraction of data to use for validation if no val_data_path + use_soft_clusters_as_gates=True, + ) + + +# # Create a 1D UMAP projection colored by cluster indices +# print("Computing 1D UMAP colored by cluster indices...") +# # Use the same 1D UMAP projection, but color by cluster indices +# plt.figure(figsize=(12, 4)) +# cluster_indices_flat = cluster_indices.flatten().reshape(-1, 1) +# embedding_1d = mapper_1d.fit_transform(cluster_indices_flat) +# +# # Regenerate y_jitter to match the new embedding size +# y_jitter = np.random.rand(embedding_1d.shape[0]) * 0.1 +# +# # Check shapes and ensure they match +# print(f"embedding_1d shape: {embedding_1d.shape}") +# print(f"cluster_indices shape: {cluster_indices.shape}") +# +# # Ensure cluster_indices is properly shaped for plotting +# if cluster_indices.ndim > 1: +# cluster_indices_plot = cluster_indices.flatten() +# else: +# cluster_indices_plot = cluster_indices +# +# # Make sure lengths match +# if len(cluster_indices_plot) != len(embedding_1d): +# # Reshape or slice cluster_indices to match embedding_1d +# if len(cluster_indices_plot) > embedding_1d.shape[0]: +# cluster_indices_plot = cluster_indices_plot[:embedding_1d.shape[0]] +# else: +# # If you need to extend, you might repeat values or use a different strategy +# cluster_indices_plot = np.pad(cluster_indices_plot, (0, embedding_1d.shape[0] - len(cluster_indices_plot)), 'edge') +# print(f"Adjusted cluster_indices shape: {cluster_indices_plot.shape}") +# +# scatter_clusters = plt.scatter(embedding_1d[:, 0], y_jitter, c=cluster_indices_plot, cmap='viridis', s=10, alpha=0.7) +# cbar = plt.colorbar(scatter_clusters, label='Cluster Index') +# plt.title('1D UMAP Projection Colored by Cluster Indices', fontsize=14) +# plt.xlabel('UMAP Dimension 1', fontsize=12) +# plt.yticks([]) +# plt.grid(axis='x', linestyle='--', alpha=0.6) +# plt.savefig('umap_clusters.png') +# plt.show() + +# # Apply UMAP dimensionality reduction +# print("Computing UMAP projection...") +# mapper = umap.UMAP(n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=42) +# embedding = mapper.fit_transform(latent_2d) +# +# # Visualize the embedding with distinct colors for each sample +# plt.figure(figsize=(12, 10)) +# +# # Use sample IDs for coloring (each sample gets a unique color) +# scatter = plt.scatter( +# embedding[:, 0], +# embedding[:, 1], +# c=sample_ids, +# cmap='tab20', # Colormap with distinct colors +# s=10, +# alpha=0.7 +# ) +# +# # Add legend and labels +# cbar = plt.colorbar(scatter, label='Sample ID') +# cbar.set_label('Sample ID') +# plt.title('UMAP Projection of Quantized Latent Space (colored by sample)', fontsize=14) +# plt.xlabel('UMAP Dimension 1', fontsize=12) +# plt.ylabel('UMAP Dimension 2', fontsize=12) +# +# # Add a grid for better readability +# plt.grid(linestyle='--', alpha=0.6) +# +# # Save with higher DPI for better quality +# plt.savefig('scqvae_latent_umap_by_sample.png', dpi=300, bbox_inches='tight') +# plt.show() +# +# # Create a second visualization colored by cluster assignment +# plt.figure(figsize=(12, 10)) +# scatter = plt.scatter( +# embedding[:, 0], +# embedding[:, 1], +# c=cluster_indices, +# cmap='viridis', +# s=10, +# alpha=0.7 +# ) +# +# # Add legend and labels +# cbar = plt.colorbar(scatter, label='Codebook Vector Index') +# plt.title('UMAP Projection of Quantized Latent Space (colored by codebook vector)', fontsize=14) +# plt.xlabel('UMAP Dimension 1', fontsize=12) +# plt.ylabel('UMAP Dimension 2', fontsize=12) +# plt.grid(linestyle='--', alpha=0.6) +# plt.savefig('scqvae_latent_umap_by_cluster.png', dpi=300, bbox_inches='tight') +# plt.show() +# +# # Save the embeddings for potential further analysis +# np.savez( +# 'umap_results.npz', +# embedding=embedding, +# sample_ids=sample_ids, +# cluster_indices=cluster_indices, +# latent_vectors=latent_2d +# ) +# print("UMAP visualization complete. Results saved with sample-based coloring.") diff --git a/run_pep2vec.py b/run_pep2vec.py new file mode 100644 index 00000000..0e632185 --- /dev/null +++ b/run_pep2vec.py @@ -0,0 +1,540 @@ +# load Conbotnet subset data +import os +import random +import sys +import pandas as pd +import numpy as np +from utils.processing_functions import create_k_fold_leave_one_out_stratified_cross_validation, create_progressive_k_fold_cross_validation + + +def load_data(file_path, sep=","): + """ + Load data from a file and return a DataFrame. + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File {file_path} does not exist.") + + df = pd.read_csv(file_path, sep=sep) + return df + + +def select_columns(df, columns): + """ + Select specific columns from a DataFrame. + """ + print("DF columns:", df.columns) + print("Selected columns:", columns) + return df[columns] + + +# def get_netmhcpan_allele(allele, netmhcpan_dataset=None, allele_cache={}): +# """ +# Get the allele from the sequence in NetMHCpan dataset. +# Returns the allele with the highest sequence similarity. +# +# Parameters: +# ----------- +# allele : str or list +# Allele(s) to format and match in NetMHCpan format +# netmhcpan_dataset : pd.DataFrame, optional +# Pre-loaded NetMHCpan dataset to avoid repeated file reading +# allele_cache : dict, optional +# Cache of previously processed alleles +# +# Returns: +# -------- +# str or dict +# Matching allele(s) in NetMHCpan format +# """ +# # Check if we're processing a single allele or multiple alleles +# if isinstance(allele, list) or isinstance(allele, np.ndarray): +# # Process unique alleles and create a mapping +# unique_alleles = set(allele) +# allele_map = {} +# for a in unique_alleles: +# allele_map[a] = get_netmhcpan_allele(a, netmhcpan_dataset, allele_cache) +# return allele_map +# +# # Check if this allele is already in cache +# if allele in allele_cache: +# return allele_cache[allele] +# +# # Load dataset if not provided +# if netmhcpan_dataset is None: +# netmhcpan_dataset = pd.read_csv("data/HLA_alleles/pseudoseqs/PMGen_pseudoseq.csv") +# +# # Format allele name to NetMHCpan format +# formatted_allele = format_allele(allele) +# +# # First, try exact match (case-insensitive) +# exact_matches = netmhcpan_dataset[netmhcpan_dataset['allele'].str.lower() == formatted_allele.lower()] +# +# if not exact_matches.empty: +# result = exact_matches.iloc[0]['allele'] +# else: +# # If no exact match, try partial match +# partial_matches = netmhcpan_dataset[netmhcpan_dataset['allele'].str.contains(formatted_allele, case=False)] +# +# if not partial_matches.empty: +# result = partial_matches.iloc[0]['allele'] +# else: +# # If no match at all, report and use formatted allele +# print(f"No match found for allele: {allele} (formatted as {formatted_allele})") +# result = formatted_allele +# +# # Cache the result +# allele_cache[allele] = result +# return result +# +# def format_allele(allele): +# """Helper function to format allele names to NetMHCpan format""" +# # Format DRB alleles (like DRB11402 to DRB1*14:02) +# if allele.startswith('DRB') and not '-' in allele: +# locus = allele[:4] +# allele_num = allele[4:] +# if len(allele_num) >= 4: +# return f"{locus}*{allele_num[:2]}:{allele_num[2:]}" +# +# # Format MHC-II heterodimers (like HLA-DQA10501-DQB10301 to HLA-DQA1*05:01-DQB1*03:01) +# elif '-' in allele and not '*' in allele: +# parts = allele.split('-') +# if len(parts) == 2 and '-' in parts[1]: # Handle cases with another hyphen +# prefix, alpha_beta = parts +# alpha, beta = alpha_beta.split('-') +# +# # Format alpha and beta chains +# alpha = format_chain(alpha) +# beta = format_chain(beta) +# +# return f"{prefix}-{alpha}-{beta}" +# elif len(parts) == 3: # Format like HLA-DQA10501-DQB10301 +# prefix, alpha, beta = parts +# +# # Format alpha and beta chains +# alpha = format_chain(alpha) +# beta = format_chain(beta) +# +# return f"{prefix}-{alpha}-{beta}" +# +# # Default case - return allele as is +# return allele +# +# def format_chain(chain): +# """Format an MHC chain like DQA10501 to DQA1*05:01""" +# if len(chain) >= 7: # e.g., DQA10501 +# locus = chain[:4] +# num = chain[4:] +# return f"{locus}*{num[:2]}:{num[2:]}" +# return chain + +from tqdm import tqdm +# Optional: for parquet streaming +import pyarrow as pa +import pyarrow.parquet as pq + + +def process_chunk_df(df_chunk, binding_labels, mhc_sequences): + """ + Process a DataFrame chunk: assign binding_label and mhc_sequence, map labels to ints. + """ + # Standardize column names + if 'allotype' not in df_chunk.columns and 'MHC' in df_chunk.columns: + df_chunk = df_chunk.rename(columns={'MHC': 'allotype'}) + if 'peptide' not in df_chunk.columns and 'long_mer' in df_chunk.columns: + df_chunk = df_chunk.rename(columns={'long_mer': 'peptide'}) + + # Assign binding labels via map on MultiIndex for vectorized lookup + mi = pd.MultiIndex.from_frame(df_chunk[['allotype', 'peptide']]) + df_chunk['binding_label'] = mi.map(binding_labels) + + # Assign mhc_sequence via map + df_chunk['mhc_sequence'] = df_chunk['allotype'].map(mhc_sequences) + + # Map string labels to integers + unique_labels = pd.unique(df_chunk['binding_label'].dropna()) + label_mapping = {lbl: i for i, lbl in enumerate(sorted(unique_labels))} + df_chunk['binding_label'] = df_chunk['binding_label'].map(label_mapping) + + return df_chunk + + +def add_binding_label_streaming(input_path, map_csv, output_dir=None, is_netmhcpan=False, + csv_chunksize=1_000_000): + """ + Stream-process large files (CSV or Parquet) on disk, without loading fully into memory. + + Parameters: + ----------- + input_path : str + Path to a file or directory of files (.csv or .parquet) to process. + csv_path : str + Path to small CSV file for binding_label and mhc_sequence mappings. + output_dir : str, optional + Directory to write processed files; if None, original files will be overwritten. + is_netmhcpan : bool + Whether the CSV mapping file uses 'assigned_label' instead of 'binding_label'. + csv_chunksize : int + Number of rows per chunk when reading mapping CSV (if large), default 1e6. + """ + print(f"Starting streaming addition of binding labels for: {input_path}") + df_map = map_csv + + # Ensure column consistency + df_map = df_map.rename(columns={ + 'long_mer': 'peptide' + }) + print(f"Loaded csv file with shape: {df_map.shape}") + # Build dicts for mapping + binding_labels = dict(zip(zip(df_map['allele'], df_map['peptide']), df_map['binding_label'])) + mhc_sequences = dict(zip(df_map['allele'], df_map['mhc_sequence'])) if 'mhc_sequence' in df_map.columns else {} + + def get_output_path(in_path): + fname = os.path.basename(in_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + return os.path.join(output_dir, fname) + return in_path + + def stream_file(file_path): + fname = os.path.basename(file_path) + print(f"\n--- Processing file: {fname} ---") + out_path = get_output_path(file_path) + ext = os.path.splitext(file_path)[1].lower() + + if ext == '.csv': + first_write = True + for i, chunk in enumerate(tqdm(pd.read_csv(file_path, chunksize=csv_chunksize), desc=f"Chunks ({fname})")): + print(f"Processing chunk {i + 1}") + processed = process_chunk_df(chunk, binding_labels, mhc_sequences) + processed.to_csv(out_path, mode='w' if first_write else 'a', index=False, header=first_write) + first_write = False + print(f"Finished writing CSV to: {out_path}") + + elif ext in ['.parquet', '.pq']: + pf = pq.ParquetFile(file_path) + writer = None + print(f"Parquet row groups: {pf.num_row_groups}") + for rg in tqdm(range(pf.num_row_groups), desc=f"RowGroups ({fname})"): + print(f"Reading row group {rg}") + partial = pf.read_row_group(rg).to_pandas() + processed = process_chunk_df(partial, binding_labels, mhc_sequences) + table = pa.Table.from_pandas(processed) + if writer is None: + writer = pq.ParquetWriter(get_output_path(file_path), table.schema) + writer.write_table(table) + if writer: + writer.close() + print(f"Finished writing Parquet to: {out_path}") + + else: + print(f"Skipped unsupported file type: {fname}") + + # Handle directory or single file + if os.path.isdir(input_path): + files = [f for f in os.listdir(input_path) if os.path.splitext(f)[1].lower() in ['.csv', '.parquet', '.pq']] + print(f"Found {len(files)} files to process in directory.") + for fname in files: + file_path = os.path.join(input_path, fname) + stream_file(file_path) + else: + stream_file(input_path) + + +def main(dataset_name="Conbotnet", mhc_type="mhc2", subset_prop=1.0, n_folds=5, process_fold_n=None, chunk_n=0): + """ + Prepare datasets for Pep2Vec training and evaluation. + + Parameters: + ----------- + dataset_name : str + Name of the dataset directory (e.g., "Conbotnet", "ConvNeXT-MHC", "NetMHCpan_dataset") + mhc_type : str + Type of MHC ("mhc1" or "mhc2") + subset_prop : float + Proportion of data to use (1.0 = all data) + n_folds : int + Number of cross-validation folds + """ + # Setup paths + base_dir = os.path.dirname(__file__) + data_dir = os.path.join(base_dir, "data") + dataset_dir = os.path.join(data_dir, dataset_name) + folds_dir = os.path.join(dataset_dir, "folds") + output_dir = os.path.join(data_dir, "Pep2Vec", dataset_name) + + # Create directories + os.makedirs(output_dir, exist_ok=True) + os.makedirs(folds_dir, exist_ok=True) + + # Define paths and columns based on dataset + is_netmhcpan = dataset_name.lower() == "netmhcpan_dataset" + + if not is_netmhcpan: + # Other datasets have separate train/test files + paths = { + 'train': os.path.join(dataset_dir, "train.csv"), + 'test': os.path.join(dataset_dir, "test_all.csv") + } + columns = ["allele", "long_mer", "binding_label", "mhc_sequence"] + rename_map = {"allele": "allotype", "long_mer": "peptide"} # required for pep2vec + + # Column configuration + # Rename label column for NetMHCpan + if is_netmhcpan: + # NetMHCpan has a single cleaned_data.csv file + # cleaned_data_path = os.path.join(dataset_dir, f"combined_data_{mhc_type[-1]}.csv") + cleaned_data_path = os.path.join(dataset_dir, f"chunks_I/subset_balanced_300k.csv") + print(cleaned_data_path) + rename_map = {"allele": "allotype", "Peptide": "peptide"} # required for pep2vec + rename_map["assigned_label"] = "binding_label" + columns = ["allotype", "peptide", "binding_label"] + + # Load and process datasets + datasets = {} + + if is_netmhcpan: + # Load NetMHCpan dataset from single file + print(f"Loading NetMHCpan dataset from {cleaned_data_path}") + # df = load_data(cleaned_data_path) + usecols = ['allele', 'peptide', 'assigned_label', 'mhc_sequence', 'mhc_class'] + rng = random.Random(42) + df = pd.read_csv(cleaned_data_path, usecols=usecols) + if subset_prop < 1.0: + df = df.sample(frac=subset_prop, random_state=42) + df_map = df.rename(columns={'assigned_label': 'binding_label', 'peptide': 'long_mer'}) + print(f"Full dataset shape: {df.shape}") + + # rename columns + if "binding_label" not in df.columns and "assigned_label" in df.columns: + df.rename(columns=rename_map, inplace=True) + + # Check the dtype + print("mhc_class dtype:", df["mhc_class"].dtype) + + # And get counts of each class to see their distribution + print(df["mhc_class"].value_counts(dropna=False)) + + # Filter by MHC class based on mhc_type parameter + mhc_class = 2 if mhc_type.lower() == "mhc2" else 1 + df = df[df["mhc_class"] == mhc_class].copy() + print(f"Dataset after filtering for MHC class {mhc_class}: {df.shape}") + + subset = df[["allotype", "peptide", "binding_label"]] + print(subset.info()) + print(subset.head(10)) + print(subset.isna().sum()) + total = len(df) + uniques = df[["allotype", "peptide"]].dropna().drop_duplicates().shape[0] + print(f"Total rows: {total:,}; unique (allotype, peptide): {uniques:,}") + print(df.filter(like="label").columns) + print(df["binding_label"].notna().sum()) + + # TODO fix + # Process dataframes + # select + clean + train_df = ( + df[["allotype", "peptide", "binding_label"]] + .dropna() + .drop_duplicates(subset=["allotype", "peptide"]) + ) + + print(f"Train dataset shape after processing: {train_df.shape}") + print(train_df["binding_label"].value_counts()) + + datasets['train'] = train_df + datasets['test'] = None + else: + # Regular dataset processing with separate files + for name, path in paths.items(): + df = pd.read_csv(path) + df_map = df.copy() + + # Process dataframe + df = (select_columns(df, columns) + .rename(columns=rename_map) + .drop_duplicates(subset=["allotype", "peptide"]) + .dropna()) + + print(f"{name.capitalize()} dataset shape after processing: {df.shape}") + datasets[name] = df + + # # change the allele names to the ones in the NetMHCpan dataset + # # Load the NetMHCpan dataset once + # netmhcpan_dataset = pd.read_csv("data/HLA_alleles/pseudoseqs/PMGen_pseudoseq.csv") + # # Create a shared cache for alleles + # allele_cache = {} + # + # # Apply to train and test datasets with shared resources + # datasets['train']['allotype'] = datasets['train']['allotype'].apply( + # lambda x: get_netmhcpan_allele(x, netmhcpan_dataset, allele_cache)) + # datasets['test']['allotype'] = datasets['test']['allotype'].apply( + # lambda x: get_netmhcpan_allele(x, netmhcpan_dataset, allele_cache)) + + # define test1 and test2 datasets + # test1: balanced sampling with equal representation from each binding_label + train_df_ = datasets['train'] + + # Determine sample size (minimum of 1000 or smallest class size) + samples_per_label = min(1000, min(train_df_['binding_label'].value_counts())) + print(f"Creating balanced test1 with {samples_per_label} samples per label") + + # Sample equally from each label + test1 = (train_df_ + .groupby('binding_label', group_keys=False) # no re‑index shuffle + .sample(n=samples_per_label, random_state=42) # vectorised sample + .reset_index(drop=True)) + + train_mask = ~train_df_.index.isin(test1.index) + train_updated = train_df_.loc[train_mask] + + datasets['train'] = train_updated + + # test2: select allele with lowest sample count + allele_counts = datasets['train']['allotype'].value_counts() + lowest_allele = allele_counts.idxmin() + test2 = datasets['train'][datasets['train']['allotype'] == lowest_allele].copy() + datasets['train'] = datasets['train'][datasets['train']['allotype'] != lowest_allele].reset_index(drop=True) + + if process_fold_n: + process_fold_n = int(process_fold_n) + if process_fold_n < 0 or process_fold_n >= n_folds: + raise ValueError(f"Invalid fold number: {process_fold_n}. Must be between 0 and {n_folds - 1}.") + if not process_fold_n or process_fold_n == 0: + # Create k-fold cross-validation splits + k = n_folds + folds = create_progressive_k_fold_cross_validation( + datasets['train'], k=k, target_col="binding_label", + id_col="allotype" + ) + + held_out_ids_path = os.path.join(folds_dir, "held_out_ids.txt") + if os.path.exists(held_out_ids_path): + os.remove(held_out_ids_path) + + # Save folds + for i, (train_set, val_set, held_out_id) in enumerate(folds): + # only keep the allele and peptide columns + train_set = train_set[["allotype", "peptide"]] + val_set = val_set[["allotype", "peptide"]] + # drop nan and duplicates + train_set = train_set.dropna().drop_duplicates(subset=["allotype", "peptide"]) + val_set = val_set.dropna().drop_duplicates(subset=["allotype", "peptide"]) + # TODO find a better solution later - pep2vec can't handle long peptides longer than ~25 might work upto 29 + train_set = train_set[train_set["peptide"].str.len() <= 25] + val_set = val_set[val_set["peptide"].str.len() <= 25] + + # reset the index + train_set.reset_index(drop=True, inplace=True) + val_set.reset_index(drop=True, inplace=True) + with open(os.path.join(os.path.dirname(__file__), "data", dataset_name, "folds", f"train_set_fold_{i}.csv"), + mode="w", encoding="utf-8") as train_file: + train_set.to_csv(train_file, sep=",", index=True, header=True) + with open(os.path.join(os.path.dirname(__file__), "data", dataset_name, "folds", f"val_set_fold_{i}.csv"), + mode="w", encoding="utf-8") as val_file: + val_set.to_csv(val_file, sep=",", index=True, header=True) + with open(held_out_ids_path, "a") as f: + f.write(f"Fold {i}: {held_out_id}\n") + + # Save the test set + # if dataset has a test set, save it as well + if datasets.get('test') is not None: + datasets['test'].to_csv(os.path.join(folds_dir, "test_original.csv"), index=False, header=True) + # Save the test1 and test2 sets + test1.to_csv(os.path.join(folds_dir, "test1_stratified.csv"), index=False, header=True) + test2.to_csv(os.path.join(folds_dir, "test2_single_unique_allele.csv"), index=False, header=True) + print("Test dataset saved.") + else: + print("no folds created, using the existing ones") + + # load the fold csv files for specified fold + train_path = os.path.join(folds_dir, f"train_set_fold_{process_fold_n}.csv") + val_path = os.path.join(folds_dir, f"val_set_fold_{process_fold_n}.csv") + if os.path.exists(train_path) and os.path.exists(val_path): + datasets['train'] = pd.read_csv(train_path, index_col=0) + datasets['val'] = pd.read_csv(val_path, index_col=0) + else: + raise FileNotFoundError(f"Fold files not found for fold {process_fold_n}") + + ################### Pep2Vec ################### + # automatically get the number of cores + num_cores = max(os.cpu_count() - 2, 1) + # TODO only process one fold? with process_fold_n + if process_fold_n: + for split in ['train', 'val']: + csv_in = os.path.join(folds_dir, f"{split}_set_fold_{process_fold_n}.csv") + out_pq = os.path.join(output_dir, f"pep2vec_output_{split}_fold_{process_fold_n}.parquet") + if os.path.exists(csv_in): + os.system( + f"./Pep2Vec/pep2vec.bin --num_threads {num_cores} --dataset {csv_in} --output_location {out_pq} --mhctype {mhc_type}") + else: + for i in range(n_folds): + for split in ['train', 'val']: + csv_in = os.path.join(folds_dir, f"{split}_set_fold_{i}.csv") + out_pq = os.path.join(output_dir, f"pep2vec_output_{split}_fold_{i}.parquet") + if os.path.exists(csv_in): + os.system( + f"./Pep2Vec/pep2vec.bin --num_threads {num_cores} --dataset {csv_in} --output_location {out_pq} --mhctype {mhc_type}") + + # # test set + test_file = os.path.join(folds_dir, "test_original.csv") + if os.path.exists(test_file): + output_file = os.path.join(output_dir, f"pep2vec_output_test.parquet") + os.system( + f"./Pep2Vec/pep2vec.bin --num_threads {num_cores} --dataset {test_file} --output_location {output_file} --mhctype {mhc_type}") + # Process test1 and test2 datasets + for test_name in ["test1_stratified", "test2_single_unique_allele"]: + test_file = os.path.join(folds_dir, f"{test_name}.csv") + if os.path.exists(test_file): + output_file = os.path.join(output_dir, f"pep2vec_output_{test_name}.parquet") + os.system( + f"./Pep2Vec/pep2vec.bin --num_threads {num_cores} --dataset {test_file} --output_location {output_file} --mhctype {mhc_type}") + ############################################### + + return df_map + + +if __name__ == "__main__": + if len(sys.argv) > 1: + fold_n = sys.argv[1] + else: + fold_n = None + + print(f"Running for Fold {fold_n}") + + # main("Conbotnet", "mhc2", 1, 5, fold_n) + # df_map = None + # if not df_map: + # df1 = pd.read_csv(os.path.join("data", "Conbotnet", "train.csv"),) + # df2 = pd.read_csv(os.path.join("data", "Conbotnet", "test_all.csv"),) + # df_map = pd.concat([df1, df2]) + # add_binding_label_streaming( + # input_path=os.path.join("data", "Pep2Vec", "Conbotnet"), + # map_csv=df_map, + # output_dir=os.path.join("data", "Pep2Vec", "Conbotnet_new_subset"), + # is_netmhcpan=False + # ) + + # df_map = main("ConvNeXT-MHC", "mhc1", 0.01, 5, fold_n) + # df_map = None + # if not df_map: + # df1 = pd.read_csv(os.path.join("data", "ConvNeXT-MHC", "train.csv"),) + # df2 = pd.read_csv(os.path.join("data", "ConvNeXT-MHC", "test_all.csv"),) + # df_map = pd.concat([df1, df2]) + # add_binding_label_streaming( + # input_path=os.path.join("data", "Pep2Vec", "ConvNeXT-MHC"), + # map_csv=df_map, + # output_dir=os.path.join("data", "Pep2Vec", "ConvNeXT-MHC_new_subset") + # ) + # main("NetMHCpan_dataset", "mhc2", 0.01, 5) + # add_binding_label( + # df_path=os.path.join("data", "Pep2Vec", "NetMHCIIpan_dataset"), + # train_path=os.path.join("data", "NetMHCIIpan_dataset", "cleaned_data.csv"), + # output_dir=os.path.join("data", "Pep2Vec", "NetMHCIpan_dataset_new_subset") + # ) + # df_map = main("NetMHCpan_dataset", "mhc1", 1, 5, fold_n) + # add_binding_label_streaming( + # input_path=os.path.join("data", "Pep2Vec", "NetMHCpan_dataset"), + # map_csv=df_map, + # output_dir=os.path.join("data", "Pep2Vec", "NetMHCIpan_dataset_new_subset"), + # is_netmhcpan=True + # ) diff --git a/run_utils.py b/run_utils.py index a7abd1e9..57097902 100644 --- a/run_utils.py +++ b/run_utils.py @@ -15,6 +15,7 @@ import pandas as pd import subprocess import warnings + warnings.filterwarnings("ignore", category=FutureWarning) @@ -790,8 +791,28 @@ def protein_mpnn_wrapper(output_pdbs_dict, args, max_jobs, mode='parallel'): raise ValueError("Invalid mode! Choose 'parallel' or 'single'.") -def run_and_parse_netmhcpan(peptide_fasta_file, mhc_type, output_dir, mhc_seq_list=[], mhc_allele=None, - dirty_mode=False): +def run_and_parse_netmhcpan(peptide_fasta_file, mhc_type, output_dir, mhc_seq_list=None, mhc_allele=None, + dirty_mode=False, save_csv=True): + """ + Runs the NetMHCpan tool and parses its output. + + Args: + peptide_fasta_file (str): Path to the FASTA file containing peptide sequences. + mhc_type (int): Type of MHC (1 for MHC-I, 2 for MHC-II). + output_dir (str): Directory to save the output files. + mhc_seq_list (list, optional): List of MHC sequences. Default is an empty list. + mhc_allele (str, optional): Specific MHC allele name. Default is None. + dirty_mode (bool, optional): If True, removes the raw output file after parsing. Default is False. + save_csv (bool, optional): If True, saves the parsed output as a CSV file. Default is True. + + Returns: + pd.DataFrame: DataFrame containing the parsed NetMHCpan output. + + Raises: + ValueError: If neither `mhc_seq_list` nor `mhc_allele` is provided. + """ + if mhc_seq_list is None: + mhc_seq_list = [] assert mhc_type in [1,2] if not mhc_allele and len(mhc_seq_list) == 0: raise ValueError(f'at least one of mhc_seq_list or mhc_allele should be provided') @@ -810,7 +831,7 @@ def run_and_parse_netmhcpan(peptide_fasta_file, mhc_type, output_dir, mhc_seq_li f'with the Alpha/Beta order ' f'found {len(mhc_seq_list)}: {mhc_seq_list}') else: - assert len(mhc_allele.split('/'))==2, (f'mhc_allele for mhc class 2, should contant both alpha and beta alleles seperated by "/"' + assert len(mhc_allele.split('/'))==2, (f'mhc_allele for mhc class 2, should contain both alpha and beta alleles seperated by "/"' f'\n example: DRA/DRB*01. found {mhc_allele}') mhc_seq_list = ['', ''] @@ -820,10 +841,11 @@ def run_and_parse_netmhcpan(peptide_fasta_file, mhc_type, output_dir, mhc_seq_li a = processing_functions.match_inputseq_to_netmhcpan_allele(mhc_seq_list[i], mhc_type, allele) matched_allele.append(a) if mhc_type == 1: break - print("Matched Alleles", matched_allele) + # print("Matched Alleles", matched_allele) processing_functions.run_netmhcpan(peptide_fasta_file, matched_allele, outfile, mhc_type) df = processing_functions.parse_netmhcpan_file(outfile) - df.to_csv(outfile_csv, index=False) + if save_csv: + df.to_csv(outfile_csv, index=False) if dirty_mode: os.remove(outfile) return df diff --git a/src/model3.py b/src/model3.py new file mode 100644 index 00000000..ce290d26 --- /dev/null +++ b/src/model3.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +""" +pepmhc_cross_attention.py +------------------------- + +Minimal end-to-end demo of a peptide × MHC cross-attention classifier +with explainable attention visualisation. + +Author: 2025-05-22 +""" + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +import matplotlib.pyplot as plt +import seaborn as sns + +# -------------------------------------------------------------------------------- +# 1. Synthetic toy-data helpers +# -------------------------------------------------------------------------------- +AA = "ACDEFGHIKLMNPQRSTVWY" +AA_TO_INT = {a:i for i,a in enumerate(AA)} +UNK = 20 # index for “unknown” +PAD_TOKEN = -2 # set manually to avoid confusion with UNK + +def onehot(seq: str, max_len: int) -> np.ndarray: + """Return (max_len,21) one-hot matrix.""" + mat = np.full((max_len, 21), PAD_TOKEN, dtype=np.float32) # initialize padding with -2 + for i, aa in enumerate(seq[:max_len]): + mat[i, AA_TO_INT.get(aa, UNK)] = 1.0 + return mat + +# -------------------------------------------------------------------------------- +# 2. Model building block: cross-attention with score output +# -------------------------------------------------------------------------------- +""" +model.py – light-weight peptide × MHC cross-attention classifier + plus twin model that exposes the attention matrix + for explainability. + +The only public symbol is + build_classifier(max_pep_len, max_mhc_len, …) +which returns + clf_model, att_model +""" +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +import numpy as np + +# --------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------- +def cross_att_block(query, # (B,pep_len,D) + context, # (B,mhc_len,D) + query_mask=None, # (B,pep_len) + context_mask=None, # (B,mhc_len) + heads=8, + name="xatt"): + """ + Keras MultiHeadAttention; returns (att_out, att_scores) + Shapes: + att_out (B, pep_len, D) + att_scores (B, heads, pep_len, mhc_len) + """ + mha = layers.MultiHeadAttention(num_heads=heads, + key_dim=query.shape[-1], + name=name) + + # Create attention mask if both query_mask and context_mask are provided + if query_mask is not None and context_mask is not None: + attn_mask = query_mask[:, :, None] & context_mask[:, None, :] # (B, pep_len, mhc_len) + else: + attn_mask = None + + att_out, att_scores = mha(query=query, + value=context, + key=context, + attention_mask=attn_mask, + return_attention_scores=True) + return att_out, att_scores + + +# --------------------------------------------------------------------- +# main builder +# --------------------------------------------------------------------- +def build_classifier(max_pep_len : int, + max_mhc_len : int, + pep_emb_dim : int = 64, + mhc_emb_dim : int = 64, + heads : int = 8): + """ + Returns + clf_model – compiled model for training / inference + att_model – same weights, outputs attention scores only + """ + # Inputs ----------------------------------------------------------- + inp_pep = keras.Input(shape=(max_pep_len, 21), name="pep_onehot") + inp_mhc = keras.Input(shape=(max_mhc_len,1152), name="mhc_latent") + + # Masks ----------------------------------------------------------- + pep_mask = layers.Lambda(lambda x: tf.reduce_any(x != PAD_TOKEN, axis=-1), name="pep_mask")(inp_pep) + mhc_mask = layers.Lambda(lambda x: tf.reduce_any(x != 0, axis=-1), name="mhc_mask")(inp_mhc) + + # Linear projections ---------------------------------------------- + pep_emb = layers.Dense(pep_emb_dim, activation=None, + name="pep_proj")(inp_pep) # (B,pep_len,D) + mhc_emb = layers.Dense(mhc_emb_dim, activation=None, + name="mhc_proj")(inp_mhc) # (B,mhc_len,D) + + # Positional encoding (simple learned) ---------------------------- + pep_pos = layers.Embedding(input_dim=max_pep_len, + output_dim=pep_emb_dim, + name="pep_pos_emb")(tf.range(max_pep_len)) + mhc_pos = layers.Embedding(input_dim=max_mhc_len, + output_dim=mhc_emb_dim, + name="mhc_pos_emb")(tf.range(max_mhc_len)) + pep_emb = pep_emb + pep_pos + mhc_emb = mhc_emb + mhc_pos + + # Dense layer to reduce dimension of mhc_emb + mhc_emb = layers.Dense(mhc_emb_dim, activation=None, + name="mhc_dim_reduction")(mhc_emb) # (B,mhc_len,D) + + # Cross-attention -------------------------------------------------- + att_pep, att_scores = cross_att_block(pep_emb, mhc_emb, pep_mask, mhc_mask, + heads=heads, name="pep2mhc") + + # Pool & classifier ----------------------------------------------- + # Mask out padding in peptide for pooling + masked_att_pep = layers.Lambda( + lambda inputs: inputs[0] * tf.expand_dims(tf.cast(inputs[1], tf.float32), axis=-1), + name="mask_pep" + )([att_pep, pep_mask]) + sum_att_pep = layers.Lambda( + lambda x: tf.reduce_sum(x, axis=1), + name="sum_att_pep" + )(masked_att_pep) + num_non_pad = layers.Lambda( + lambda x: tf.reduce_sum(tf.cast(x, tf.float32), axis=1, keepdims=True), + name="num_non_pad" + )(pep_mask) + pooled = layers.Lambda( + lambda inputs: inputs[0] / (inputs[1] + 1e-8), + name="pooled" + )([sum_att_pep, num_non_pad]) + x = layers.Dense(32, activation="relu")(pooled) + out = layers.Dense(1, activation="sigmoid", name="prob")(x) + + clf_model = keras.Model([inp_pep, inp_mhc], out, name="PepMHC_clf") + clf_model.compile(optimizer="adam", + loss="binary_crossentropy", + metrics=["binary_accuracy","AUC"]) + + # attention twin --------------------------------------------------- + att_model = keras.Model([inp_pep, inp_mhc], att_scores, + name="PepMHC_attention") + + # cross latent space projection model ----------------------------- + cross_latent_model = keras.Model([inp_pep, inp_mhc], att_pep, + name="PepMHC_cross_latent") + + return clf_model, att_model, cross_latent_model + + +# -------------------------------------------------------------------------------- +# 4. Demo run (synthetic data, 2 epochs, heat map) +# -------------------------------------------------------------------------------- +# if __name__ == "__main__": +# tf.random.set_seed(0); np.random.seed(0) +# +# PEP_MAX = 14 +# MHC_MAX = 36 +# BATCH = 64 +# STEPS = 4 # keep tiny for a quick sanity run +# +# model, att_model, cross_latent_model = build_classifier(max_pep_len=PEP_MAX, max_mhc_len=MHC_MAX) +# +# print(model.summary(line_length=110)) +# +# # quick dummy training -------------------------------------------------- +# for step in range(STEPS): +# Xp, Xm, y = toy_batch_masked(BATCH, pep_max=PEP_MAX, mhc_max=MHC_MAX) +# model.train_on_batch([Xp,Xm], y) +# +# # one inference batch & attention --------------------------------------- +# Xp_test, Xm_test, y_test = toy_batch_masked(8, pep_max=PEP_MAX, mhc_max=MHC_MAX) +# preds = model.predict([Xp_test, Xm_test]) +# att_scores = att_model.predict([Xp_test, Xm_test]) # (B,heads,pep_len,mhc_len) +# +# print("\nPredictions:", preds.squeeze().round(3)) +# +# # -------------------------------------------------------------------------------- +# # 5. visualise attention for the first sample (average over heads) +# # -------------------------------------------------------------------------------- +# sample = 0 +# A = att_scores[sample] # (heads,pep_len,mhc_len) +# A_mean = A.mean(axis=0) # (pep_len,mhc_len) +# +# plt.figure(figsize=(8,6)) +# sns.heatmap(A_mean, +# cmap="viridis", +# xticklabels=[f"M{i}" for i in range(MHC_MAX)], +# yticklabels=[f"P{j}" for j in range(PEP_MAX)], +# cbar_kws={"label":"attention"}) +# plt.xlabel("MHC position"); plt.ylabel("Peptide position") +# plt.title("Peptide→MHC attention (heads averaged) – sample 0") +# plt.tight_layout() +# plt.show() +# +# # report positions with highest influence ------------------------------ +# pep_pos = np.argmax(A_mean, axis=1) # best MHC pos for each peptide token +# mhc_pos = np.argmax(A_mean, axis=0) # most queried peptide pos per MHC token +# +# print("\nTop MHC position attended by each peptide residue:") +# for p in range(PEP_MAX): +# print(f" peptide P{p:02d} ⇢ MHC M{pep_pos[p]:02d} (score={A_mean[p,pep_pos[p]]:.3f})") +# +# print("\nPeptide position with max attention received from each MHC residue:") +# for m in range(MHC_MAX): +# print(f" MHC M{m:02d} ⇠ peptide P{mhc_pos[m]:02d} (score={A_mean[mhc_pos[m],m]:.3f})") diff --git a/src/model4_recon.py b/src/model4_recon.py new file mode 100644 index 00000000..229b4896 --- /dev/null +++ b/src/model4_recon.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python +""" +pepmhc_cross_attention.py +------------------------- + +Minimal end-to-end demo of a peptide × MHC cross-attention classifier +with explainable attention visualisation. + +Author: 2025-05-22 +""" + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +import matplotlib.pyplot as plt +import seaborn as sns + +# -------------------------------------------------------------------------------- +# 1. Synthetic toy-data helpers +# -------------------------------------------------------------------------------- +AA = "ACDEFGHIKLMNPQRSTVWY" +MASK_IDX = 20 # new index +AA_TO_INT = {a: i for i, a in enumerate(AA)} +AA_DIM = 21 # 20 AA + 1 MASK + +def onehot(seq: str, max_len: int) -> np.ndarray: + mat = np.zeros((max_len, AA_DIM), dtype=np.float32) + for i, aa in enumerate(seq[:max_len]): + mat[i, AA_TO_INT.get(aa, MASK_IDX)] = 1.0 + return mat + +def onehot_to_seq(onehot_mat: np.ndarray) -> str: + """Convert one-hot encoding back to amino acid sequence.""" + indices = np.argmax(onehot_mat, axis=1) + seq = "" + for idx in indices: + if idx == MASK_IDX: + seq += "X" # Use X to represent masked positions + else: + seq += AA[idx] + return seq + +def toy_batch_masked(batch_size=32, + pep_max=14, + mhc_max=36, + mask_rate=0.3, # 30\% of peptide tokens will be masked + mhc_dim=1152): + """Synthetic (masked peptide input, true peptide, mask weights, mhc) batch.""" + peps_in, peps_true, mask_w, mhcs = [], [], [], [] + for _ in range(batch_size): + # ----- peptide ------------------------------------------------------- + Lp = np.random.randint(8, pep_max + 1) + pep = ''.join(np.random.choice(list(AA), Lp)) + oh = onehot(pep, pep_max) # ground-truth (B,pep_max,21) + + # decide which positions to mask + mpos = np.random.rand(pep_max) < mask_rate + oh_masked = oh.copy() + oh_masked[mpos] = 0.0 + oh_masked[mpos, MASK_IDX] = 1.0 # replace by MASK token + + peps_in.append(oh_masked) # model input + peps_true.append(oh) # y_true + mask_w.append(mpos.astype(np.float32)) # sample-weight + # ----- MHC latent ---------------------------------------------------- + Lm = np.random.randint(20, mhc_max + 1) + mhc_lat = np.random.randn(Lm, mhc_dim).astype(np.float32) + pad_mhc = np.zeros((mhc_max, mhc_dim), np.float32) + pad_mhc[:Lm] = mhc_lat + mhcs.append(pad_mhc) + + return (np.stack(peps_in), + np.stack(mhcs), + np.stack(peps_true), + np.stack(mask_w)) + + +# -------------------------------------------------------------------------------- +# 2. Model building block: cross-attention with score output +# -------------------------------------------------------------------------------- +""" +model.py – light-weight peptide × MHC cross-attention classifier + plus twin model that exposes the attention matrix + for explainability. + +The only public symbol is + build_classifier(max_pep_len, max_mhc_len, …) +which returns + clf_model, att_model +""" +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +import numpy as np + +# --------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------- +def cross_att_block(query, # (B,pep_len,D) + context, # (B,mhc_len,D) + heads=8, + name="xatt"): + """ + Keras MultiHeadAttention; returns (att_out, att_scores) + Shapes: + att_out (B, pep_len, D) + att_scores (B, heads, pep_len, mhc_len) + """ + mha = layers.MultiHeadAttention(num_heads=heads, + key_dim=query.shape[-1], + name=name) + + att_out, att_scores = mha(query=query, + value=context, + key=context, + return_attention_scores=True) + return att_out, att_scores + + +# --------------------------------------------------------------------- +# main builder +# --------------------------------------------------------------------- +# ----- inputs ------------------------------------------------------------ + + +def build_reconstruction_model(max_pep_len : int, + max_mhc_len : int, + pep_emb_dim : int = 64, + mhc_emb_dim : int = 64, + mhc_latent_dim : int = 1152, + heads : int = 8): + """ + Returns + clf_model – compiled model for training / inference + att_model – same weights, outputs attention scores only + """ + # Inputs ----------------------------------------------------------- + inp_pep = keras.Input(shape=(max_pep_len, AA_DIM), name="pep_onehot") + inp_mhc = keras.Input(shape=(max_mhc_len,mhc_latent_dim), name="mhc_latent") + + # ----- linear projections & positional enc ------------------------------ + pep_emb = layers.Dense(pep_emb_dim, activation=None, name="pep_proj")(inp_pep) + mhc_emb = layers.Dense(mhc_emb_dim, activation=None, name="mhc_proj")(inp_mhc) + + pep_pos = layers.Embedding(max_pep_len, pep_emb_dim, name="pep_pos")(tf.range(max_pep_len)) + mhc_pos = layers.Embedding(max_mhc_len, mhc_emb_dim, name="mhc_pos")(tf.range(max_mhc_len)) + pep_emb = pep_emb + pep_pos + mhc_emb = mhc_emb + mhc_pos + mhc_emb = layers.Dense(mhc_emb_dim, activation=None, name="mhc_dim_reduction")(mhc_emb) + + # ----- cross-attention --------------------------------------------------- + att_pep, att_scores = cross_att_block(pep_emb, mhc_emb, heads=heads, name="pep2mhc") + + # ----- per-token classifier --------------------------------------------- + logits = layers.Dense(AA_DIM, activation=None, name="logits")(att_pep) + probs = layers.Activation("softmax", name="aa_probs")(logits) # (B,pep_len,21) + + model = keras.Model([inp_pep, inp_mhc], logits, name="PepMHC_maskedLM") + + model.compile( + optimizer="adam", + loss=keras.losses.CategoricalCrossentropy(from_logits=True), + metrics=[keras.metrics.CategoricalAccuracy(name="masked_accuracy")], + sample_weight_mode="temporal", + ) + + # optional twin that only outputs attention ------------------------------ + att_model = keras.Model([inp_pep, inp_mhc], att_scores, name="PepMHC_attention") + + return model, att_model + + +# -------------------------------------------------------------------------------- +# 4. Demo run (synthetic data, 2 epochs, heat map) +# -------------------------------------------------------------------------------- +if __name__ == "__main__": + tf.random.set_seed(0); np.random.seed(0) + + PEP_MAX = 14 + MHC_MAX = 36 + BATCH = 64 + STEPS = 40 # keep tiny for a quick sanity run + + model, att_model = build_reconstruction_model(max_pep_len=PEP_MAX, max_mhc_len=MHC_MAX) + + print(model.summary(line_length=110)) + + # quick dummy training -------------------------------------------------- + for step in range(STEPS): + Xp, Xm, y_true, w = toy_batch_masked(BATCH, pep_max=PEP_MAX, mhc_max=MHC_MAX) + model.train_on_batch([Xp, Xm], y_true, sample_weight=w) + + ############## + Xp_test, Xm_test, y_test, mask_w_test = toy_batch_masked(8, pep_max=PEP_MAX, mhc_max=MHC_MAX) + preds = model.predict([Xp_test, Xm_test]) + att_scores = att_model.predict([Xp_test, Xm_test]) # (B,heads,pep_len,mhc_len) + + # Visualize peptide reconstruction examples + print("\n=== Peptide Reconstruction Examples ===") + for i in range(8): # Show first 3 examples + # Get original, masked and reconstructed peptides + original_pep = onehot_to_seq(y_test[i]) + masked_pep = onehot_to_seq(Xp_test[i]) + recon_pep = onehot_to_seq(preds[i]) + + # Trim to actual peptide length (remove trailing padding) + actual_len = len(original_pep.rstrip()) + original_pep = original_pep[:actual_len] + masked_pep = masked_pep[:actual_len] + recon_pep = recon_pep[:actual_len] + + # Highlight differences with formatting + highlighted_recon = "" + for j, (o, r) in enumerate(zip(original_pep, recon_pep)): + if masked_pep[j] == "X": + if o == r: + highlighted_recon += f"[{r}]" # Correctly reconstructed (was masked) + else: + highlighted_recon += f"({r})" # Incorrectly reconstructed (was masked) + else: + highlighted_recon += r # Position was not masked + + print(f"Example {i + 1}:") + print(f" Original: {original_pep}") + print(f" Masked input: {masked_pep}") + print(f" Reconstructed: {highlighted_recon}") + print(f" ([correct] / (incorrect) reconstruction of masked positions)\n") + + # Visualize attention for one example + plt.figure(figsize=(12, 5)) + example_idx = 0 + att_example = np.mean(att_scores[example_idx], axis=0) # Average over attention heads + + # Get non-padding length + pep_len = len(onehot_to_seq(y_test[example_idx]).rstrip()) + mhc_len = np.sum(np.any(Xm_test[example_idx] != 0, axis=1)) + + # Only display the actual peptide and MHC lengths (not padding) + att_display = att_example[:pep_len, :mhc_len] + + # Create heatmap + ax = sns.heatmap(att_display, cmap='viridis') + plt.title(f"Peptide-MHC Cross-Attention (Example {example_idx + 1})") + plt.xlabel("MHC position") + plt.ylabel("Peptide position") + + # Annotate masked positions on y-axis + masked_pep = onehot_to_seq(Xp_test[example_idx])[:pep_len] + plt.yticks(np.arange(pep_len) + 0.5, + [f"{i}:{aa}" + ("*" if aa == "X" else "") + for i, aa in enumerate(masked_pep)], + rotation=0) + + plt.tight_layout() + plt.show() + + diff --git a/src/run_SCQ_VAE.py b/src/run_SCQ_VAE.py new file mode 100644 index 00000000..6c0ae8a6 --- /dev/null +++ b/src/run_SCQ_VAE.py @@ -0,0 +1,746 @@ +#!/usr/bin/env python +""" +========================= + +End‑to‑end trainer for a **peptide×MHC SCQ-VAE clustering**. +For each fold, load cross_latents.npz containing cross-atnn and mhc_ids, and labels. +train and evaluate SCQ-VAE model on the cross-attn data. +Visualize the results with t-SNE and UMAP with mhc_ids and labels. + +Author: Amirreza (memory-optimized version, 2025) +""" +import os +import time + +import numpy as np +import pandas as pd +import tensorflow as tf +from matplotlib import pyplot as plt +from tensorflow.keras import layers +from sklearn.manifold import TSNE +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +import umap +from tqdm import tqdm +from utils.model import SCQ1DAutoEncoder +from sklearn.cluster import KMeans + +# Enable eager execution explicitly to resolve graph mode issues +tf.config.run_functions_eagerly(True) +from tensorflow import keras + +# Global random state for reproducibility +random_state = 42 + +def load_cross_latents_data(file_path): + """Load cross-attention data from a .npz file.""" + data = np.load(file_path) + cross_latents = data['cross_latents'] + mhc_ids = data['mhc_ids'] + labels = data['labels'] + return cross_latents, mhc_ids, labels # (N, seq_length, embedding_dim), (N,), (N,) + +def create_dataset(cross_latents, labels=None): + """Create a TensorFlow dataset from cross-latents data.""" + if labels is None: + dataset = tf.data.Dataset.from_tensor_slices(tf.cast(cross_latents, tf.float32)) + else: + # Cast features to float32 + features = tf.cast(cross_latents, tf.float32) + labels = tf.cast(labels, tf.float32) + dataset = tf.data.Dataset.from_tensor_slices((features, labels)) + + return dataset + +def initialize_codebook_with_kmeans(X_train, num_embeddings, embedding_dim): + """Initialize codebook vectors using k-means clustering.""" + kmeans = KMeans(n_clusters=num_embeddings, random_state=random_state) + flat_data = X_train.reshape(-1, X_train.shape[-1]) + kmeans.fit(flat_data) + return kmeans.cluster_centers_.astype(np.float32) + +# Visualization functions +def plot_training_metrics(history, save_path=None): + """Plot training and validation loss and accuracy.""" + return + + +def plot_reconstructions(original, reconstructed, n_samples=5, save_path=None, visualize=True): + """Plot comparison between original and reconstructed sequences.""" + n_samples = min(n_samples, len(original)) + plt.figure(figsize=(15, 3 * n_samples)) + for i in range(n_samples): + plt.subplot(n_samples, 2, 2 * i + 1) + plt.plot(original[i]) + plt.title(f"Original Sequence {i + 1}") + plt.grid(True) + plt.subplot(n_samples, 2, 2 * i + 2) + plt.plot(reconstructed[i]) + plt.title(f"Reconstructed Sequence {i + 1}") + plt.grid(True) + plt.tight_layout() + if save_path: + plt.savefig(save_path) + if visualize: + plt.show() + else: + plt.close() + + +def plot_codebook_usage(indices, num_embeddings, save_path=None, visualize=True): + """Visualize the usage distribution of codebook vectors. + + Args: + indices: Hard indices from model output (integer indices of assigned codes) + num_embeddings: Total number of vectors in the codebook + save_path: Path to save the visualization + + Returns: + used_vectors: Number of vectors used + usage_percentage: Percentage of codebook utilized + """ + # Convert from tensor to numpy if needed + if isinstance(indices, tf.Tensor): + indices = indices.numpy() + + # Ensure indices is a NumPy array before proceeding + if not isinstance(indices, np.ndarray): + try: + # Attempt conversion if it's list-like + indices = np.array(indices) + except Exception as e: + print( + f"Error in plot_codebook_usage: Input 'indices' is not a NumPy array or convertible. Type: {type(indices)}. Error: {e}") + # Return default values or raise an error + return 0, 0.0 + + # Flatten indices to 1D array for counting + try: + flat_indices = indices.flatten() + except AttributeError: + print(f"Error in plot_codebook_usage: Cannot flatten 'indices'. Type: {type(indices)}") + return 0, 0.0 # Return default values + + # Count occurrences of each codebook vector + try: + unique, counts = np.unique(flat_indices, return_counts=True) + except TypeError as e: + print( + f"Error in plot_codebook_usage: Cannot compute unique values for 'flat_indices'. Type: {type(flat_indices)}. Error: {e}") + return 0, 0.0 # Return default values + + # Create full distribution including zeros for unused vectors + full_distribution = np.zeros(num_embeddings) + for idx, count in zip(unique, counts): + if 0 <= idx < num_embeddings: # Ensure index is valid + full_distribution[int(idx)] = count + + # Calculate usage statistics + used_vectors = np.sum(full_distribution > 0) + usage_percentage = (used_vectors / num_embeddings) * 100 + + # Create the plot + plt.figure(figsize=(12, 6)) + bar_positions = np.arange(num_embeddings) + bars = plt.bar(bar_positions, full_distribution) + + # Color bars by frequency + max_count = np.max(full_distribution) + if max_count > 0: + for i, bar in enumerate(bars): + intensity = full_distribution[i] / max_count + bar.set_color(plt.cm.viridis(intensity)) + + plt.xlabel('Codebook Vector Index') + plt.ylabel('Usage Count') + plt.title( + f'Codebook Vector Usage Distribution\n{used_vectors}/{num_embeddings} vectors used ({usage_percentage:.1f}%)') + plt.grid(axis='y', linestyle='--', alpha=0.7) + + # Add colorbar + sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(0, max_count)) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=plt.gca()) + cbar.set_label('Frequency') + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + print(f"Codebook usage: {used_vectors}/{num_embeddings} vectors used ({usage_percentage:.1f}%)") + if visualize: + plt.show() + else: + plt.close() + return used_vectors, usage_percentage + + +def plot_soft_cluster_distribution(soft_cluster_probs, num_samples=None, save_path=None, visualize=True): + """ + Visualize the distribution of soft clustering probabilities across samples. + + Args: + soft_cluster_probs: List of arrays or single array containing probability + distributions across clusters for each sample + num_samples: Number of samples to visualize (default: 10) + save_path: Path to save the visualization (default: None) + + Returns: + None + """ + if num_samples is None: + num_samples = len(soft_cluster_probs) + # Process input to get a clean 2D array (samples x clusters) + if isinstance(soft_cluster_probs, list): + if not soft_cluster_probs: + print("Warning: Empty list provided to plot_soft_cluster_distribution") + return + sample_batch = soft_cluster_probs[0] + if isinstance(sample_batch, tf.Tensor): + sample_batch = sample_batch.numpy() + soft_cluster_probs = sample_batch + elif isinstance(soft_cluster_probs, tf.Tensor): + soft_cluster_probs = soft_cluster_probs.numpy() + + # Reshape if needed + if soft_cluster_probs.ndim > 2: + print(f"Reshaping array from shape {soft_cluster_probs.shape} to 2D format") + if soft_cluster_probs.shape[1] == 1: + soft_cluster_probs = soft_cluster_probs.reshape(soft_cluster_probs.shape[0], + soft_cluster_probs.shape[2]) + else: + soft_cluster_probs = soft_cluster_probs.reshape(soft_cluster_probs.shape[0], -1) + + # Verify valid data + if soft_cluster_probs.size == 0: + print("Warning: Empty array provided to plot_soft_cluster_distribution") + return + + n_samples = min(len(soft_cluster_probs), 1000) # Limit to prevent memory issues + n_clusters = soft_cluster_probs.shape[1] + + # Create a figure with 2 subplots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12), gridspec_kw={'height_ratios': [2, 1]}) + + # 1. Heatmap visualization (top) + display_samples = min(num_samples, n_samples) + im = ax1.imshow(soft_cluster_probs[:display_samples], aspect='auto', cmap='viridis') + ax1.set_xlabel('Cluster Index') + ax1.set_ylabel('Sample Index') + ax1.set_title(f'Soft Cluster Probability Heatmap (First {display_samples} Samples)') + plt.colorbar(im, ax=ax1, label='Probability') + + # 2. Aggregated statistics (bottom) + # Calculate statistics across all samples for each cluster + cluster_means = np.mean(soft_cluster_probs, axis=0) + cluster_max_counts = np.sum(np.argmax(soft_cluster_probs, axis=1)[:, np.newaxis] == np.arange(n_clusters), axis=0) + + # Create a twin axis for the bar plot + ax2_twin = ax2.twinx() + + # Plot mean probability for each cluster (line) + x = np.arange(n_clusters) + ax2.plot(x, cluster_means, 'r-', linewidth=2, label='Mean Probability') + ax2.set_ylabel('Mean Probability', color='r') + ax2.tick_params(axis='y', labelcolor='r') + ax2.set_ylim(0, max(cluster_means) * 1.2) + + # Plot histogram of cluster assignments (bars) + ax2_twin.bar(x, cluster_max_counts, alpha=0.3, label='Assignment Count') + ax2_twin.set_ylabel('Number of Samples\nwith Highest Probability', color='b') + ax2_twin.tick_params(axis='y', labelcolor='b') + + # Add labels and grid + ax2.set_xlabel('Cluster Index') + ax2.set_title('Cluster Usage Statistics Across All Samples') + ax2.set_xticks(np.arange(0, n_clusters, max(1, n_clusters // 20))) + ax2.grid(True, linestyle='--', alpha=0.5, axis='y') + + # Create custom legend + lines, labels = ax2.get_legend_handles_labels() + lines2, labels2 = ax2_twin.get_legend_handles_labels() + ax2.legend(lines + lines2, labels + labels2, loc='upper right') + + # Add overall statistics as text + active_clusters = np.sum(np.max(soft_cluster_probs, axis=0) > 0.01) + most_used_cluster = np.argmax(cluster_max_counts) + ax2.text(0.02, 0.95, + f"Active clusters: {active_clusters}/{n_clusters} ({active_clusters / n_clusters:.1%})\n" + f"Most used cluster: {most_used_cluster} ({cluster_max_counts[most_used_cluster]} samples)", + transform=ax2.transAxes, verticalalignment='top', + bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + plt.tight_layout() + + # Save if path is provided + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + + # Show or close based on global visualize flag + if visualize: + plt.show() + else: + plt.close() + + +def plot_cluster_distribution(soft_cluster_probs, save_path=None, visualize=True): + """ + Plot distribution of samples across clusters based on one-hot encodings. + + Args: + soft_cluster_probs: Soft cluster probabilities + save_path: Path to save the plot + + Returns: + used_clusters: Number of clusters used + usage_percentage: Percentage of clusters used + """ + # Convert soft cluster probabilities to one-hot encodings + one_hot_encodings = tf.one_hot(tf.argmax(soft_cluster_probs, axis=-1), depth=soft_cluster_probs.shape[-1]) + one_hot_encodings = tf.cast(one_hot_encodings, tf.float32) + + # print(f"one_hot_encodings shape: {one_hot_encodings.shape}") + # print first 5 values + # print(f"one_hot_encodings values: {one_hot_encodings[:5]}") + + # Convert one-hot to cluster indices if needed + if isinstance(one_hot_encodings, tf.Tensor): + one_hot_encodings = one_hot_encodings.numpy() + + # Handle different shapes of one_hot_encodings + if len(one_hot_encodings.shape) == 3: # (batch, seq_len, num_embeddings) + cluster_assignments = np.argmax(one_hot_encodings, axis=-1).flatten() + else: # (batch, num_embeddings) + cluster_assignments = np.argmax(one_hot_encodings, axis=-1).flatten() + + # Count occurrences of each cluster + unique_clusters, counts = np.unique(cluster_assignments, return_counts=True) + + # Create a full distribution including zeros for unused clusters + num_clusters = one_hot_encodings.shape[-1] + full_distribution = np.zeros(num_clusters) + for cluster, count in zip(unique_clusters, counts): + full_distribution[cluster] = count + + # Calculate usage statistics + used_clusters = np.sum(full_distribution > 0) + usage_percentage = (used_clusters / num_clusters) * 100 + + # Create the plot + plt.figure(figsize=(12, 6)) + bars = plt.bar(np.arange(num_clusters), full_distribution) + + # Color bars by frequency + max_count = np.max(full_distribution) + if max_count > 0: + for i, bar in enumerate(bars): + intensity = full_distribution[i] / max_count + bar.set_color(plt.cm.plasma(intensity)) + + plt.xlabel('Cluster Index') + plt.ylabel('Number of Samples') + plt.title( + f'Sample Distribution Across Clusters\n{used_clusters}/{num_clusters} clusters used ({usage_percentage:.1f}%)') + plt.grid(axis='y', linestyle='--', alpha=0.7) + + # Add a colorbar + sm = plt.cm.ScalarMappable(cmap=plt.cm.plasma, norm=plt.Normalize(0, max_count)) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=plt.gca()) + cbar.set_label('Sample Count') + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + print(f"Cluster usage: {used_clusters}/{num_clusters} clusters contain samples ({usage_percentage:.1f}%)") + if visualize: + plt.show() + else: + plt.close() + return used_clusters, usage_percentage + +def plot_tsne_umap(cross_latents, mhc_ids, labels, save_path=None, visualize=True): + """Plot t-SNE and UMAP visualizations of the cross-latents.""" + # Handle dimensions if cross_latents is 3D (reshape to 2D) + if cross_latents.ndim > 2: + cross_latents = cross_latents.reshape(cross_latents.shape[0], -1) + + # Standardize the data + scaler = StandardScaler() + cross_latents_scaled = scaler.fit_transform(cross_latents) + + # t-SNE + tsne = TSNE(n_components=2, random_state=random_state) + tsne_results = tsne.fit_transform(cross_latents_scaled) + + # UMAP + reducer = umap.UMAP(random_state=random_state) + umap_results = reducer.fit_transform(cross_latents_scaled) + + # Plotting + plt.figure(figsize=(18, 9)) + + # Determine color source + color_source = labels + color_label = 'Labels' + + if labels is None or (isinstance(labels, np.ndarray) and len(labels) == 0): + if mhc_ids is not None and len(mhc_ids) > 0: + # Convert string MHC IDs to categorical indices if needed + if isinstance(mhc_ids[0], (str, bytes)): + unique_mhcs = np.unique(mhc_ids) + mhc_to_index = {mhc: i for i, mhc in enumerate(unique_mhcs)} + color_source = np.array([mhc_to_index[mhc] for mhc in mhc_ids]) + else: + color_source = mhc_ids + color_label = 'MHC IDs' + else: + # No coloring available + color_source = None + + # t-SNE plot + plt.subplot(1, 2, 1) + if color_source is not None: + sc = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=color_source, cmap='viridis', s=5, alpha=0.7) + plt.colorbar(sc, label=color_label) + else: + plt.scatter(tsne_results[:, 0], tsne_results[:, 1], s=5, alpha=0.7) + + plt.title('t-SNE Visualization') + plt.xlabel('t-SNE Component 1') + plt.ylabel('t-SNE Component 2') + plt.grid(True, linestyle='--', alpha=0.7) + + # UMAP plot + plt.subplot(1, 2, 2) + if color_source is not None: + sc = plt.scatter(umap_results[:, 0], umap_results[:, 1], c=color_source, cmap='viridis', s=5, alpha=0.7) + plt.colorbar(sc, label=color_label) + else: + plt.scatter(umap_results[:, 0], umap_results[:, 1], s=5, alpha=0.7) + + plt.title('UMAP Visualization') + plt.xlabel('UMAP Component 1') + plt.ylabel('UMAP Component 2') + plt.grid(True, linestyle='--', alpha=0.7) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Plots saved to {save_path}") + + if visualize: + plt.show() + else: + plt.close() + + +def plot_PCA(cross_latents, mhc_ids=None, labels=None, save_path=None, color_map_name="MHC IDs"): + """Plot PCA visualization of the cross-latents with optional label highlighting.""" + + # reduce dimensions if necessary + if cross_latents.ndim > 2: + # Flatten the cross_latents to ensure they are 2D (N, seq_length * embedding_dim) + cross_latents = cross_latents.reshape(cross_latents.shape[0], -1) + + # Standardize the data + scaler = StandardScaler() + cross_latents_scaled = scaler.fit_transform(cross_latents) + + # PCA + pca = PCA(n_components=2) + pca_results = pca.fit_transform(cross_latents_scaled) + + # Plotting + import matplotlib.pyplot as plt + import numpy as np + + plt.figure(figsize=(8, 6)) + + if mhc_ids is not None: + # Convert string MHC IDs to categorical indices + unique_mhcs = np.unique(mhc_ids) + mhc_to_index = {mhc: i for i, mhc in enumerate(unique_mhcs)} + numeric_mhc_ids = np.array([mhc_to_index[mhc] for mhc in mhc_ids]) + + # Plot using the numeric encoding + sc = plt.scatter(pca_results[:, 0], pca_results[:, 1], c=numeric_mhc_ids, cmap='viridis', s=5) + plt.colorbar(sc, label=color_map_name) + + # If not too many unique MHCs, add a legend + if len(unique_mhcs) <= 20: + from matplotlib.lines import Line2D + cmap = plt.cm.get_cmap('viridis', len(unique_mhcs)) + legend_elements = [ + Line2D([0], [0], marker='o', color='w', markerfacecolor=cmap(i), + label=str(mhc.decode('utf-8') if isinstance(mhc, bytes) else mhc), markersize=5) + for i, mhc in enumerate(unique_mhcs) + ] + plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left') + else: + # Default plotting without colors + plt.scatter(pca_results[:, 0], pca_results[:, 1], s=5) + + # Highlight positive labels (value 1) + if labels is not None: + # Flatten and convert labels if needed + if isinstance(labels, (list, np.ndarray)) and len(labels) > 0: + flat_labels = np.asarray(labels).flatten() + # Find indices of points with positive labels (value 1 or close to 1) + positive_indices = np.where(np.isclose(flat_labels, 1.0, atol=0.01))[0] + + if len(positive_indices) > 0: + # Plot circles around positive points + plt.scatter(pca_results[positive_indices, 0], pca_results[positive_indices, 1], + s=30, facecolors='none', edgecolors='orange', linewidths=0.07, + label='Positive Labels') + + # Add a legend entry for positive labels + plt.legend(loc='upper right') + plt.title('PCA Visualization (orange circles: Positive Labels)') + else: + plt.title('PCA Visualization (No positive labels found)') + else: + plt.title('PCA Visualization') + else: + plt.title('PCA Visualization') + + if save_path: + plt.savefig(save_path, bbox_inches='tight') + print(f"PCA plot saved to {save_path}") + + plt.show() + + +def process_and_save(dataset: tf.data.Dataset, + split_name: str, + model: tf.keras.Model, + output_dir: str, + num_embeddings: int, + mhc_ids: np.ndarray = None, + labels: np.ndarray = None): + """ + Quantize `dataset` through `model`, assemble into a DataFrame, + save to parquet, and plot distributions with split-specific filenames. + """ + original = [] + quantized_latents = [] + cluster_indices_soft = [] + cluster_indices_hard = [] + labels = [] + + # Extract latent codes for every batch + for batch_X, batch_y in dataset: + Zq, out_P_proj, _, _ = model.encode_(batch_X) + quantized_latents.append(Zq.numpy()) + cluster_indices_soft.append(out_P_proj.numpy()) + cluster_indices_hard.append(tf.argmax(out_P_proj, axis=-1).numpy()) + labels.append(batch_y.numpy()) + # Store original sequences if available + original.append(batch_X.numpy()) + + # Concatenate across batches + quantized_latent = np.concatenate(quantized_latents, axis=0) + soft_probs = np.concatenate(cluster_indices_soft, axis=0) + hard_assign = np.concatenate(cluster_indices_hard, axis=0) + labels = np.concatenate(labels, axis=0) + + # Build DataFrame + records = [] + for i in range(len(quantized_latent)): + rec = {} + flat_latent = quantized_latent[i].flatten() + for j, v in enumerate(flat_latent): + rec[f'latent_{j}'] = float(v) + + flat_soft = soft_probs[i].flatten() + for j, v in enumerate(flat_soft): + rec[f'soft_cluster_{j}'] = float(v) + + rec['hard_cluster'] = int(hard_assign[i]) + # binding label if available + if i < len(labels): + lbl = labels[i] + # Handle scalar, 0-dim, or 1-dim arrays + if isinstance(lbl, (np.ndarray, list)) and np.asarray(lbl).size > 0: + rec['binding_label'] = float(np.asarray(lbl).flatten()[0]) + else: + rec['binding_label'] = float(lbl) + else: + rec['binding_label'] = np.nan + records.append(rec) + + df = pd.DataFrame(records) + # ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Save parquet + parquet_path = os.path.join(output_dir, f'quantized_outputs_{split_name}.parquet') + df.to_parquet(parquet_path, index=False) + print(f"[{split_name}] saved parquet → {parquet_path}") + + # Plot distributions + # plot_cluster_distribution(soft_probs, + # save_path=os.path.join(output_dir, f'cluster_distribution_soft_{split_name}.png')) + plot_codebook_usage(hard_assign, num_embeddings, + save_path=os.path.join(output_dir, f'codebook_usage_{split_name}.png')) + plot_soft_cluster_distribution(soft_probs, 20, + save_path=os.path.join(output_dir, + f'soft_cluster_distribution_{split_name}.png')) + + # visualize PCA with label highlights + print(quantized_latent.shape) + # Pass both mhc_ids and labels to plot_PCA + plot_PCA(quantized_latent, + save_path=os.path.join(output_dir, f'quantized_latents_pca_{split_name}.png'), + mhc_ids=mhc_ids, + labels=labels,) # Add labels parameter + + + + # visualize by hard cluster assignments + plot_PCA(quantized_latent, + save_path=os.path.join(output_dir, f'quantized_latents_pca_hard_{split_name}.png'), + mhc_ids=hard_assign.flatten(), + labels=labels, + color_map_name="Cluster IDs") # Add labels parameter + + # visualize t-SNE and UMAP + plot_tsne_umap(quantized_latent, mhc_ids=hard_assign.flatten(), labels=labels, + save_path=os.path.join(output_dir, f'quantized_latents_tsne_umap_{split_name}.png')) + print(f"[{split_name}] processed and saved quantized outputs with {len(df)} records.") + + # visualize reconstructions + plot_reconstructions(quantized_latents[0][:5], # First 5 samples) + original[0][:5], # First 5 samples + n_samples=5, + save_path=os.path.join(output_dir, f'reconstructions_{split_name}.png'), + visualize=False) # Set visualize to False to save without showing + + +def main(): + # --- Configuration --- + num_embeddings = 16 # Number of embeddings in the codebook + embedding_dim = 4 # Dimension of each embedding vector + commitment_beta = 0.25 # Commitment loss weight + batch_size = 32 # Batch size for training + epochs = 3 # Number of training epochs + learning_rate = 1e-3 # Learning rate for the optimizer + + cross_latents_file = 'runs/run_20250626-114618/cross_latent_test1_fold_1.npz' # Path to the cross-attention data file + val_file = 'runs/run_20250626-114618/cross_latent_test2_fold_1.npz' # Path to the validation data file (if needed) + save_dir = 'runs/run_20250626-114618/scq' # Directory to save results and plots + os.makedirs(save_dir, exist_ok=True) + + # --- Load Cross-Attention Data --- + print("Loading cross-attention data...") + cross_latents, mhc_ids, labels = load_cross_latents_data(cross_latents_file) + print("Cross-latents shape:", cross_latents.shape) + # flatten the cross_latents to ensure they are 2D (N, seq_length * embedding_dim) + cross_latents = cross_latents.reshape(cross_latents.shape[0], -1) # Flatten to (N, seq_length * embedding_dim) + # cross_latents = cross_latents.mean(axis=1) # Average across the sequence length dimension + seq_length = cross_latents.shape[1] # Length of the sequences + + val_latents, val_mhc_ids, val_labels = load_cross_latents_data(val_file) + if val_latents is not None: + print("Validation cross-latents shape:", val_latents.shape) + val_data = val_latents.reshape(val_latents.shape[0], -1) + # val_data = val_latents.mean(axis=1) + else: + print("No validation data found, proceeding without validation set.") + val_data, val_mhc_ids, val_labels = None, None, None + + + # --- Create TensorFlow Dataset --- + print("Creating TensorFlow dataset...") + dataset = create_dataset(cross_latents, labels) + dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE) + # drop labels + dataset_train = create_dataset(cross_latents) + dataset_train = dataset_train.batch(batch_size).prefetch(tf.data.AUTOTUNE) # Batch and prefetch for performance + print("Dataset created with {} batches.".format(len(dataset))) + # --- Print Dataset Information --- + print("Cross-latents shape:", cross_latents.shape) + print("MHC IDs shape:", mhc_ids.shape) + print("Labels shape:", labels.shape) + print("Sequence length:", seq_length) + + # Visualize raw cross-latents data + print("Visualizing raw cross-latents data with PCA...") + plot_PCA(cross_latents, mhc_ids, save_path=os.path.join(save_dir, 'cross_latents_pca.png'), labels=labels) + print("Visualizing cross-latents data with t-SNE and UMAP...") + plot_tsne_umap(cross_latents, mhc_ids, labels, save_path=os.path.join(save_dir, 'cross_latents_tsne_umap.png')) + print("Cross-latents data visualization completed.") + + # --- Initialize Codebook --- + print("Initializing codebook with k-means...") + codebook_init = initialize_codebook_with_kmeans(cross_latents, num_embeddings, embedding_dim) + print("Codebook initialized with shape:", codebook_init.shape) + + # --- Model Instantiation --- + print("Building the SCQ1DAutoEncoder model...") + input_shape = (seq_length,) + # Print detailed information about input dimensions + print(f"Input shape being passed to model: {input_shape}") + print(f"Expected embedding_dim: {embedding_dim}") + print(f"Actual shape of cross_latents: {cross_latents.shape}") + + # Check if the embedding_dim needs adjustment based on the actual data + if embedding_dim != cross_latents.shape[-1]: + print(f"Warning: Adjusting embedding_dim from {embedding_dim} to match data: {cross_latents.shape[-1]}") + embedding_dim = cross_latents.shape[-1] + + model = SCQ1DAutoEncoder( + input_dim=input_shape, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + commitment_beta=commitment_beta, + initial_codebook=codebook_init, + scq_params={ + 'lambda_reg': 1.0, + 'discrete_loss': False, + 'reset_dead_codes': True, + 'usage_threshold': 1e-4, + 'reset_interval': 5 + }, + cluster_lambda=1 + ) + print("Model built.") + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate)) + + # --- Training the Model --- + print("Starting training...") + print(f"Training on balanced set for {epochs} epochs...") + start_time = time.time() + history = model.fit( + dataset_train, + validation_data=val_data, + epochs=epochs, + ) + end_time = time.time() + print(f"Balanced training finished in {end_time - start_time:.2f} seconds.") + + # --- Save the Model --- + model_save_path = os.path.join(save_dir, 'scq_vae_model.h5') + print(f"Saving the model to {model_save_path}...") + model.save(model_save_path) + print("Model saved successfully.") + + # --- Visualize Results --- + print("Evaluating the model...") + process_and_save(dataset, 'train', model, save_dir, num_embeddings, mhc_ids) + if val_data is not None: + val_dataset = create_dataset(val_data, val_labels).batch(batch_size) + evaluation_results = model.evaluate(val_dataset) + print(f"Validation Loss: {evaluation_results[0]}, Validation Accuracy: {evaluation_results[1]}") + print("Visualizing results...") + process_and_save(val_dataset, 'val', model, save_dir, num_embeddings, val_mhc_ids) + else: + print("No validation data provided, skipping evaluation.") + print("Model evaluation completed.") + + # --- End of Main Function --- + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/run_model3.py b/src/run_model3.py new file mode 100644 index 00000000..35a2fcfa --- /dev/null +++ b/src/run_model3.py @@ -0,0 +1,859 @@ +#!/usr/bin/env python +""" +========================= + +MEMORY-OPTIMIZED End‑to‑end trainer for a **peptide×MHC cross‑attention classifier**. +Loads NetMHCpan‑style parquet files in true streaming fashion without loading entire datasets into memory. + +Key improvements: +1. Streaming parquet reading with configurable batch sizes +2. Lazy evaluation of dataset properties (seq length, class balance) +3. Memory-efficient TensorFlow data pipelines +4. Proper cleanup and memory monitoring + +Author: Amirreza (memory-optimized version, 2025) +""" +from __future__ import annotations +import os +import sys + +print(sys.executable) + +# ============================================================================= +# CRITICAL: GPU Memory Configuration - MUST BE FIRST +# ============================================================================= +import tensorflow as tf + + +def configure_gpu_memory(): + """Configure TensorFlow to use GPU memory efficiently""" + try: + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + print(f"Found {len(gpus)} GPU(s)") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + print("✓ GPU memory growth enabled") + else: + print("No GPUs found - running on CPU") + except RuntimeError as e: + print(f"GPU configuration error: {e}") + + +# Configure GPU immediately +configure_gpu_memory() + +# --------------------------------------------------------------------- +# ► Use all logical CPU cores for TF ops that still run on CPU +# --------------------------------------------------------------------- +NUM_CPUS = os.cpu_count() or 1 +tf.config.threading.set_intra_op_parallelism_threads(NUM_CPUS) +tf.config.threading.set_inter_op_parallelism_threads(NUM_CPUS) +print(f'✓ TF intra/inter-op threads set to {NUM_CPUS}') + +# Set memory-friendly environment variables +os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async' +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' +os.environ["PYTHONHASHSEED"] = "42" +os.environ["TF_DETERMINISTIC_OPS"] = "1" + +import math +import argparse, datetime, pathlib, json +import psutil +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from model3 import build_classifier +from sklearn.metrics import ( + confusion_matrix, roc_curve, auc, precision_score, + recall_score, f1_score, accuracy_score, roc_auc_score +) +import seaborn as sns +import pyarrow.parquet as pq +import gc +import weakref +import pyarrow as pa, pyarrow.compute as pc +pa.set_cpu_count(os.cpu_count()) + + +# ============================================================================= +# Memory monitoring functions +# ============================================================================= +def monitor_memory(): + """Monitor system memory usage""" + memory = psutil.virtual_memory() + print(f"System RAM: {memory.used / 1e9:.1f}GB / {memory.total / 1e9:.1f}GB ({memory.percent:.1f}% used)") + + try: + from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo + nvmlInit() + deviceCount = nvmlDeviceGetCount() + for i in range(deviceCount): + handle = nvmlDeviceGetHandleByIndex(i) + info = nvmlDeviceGetMemoryInfo(handle) + print( + f"GPU {i}: {info.used / 1e9:.1f}GB / {info.total / 1e9:.1f}GB ({100 * info.used / info.total:.1f}% used)") + except: + print("GPU memory monitoring not available") + + +def cleanup_memory(): + """Aggressive memory cleanup""" + gc.collect() + try: + tf.keras.backend.clear_session() + except: + pass + + +# ---------------------------------------------------------------------------- +# Peptide encoding utilities +# ---------------------------------------------------------------------------- +AA = "ACDEFGHIKLMNPQRSTVWY" # 20 standard AAs, order fixed +AA_TO_IDX = {aa: i for i, aa in enumerate(AA)} +UNK_IDX = 20 # index for unknown +PAD_TOKEN = -2 # set manually to avoid confusion with UNK + + +def peptides_to_onehot(sequence: str, max_seq_len: int) -> np.ndarray: + """Convert peptide sequence to one-hot encoding""" + arr = np.full((max_seq_len, 21), PAD_TOKEN, dtype=np.float32) # initialize padding with -2 + for j, aa in enumerate(sequence.upper()[:max_seq_len]): + arr[j, AA_TO_IDX.get(aa, UNK_IDX)] = 1.0 + # print number of UNKs in the sequence + num_unks = np.sum(arr[:, UNK_IDX]) + if num_unks > 0: + print(f"Warning: {num_unks} unknown amino acids in sequence '{sequence}'") + return arr + + +def _read_embedding_file(path: str | os.PathLike) -> np.ndarray: + """Robust loader for latent embeddings""" + try: + arr = np.load(path) + if isinstance(arr, np.ndarray) and arr.dtype == np.float32: + return arr + raise ValueError + except ValueError: + obj = np.load(path, allow_pickle=True) + if isinstance(obj, np.ndarray) and obj.dtype == object: + obj = obj.item() + if isinstance(obj, dict) and "embedding" in obj: + return obj["embedding"].astype("float32") + raise ValueError(f"Unrecognised embedding file {path}") + + +# ---------------------------------------------------------------------------- +# Streaming dataset utilities +# ---------------------------------------------------------------------------- +class StreamingParquetReader: + """Memory-efficient streaming parquet reader""" + + def __init__(self, parquet_path: str, batch_size: int = 1000): + self.parquet_path = parquet_path + self.batch_size = batch_size + self._file = None + self._num_rows = None + + def __enter__(self): + self._file = pq.ParquetFile(self.parquet_path) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._file: + self._file = None + + @property + def num_rows(self): + """Get total number of rows without loading data""" + if self._num_rows is None: + if self._file is None: + with pq.ParquetFile(self.parquet_path) as f: + self._num_rows = f.metadata.num_rows + else: + self._num_rows = self._file.metadata.num_rows + return self._num_rows + + def iter_batches(self): + """Iterate over parquet file in batches""" + if self._file is None: + raise RuntimeError("Reader not opened. Use within 'with' statement.") + + for batch in self._file.iter_batches(batch_size=self.batch_size): + df = batch.to_pandas() + yield df + del df, batch # Explicit cleanup + + def sample_for_metadata(self, n_samples: int = 1000): + """Sample a small portion for metadata extraction""" + with pq.ParquetFile(self.parquet_path) as f: + # Read first batch for metadata + first_batch = next(f.iter_batches(batch_size=min(n_samples, self.num_rows))) + return first_batch.to_pandas() + + +def get_dataset_metadata(parquet_path: str): + """Extract dataset metadata without loading full dataset""" + with StreamingParquetReader(parquet_path) as reader: + sample_df = reader.sample_for_metadata(reader.num_rows) + + metadata = { + 'total_rows': reader.num_rows, + 'max_peptide_length': int(sample_df['long_mer'].str.len().max()) if 'long_mer' in sample_df.columns else 0, + 'class_distribution': sample_df[ + 'assigned_label'].value_counts().to_dict() if 'assigned_label' in sample_df.columns else {}, + } + + del sample_df + return metadata + + +def calculate_class_weights(parquet_path: str): + """Calculate class weights from a sample of the dataset""" + with StreamingParquetReader(parquet_path, batch_size=1000) as reader: + label_counts = {0: 0, 1: 0} + for batch_df in reader.iter_batches(): + batch_labels = batch_df['assigned_label'].values + unique, counts = np.unique(batch_labels, return_counts=True) + for label, count in zip(unique, counts): + if label in [0, 1]: + label_counts[int(label)] += count + del batch_df + + # Calculate balanced class weights + total = sum(label_counts.values()) + if total == 0 or label_counts[0] == 0 or label_counts[1] == 0: + return {0: 1.0, 1: 1.0} + + return { + 0: total / (2 * label_counts[0]), + 1: total / (2 * label_counts[1]) + } + + +# --------------------------------------------------------------------- +# Utility that is executed in worker processes +# (must be top-level so it can be pickled on Windows) +# --------------------------------------------------------------------- +def _row_to_tensor_pack(row_dict: dict, max_pep_seq_len: int, max_mhc_len: int): + """Convert a single row (already in plain-python dict form) into tensors.""" + # --- peptide one-hot ------------------------------------------------ + pep = row_dict["long_mer"].upper()[:max_pep_seq_len] + pep_arr = np.zeros((max_pep_seq_len, 21), dtype=np.float32) + for j, aa in enumerate(pep): + pep_arr[j, AA_TO_IDX.get(aa, UNK_IDX)] = 1.0 + + # --- load MHC embedding -------------------------------------------- + mhc = _read_embedding_file(row_dict["mhc_embedding_path"]) + # pad MHC to max_mhc_len if needed + if mhc.shape[0] < max_mhc_len: + pad_mhc = np.full((max_mhc_len, mhc.shape[1]), PAD_TOKEN, dtype=np.float32) + pad_mhc[:mhc.shape[0]] = mhc + mhc = pad_mhc + elif mhc.shape[0] > max_mhc_len: + raise ValueError( + f"MHC length {mhc.shape[0]} exceeds max_mhc_len {max_mhc_len} for row: {row_dict}") + + if mhc.shape[0] != max_mhc_len: # sanity check + raise ValueError(f"MHC length mismatch: {mhc.shape[0]} vs {max_mhc_len}") + + # --- label ---------------------------------------------------------- + label = float(row_dict["assigned_label"]) + return (pep_arr, mhc.astype("float32")), label + +from concurrent.futures import ProcessPoolExecutor +import functools, itertools + +def streaming_data_generator( + parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 1000): + """ + Yields *individual* samples, but converts an entire Parquet batch + on multiple CPU cores first. + """ + with StreamingParquetReader(parquet_path, batch_size) as reader, \ + ProcessPoolExecutor(max_workers=os.cpu_count()) as pool: + + # Partial function to avoid re-sending constants + worker_fn = functools.partial( + _row_to_tensor_pack, + max_pep_seq_len=max_pep_seq_len, + max_mhc_len=max_mhc_len, + ) + + for batch_df in reader.iter_batches(): + # Convert Arrow table → list[dict] once; avoids pandas overhead + dict_rows = batch_df.to_dict(orient="list") # columns -> python lists + # Re-shape to list[dict(row)] + rows_iter = ( {k: dict_rows[k][i] for k in dict_rows} # row dict + for i in range(len(batch_df)) ) + + # Parallel map; chunksize tuned for large batches + results = pool.map(worker_fn, rows_iter, chunksize=64) + + # # Stream each converted sample back to the generator consumer + # yield from results # <-- keeps memory footprint tiny + for result, sample_id in zip(results, dict_rows["allele"]): + yield result + (sample_id,) + + # explicit clean-up + del batch_df, dict_rows, rows_iter, results + + +def create_streaming_dataset(parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 128, + buffer_size: int = 1000): + """ + Same semantics as before, but the generator already does parallel + preprocessing. We now ask tf.data to interleave multiple generator + shards in parallel as well. + """ + output_signature = ( + ( + tf.TensorSpec(shape=(max_pep_seq_len, 21), dtype=tf.float32), + tf.TensorSpec(shape=(max_mhc_len, 1152), dtype=tf.float32), + ), + tf.TensorSpec(shape=(), dtype=tf.float32), + tf.TensorSpec(shape=(), dtype=tf.string), # MHC allele ID + ) + + # Create raw dataset with features, label, and IDs + raw_ds = tf.data.Dataset.from_generator( + lambda: streaming_data_generator( + parquet_path, + max_pep_seq_len, + max_mhc_len, + buffer_size), + output_signature=output_signature, + ) + + # Parallel interleave for speed + raw_ds = raw_ds.interleave( + lambda feats, label, mhc_id: tf.data.Dataset.from_tensors((feats, label, mhc_id)), + cycle_length=tf.data.AUTOTUNE, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + + # Separate dataset and IDs + ds = raw_ds.map(lambda feats, label, _: (feats, label), + num_parallel_calls=tf.data.AUTOTUNE) + ids_ds = raw_ds.map(lambda _, __, mhc_id: mhc_id, + num_parallel_calls=tf.data.AUTOTUNE) + labels_only_ds = raw_ds.map(lambda _, label, __: label, + num_parallel_calls=tf.data.AUTOTUNE) + + return ds, ids_ds, labels_only_ds + + +# ---------------------------------------------------------------------------- +# get cross latent npy +# ---------------------------------------------------------------------------- +def save_cross_latent_npy(cross_latent_model, ds, run_dir: str, name: str = "cross_latents_fold_{fold_id}", mhc_ids: np.ndarray = None, labels_only: np.ndarray = None): + cross_latents = cross_latent_model.predict(ds, verbose=0) + save_path = os.path.join(run_dir, f'{name}.npz') + # Prepare data for saving + if mhc_ids is not None or labels_only is not None: + if isinstance(mhc_ids, tf.data.Dataset): + mhc_ids = np.array(list(mhc_ids.as_numpy_iterator())) + if isinstance(labels_only, tf.data.Dataset): + labels_only = np.array(list(labels_only.as_numpy_iterator())) + if isinstance(cross_latents, tf.Tensor): + cross_latents = cross_latents.numpy() + savez_kwargs = {'cross_latents': cross_latents} + if mhc_ids is not None: + savez_kwargs['mhc_ids'] = mhc_ids + if labels_only is not None: + savez_kwargs['labels'] = labels_only + np.savez(save_path, **savez_kwargs) + else: + np.save(save_path.replace('.npz', '.npy'), cross_latents) + +# ---------------------------------------------------------------------------- +# Visualization utilities (keeping the same as original) +# ---------------------------------------------------------------------------- +def plot_training_curve(history: tf.keras.callbacks.History, run_dir: str, fold_id: int = None, + model=None, val_dataset=None): + """Plot training curves and validation metrics""" + hist = history.history + plt.figure(figsize=(21, 6)) + plot_name = f"training_curve{'_fold' + str(fold_id) if fold_id is not None else ''}" + + plt.suptitle(f"Training Curves{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # Plot 1: Loss curve + plt.subplot(1, 4, 1) + plt.plot(hist["loss"], label="train", linewidth=2) + plt.plot(hist["val_loss"], label="val", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("BCE Loss") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 2: Accuracy curve + if "binary_accuracy" in hist and "val_binary_accuracy" in hist: + plt.subplot(1, 4, 2) + plt.plot(hist["binary_accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_binary_accuracy"], label="val acc", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + plt.title("Binary Accuracy") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 3: AUC curve + if "AUC" in hist and "val_AUC" in hist: + plt.subplot(1, 4, 3) + plt.plot(hist["AUC"], label="train AUC", linewidth=2) + plt.plot(hist["val_AUC"], label="val AUC", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("AUC") + plt.title("AUC") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 4: Confusion matrix placeholder + plt.subplot(1, 4, 4) + if model is not None and val_dataset is not None: + # Sample a subset for confusion matrix to avoid memory issues + sample_dataset = val_dataset.take(100) # Take only 100 batches + y_pred_proba = model.predict(sample_dataset, verbose=0) + y_pred = (y_pred_proba > 0.8).astype(int) + + y_true = [] + for _, labels in sample_dataset: + y_true.extend(labels.numpy()) + y_true = np.array(y_true) + + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix (100 Batches)') + else: + plt.text(0.5, 0.5, 'Confusion Matrix N/A \n(Sample from validation)', + ha='center', va='center', transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) + plt.axis('off') + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{plot_name}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Training curve saved to {out_png}") + + +def plot_test_metrics(model, test_dataset, run_dir: str, fold_id: int = None, + history=None, string: str = None): + """Plot comprehensive evaluation metrics for test dataset""" + print("Generating predictions for test metrics...") + + # Collect predictions and labels in batches to avoid memory issues + y_true_list = [] + y_pred_proba_list = [] + + for batch_x, batch_y in test_dataset: + batch_pred = model.predict(batch_x, verbose=0).flatten() + batch_y_np = batch_y.numpy().flatten() + mask = ~np.isnan(batch_y_np) + y_true_list.append(batch_y_np[mask]) + y_pred_proba_list.append(batch_pred[mask]) + + y_true = np.concatenate(y_true_list).flatten() + y_pred_proba = np.concatenate(y_pred_proba_list) + y_pred = (y_pred_proba > 0.8).astype(int) + + # Calculate ROC curve + fpr, tpr, _ = roc_curve(y_true, y_pred_proba) + roc_auc = auc(fpr, tpr) + + # Create evaluation plot + plt.figure(figsize=(15, 10)) + plt.suptitle(f"{string} Evaluation Metrics{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # ROC Curve + plt.subplot(2, 2, 1) + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.8) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('ROC Curve') + plt.legend(loc="lower right") + plt.grid(True, alpha=0.3) + + # Confusion Matrix + plt.subplot(2, 2, 2) + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix') + plt.xlabel('Predicted Label') + plt.ylabel('True Label') + + # Metrics bar chart + plt.subplot(2, 2, 3) + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, zero_division=0) + recall = recall_score(y_true, y_pred, zero_division=0) + f1 = f1_score(y_true, y_pred, zero_division=0) + + metrics = {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1, 'AUC': roc_auc} + bars = plt.bar(range(len(metrics)), list(metrics.values()), + color=['skyblue', 'lightcoral', 'lightgreen', 'gold', 'orange'], + alpha=0.8, edgecolor='black', linewidth=1) + plt.xticks(range(len(metrics)), list(metrics.keys()), rotation=45, ha='right') + plt.ylim(0, 1.0) + plt.title('Evaluation Metrics') + plt.ylabel('Score') + plt.grid(True, alpha=0.3, axis='y') + + for bar, value in zip(bars, metrics.values()): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, + f'{value:.3f}', ha='center', va='bottom', fontweight='bold') + + # Prediction distribution + plt.subplot(2, 2, 4) + plt.hist(y_pred_proba[y_true == 0], bins=30, alpha=0.7, label='Negative Class', + color='red', density=True) + plt.hist(y_pred_proba[y_true == 1], bins=30, alpha=0.7, label='Positive Class', + color='blue', density=True) + plt.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Threshold') + plt.xlabel('Predicted Probability') + plt.ylabel('Density') + plt.title('Prediction Distribution') + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{string}_metrics{'_fold' + str(fold_id) if fold_id is not None else ''}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Test metrics visualization saved to {out_png}") + + # Print summary + print("\n" + "=" * 50) + print("EVALUATION SUMMARY") + print("=" * 50) + print(f"Accuracy: {accuracy:.4f}") + print(f"Precision: {precision:.4f}") + print(f"Recall: {recall:.4f}") + print(f"F1 Score: {f1:.4f}") + print(f"ROC AUC: {roc_auc:.4f}") + print("=" * 50) + + return { + 'roc_auc': roc_auc, 'accuracy': accuracy, 'precision': precision, + 'recall': recall, 'f1': f1, 'confusion_matrix': cm.tolist() + } + +def plot_attn(att_model, val_loader, run_dir: str, fold_id: int = None): + """Generate and save attention heatmaps for 5 samples.""" + # ------------------------------------------------------------- + # ATTENTION VISUALISATION – take ONE batch from validation + # ------------------------------------------------------------- + (pep_ex, mhc_ex), labels = next(iter(val_loader)) # first batch + att_scores = att_model.predict([pep_ex, mhc_ex], verbose=0) + print("attn_scores", att_scores.shape) + # save attention scores + if fold_id is None: + fold_id = 0 + run_dir = os.path.join(run_dir, f"fold_{fold_id}") + pathlib.Path(run_dir).mkdir(parents=True, exist_ok=True) + print(f"✓ Attention scores saved to {run_dir}") + out_attn = os.path.join(run_dir, f"attn_scores_fold{fold_id}.npy") + np.save(out_attn, att_scores) + print(f"✓ Attention scores saved to {out_attn}") + # ------------------------------------------------------------- + # att_scores shape : (B, heads, pep_len, mhc_len) + att_mean = att_scores.mean(axis=1) # (B,pep,mhc) + print("att_mean shape:", att_mean.shape) + # Find positive and negative samples + labels_np = labels.numpy().flatten() + pos_indices = np.where(labels_np == 1)[0] + neg_indices = np.where(labels_np == 0)[0] + # Select up to 5 samples (prioritize positive samples, then use negative if needed) + num_pos = min(5, len(pos_indices)) + num_neg = min(5 - num_pos, len(neg_indices)) if num_pos < 5 else 0 + selected_pos = pos_indices[:num_pos] + selected_neg = neg_indices[:num_neg] if num_neg > 0 else [] + selected_samples = list(selected_pos) + list(selected_neg) + print(f"Plotting attention maps for {len(selected_pos)} positive and {len(selected_neg)} negative samples") + # Generate heatmaps for each selected sample + for i, sample_id in enumerate(selected_samples[:5]): + sample_type = "Positive" if sample_id in pos_indices else "Negative" + A = att_mean[sample_id] + A = A.transpose() + plt.figure(figsize=(8, 6)) + ax = sns.heatmap( + A, + cmap="viridis", + xticklabels=[ + AA[pep_ex[sample_id][j].numpy().argmax()] if float(tf.reduce_sum(pep_ex[sample_id][j])) > 0 else "" + for j in range(A.shape[1]) + ], + yticklabels=[f"M{i}" for i in range(A.shape[0])], + cbar_kws={"label": "attention"}, + linewidths=0.1, # Add lines between cells + linecolor='black', # White lines for better contrast + linestyle=':' # Dashed lines + ) + # Improve labels and title + plt.xlabel("Peptide Position (Amino Acid)") + plt.ylabel("MHC Position") + plt.title(f"Fold {fold_id} - Attention Heatmap\nSample {sample_id} ({sample_type} Example)") + # Add box around the entire heatmap + for _, spine in ax.spines.items(): + spine.set_visible(True) + spine.set_linewidth(2) + out_png = os.path.join(run_dir, f"attention_fold{fold_id}_sample{sample_id}_{sample_type.lower()}.png") + plt.tight_layout() + plt.savefig(out_png, dpi=300, bbox_inches="tight") + plt.close() + print(f"✓ Attention heat-map {i+1}/5 saved to {out_png}") + + +# ---------------------------------------------------------------------------- +# Main training function +# ---------------------------------------------------------------------------- +def main(argv=None): + p = argparse.ArgumentParser() + p.add_argument("--dataset_path", required=True, + help="Path to the dataset directory") + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch", type=int, default=128) + p.add_argument("--outdir", default=None, + help="Output dir (default: runs/run_YYYYmmdd-HHMMSS)") + p.add_argument("--buffer_size", type=int, default=1000, + help="Buffer size for streaming data loading") + p.add_argument("--test_batches", type=int, default=3, + help="Number of batches to use for test dataset evaluation") + + args = p.parse_args(argv) + + run_dir = args.outdir or f"runs/run_{datetime.datetime.now():%Y%m%d-%H%M%S}" + pathlib.Path(run_dir).mkdir(parents=True, exist_ok=True) + print(f"★ Outputs → {run_dir}\n") + + # Set seeds for reproducibility + tf.random.set_seed(42) + np.random.seed(42) + print("Setting random seeds for reproducibility...") + + print("Initial memory state:") + monitor_memory() + + # Extract metadata from datasets without loading them fully + print("Extracting dataset metadata...") + + # Get fold information + fold_dir = os.path.join(args.dataset_path, 'folds') + fold_files = sorted([f for f in os.listdir(fold_dir) if f.endswith('.parquet')]) + n_folds = len(fold_files) // 2 + + # Find maximum peptide length across all datasets + max_peptide_length = 0 + max_mhc_length = 50 + + print("Scanning datasets for maximum peptide length...") + all_parquet_files = [ + os.path.join(args.dataset_path, "test1.parquet"), + os.path.join(args.dataset_path, "test2.parquet") + ] + + # Add fold files + for i in range(1, n_folds + 1): + all_parquet_files.extend([ + os.path.join(fold_dir, f'fold_{i}_train.parquet'), + os.path.join(fold_dir, f'fold_{i}_val.parquet') + ]) + + for pq_file in all_parquet_files: + if os.path.exists(pq_file): + metadata = get_dataset_metadata(pq_file) + max_peptide_length = max(max_peptide_length, metadata['max_peptide_length']) + print( + f" {os.path.basename(pq_file)}: max_len={metadata['max_peptide_length']}, rows={metadata['total_rows']}") + + print(f"✓ Maximum peptide length across all datasets: {max_peptide_length}") + + # Create fold datasets and class weights + folds = [] + class_weights = [] + + for i in range(1, n_folds + 1): + print(f"\nProcessing fold {i}/{n_folds}") + train_path = os.path.join(fold_dir, f'fold_{i}_train.parquet') + val_path = os.path.join(fold_dir, f'fold_{i}_val.parquet') + + # Calculate class weights from training data + print(f" Calculating class weights...") + cw = calculate_class_weights(train_path) + print(f" Class weights: {cw}") + + # Create streaming datasets + train_ds, val_ids, train_labels_copy = create_streaming_dataset(train_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + train_ds = (train_ds + .shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True) + .batch(args.batch) + .take(args.test_batches) + .prefetch(tf.data.AUTOTUNE)) + + train_ids = np.asarray(val_ids) + train_labels_copy = np.asarray(train_labels_copy) + + + val_ds, val_ids, val_labels_copy = create_streaming_dataset(val_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + val_ds = (val_ds + .batch(args.batch) + .take(args.test_batches) + .prefetch(tf.data.AUTOTUNE)) + + + folds.append((train_ds, val_ds)) + class_weights.append(cw) + + val_ids = np.asarray(val_ids) + val_labels_copy = np.asarray(val_labels_copy) + + # Force cleanup + cleanup_memory() + + # Create test datasets + print("Creating test datasets...") + test1_ds, test1_ids, test1_labels_copy = create_streaming_dataset(os.path.join(args.dataset_path, "test1.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + test1_ds = (test1_ds + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + test1_ids = np.array(list(test1_ids.as_numpy_iterator())) + test1_labels_copy = np.array(list(test1_labels_copy.as_numpy_iterator())) + + test2_ds, test2_ids, test2_labels_copy = create_streaming_dataset(os.path.join(args.dataset_path, "test2.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + test2_ds = (test2_ds + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + test2_ids = np.array(list(test2_ids.as_numpy_iterator())) + test2_labels_copy = np.array(list(test2_labels_copy.as_numpy_iterator())) + + print(f"✓ Created {n_folds} fold datasets and 2 test datasets") + print("Memory after dataset creation:") + monitor_memory() + + # Training loop + print("\n" + "=" * 60) + print("STARTING TRAINING") + print("=" * 60) + + for fold_id, ((train_loader, val_loader), class_weight) in enumerate(zip(folds, class_weights), start=1): + print(f'\n🔥 Training fold {fold_id}/{n_folds}') + + # Clean up before each fold + cleanup_memory() + + # Build fresh model for each fold + print("Building model...") + model, attn_model, cross_latent_model = build_classifier(max_peptide_length,max_mhc_length) + model.summary() + + # Callbacks + ckpt_cb = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(run_dir, f'best_fold_{fold_id}.weights.h5'), + monitor='val_loss', save_best_only=True, mode='min', verbose=1) + early_cb = tf.keras.callbacks.EarlyStopping( + monitor='val_loss', patience=15, restore_best_weights=True, verbose=1) + + # Verify data shapes + for (x_pep, latents), labels in train_loader.take(1): + print(f"✓ Input shapes: peptide={x_pep.shape}, mhc={latents.shape}, labels={labels.shape}") + break + + print("Memory before training:") + monitor_memory() + + # Train model + print("🚀 Starting training...") + hist = model.fit( + train_loader, + validation_data=val_loader, + epochs=args.epochs, + class_weight=class_weight, + callbacks=[ckpt_cb, early_cb], + verbose=1, + ) + + print("Memory after training:") + monitor_memory() + + # Plot training curves + plot_training_curve(hist, run_dir, fold_id, model, val_loader) + plot_attn(attn_model, val_loader, run_dir, fold_id) + + # Save model and metadata + model.save_weights(os.path.join(run_dir, f'model_fold_{fold_id}.weights.h5')) + metadata = { + "fold_id": fold_id, + "epochs": args.epochs, + "batch_size": args.batch, + "max_peptide_length": max_peptide_length, + "max_mhc_length": max_mhc_length, + "class_weights": class_weight, + "run_dir": run_dir, + "mhc_class": MHC_CLASS + } + with open(os.path.join(run_dir, f'metadata_fold_{fold_id}.json'), 'w') as f: + json.dump(metadata, f, indent=4) + + # Evaluate on test sets + print(f"\n📊 Evaluating fold {fold_id} on test sets...") + + # Test1 evaluation + print("Evaluating on test1 (balanced alleles)...") + plot_test_metrics(model, test1_ds, run_dir, fold_id, string="Test1_balanced_alleles") + + # Test2 evaluation + print("Evaluating on test2 (rare alleles)...") + plot_test_metrics(model, test2_ds, run_dir, fold_id, string="Test2_rare_alleles") + + # save cross_latents for test1 and test2 + save_cross_latent_npy(cross_latent_model, test1_ds, run_dir, name=f"cross_latent_test1_fold_{fold_id}", mhc_ids=test1_ids, labels_only=test1_labels_copy) + save_cross_latent_npy(cross_latent_model, test2_ds, run_dir, name=f"cross_latent_test2_fold_{fold_id}", mhc_ids=test2_ids, labels_only=test2_labels_copy) + + print(f"✅ Fold {fold_id} completed successfully") + + # Cleanup + del model, hist + cleanup_memory() + + print("\n🎉 Training completed successfully!") + print(f"📁 All results saved to: {run_dir}") + + +if __name__ == "__main__": + BUFFER = 8192 # Reduced buffer size for memory efficiency + MHC_CLASS = 2 + dataset_path = f"../data/Custom_dataset/PMGen_sequences/mhc_{MHC_CLASS}" + main([ + "--dataset_path", dataset_path, + "--epochs", "5", + "--batch", "128", + "--buffer_size", "8192", + "--test_batches", "2000", + ]) \ No newline at end of file diff --git a/src/run_model4_recon.py b/src/run_model4_recon.py new file mode 100644 index 00000000..e2bba455 --- /dev/null +++ b/src/run_model4_recon.py @@ -0,0 +1,767 @@ +#!/usr/bin/env python +""" +========================= + +MEMORY-OPTIMIZED End‑to‑end trainer for a **peptide×MHC cross‑attention classifier**. +Loads NetMHCpan‑style parquet files in true streaming fashion without loading entire datasets into memory. + +Key improvements: +1. Streaming parquet reading with configurable batch sizes +2. Lazy evaluation of dataset properties (seq length, class balance) +3. Memory-efficient TensorFlow data pipelines +4. Proper cleanup and memory monitoring + +Author: Amirreza (memory-optimized version, 2025) +""" +from __future__ import annotations +import os +import sys + +print(sys.executable) + +# ============================================================================= +# CRITICAL: GPU Memory Configuration - MUST BE FIRST +# ============================================================================= +import tensorflow as tf + + +def configure_gpu_memory(): + """Configure TensorFlow to use GPU memory efficiently""" + try: + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + print(f"Found {len(gpus)} GPU(s)") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + print("✓ GPU memory growth enabled") + else: + print("No GPUs found - running on CPU") + except RuntimeError as e: + print(f"GPU configuration error: {e}") + + +# Configure GPU immediately +configure_gpu_memory() + +# --------------------------------------------------------------------- +# ► Use all logical CPU cores for TF ops that still run on CPU +# --------------------------------------------------------------------- +NUM_CPUS = os.cpu_count() or 1 +tf.config.threading.set_intra_op_parallelism_threads(NUM_CPUS) +tf.config.threading.set_inter_op_parallelism_threads(NUM_CPUS) +print(f'✓ TF intra/inter-op threads set to {NUM_CPUS}') + +# Set memory-friendly environment variables +os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async' +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' +os.environ["PYTHONHASHSEED"] = "42" +os.environ["TF_DETERMINISTIC_OPS"] = "1" + +import math +import argparse, datetime, pathlib, json +import psutil +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from model4_recon import build_reconstruction_model +from sklearn.metrics import ( + confusion_matrix, roc_curve, auc, precision_score, + recall_score, f1_score, accuracy_score, roc_auc_score +) +import seaborn as sns +import pyarrow.parquet as pq +import gc +import weakref +import pyarrow as pa, pyarrow.compute as pc +pa.set_cpu_count(os.cpu_count()) + + +# ============================================================================= +# Memory monitoring functions +# ============================================================================= +def monitor_memory(): + """Monitor system memory usage""" + memory = psutil.virtual_memory() + print(f"System RAM: {memory.used / 1e9:.1f}GB / {memory.total / 1e9:.1f}GB ({memory.percent:.1f}% used)") + + try: + from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo + nvmlInit() + deviceCount = nvmlDeviceGetCount() + for i in range(deviceCount): + handle = nvmlDeviceGetHandleByIndex(i) + info = nvmlDeviceGetMemoryInfo(handle) + print( + f"GPU {i}: {info.used / 1e9:.1f}GB / {info.total / 1e9:.1f}GB ({100 * info.used / info.total:.1f}% used)") + except: + print("GPU memory monitoring not available") + + +def cleanup_memory(): + """Aggressive memory cleanup""" + gc.collect() + try: + tf.keras.backend.clear_session() + except: + pass + + +# ---------------------------------------------------------------------------- +# Peptide encoding utilities +# ---------------------------------------------------------------------------- +# add a MASK channel +AA = "ACDEFGHIKLMNPQRSTVWY" +MASK_IDX = 20 +AA_TO_IDX = {aa: i for i, aa in enumerate(AA)} +AA_DIM = 21 # 20 aa + MASK token + +def onehot(seq: str, max_len: int) -> np.ndarray: + """Return one-hot with an explicit MASK channel (all-zeros for padding).""" + arr = np.zeros((max_len, AA_DIM), dtype=np.float32) + for i, aa in enumerate(seq[:max_len].upper()): + arr[i, AA_TO_IDX.get(aa, MASK_IDX)] = 1.0 + return arr + +def mask_random_positions(oh: np.ndarray, mask_rate: float = .3): + """ + Replace ~mask_rate positions by MASK token. + Returns (masked_input, y_true, sample_weight) + """ + pep_len = oh.shape[0] + mpos = np.random.rand(pep_len) < mask_rate + y_true = oh.copy() + sample_weight = mpos.astype(np.float32) # 1 on masked positions + + masked = oh.copy() + masked[mpos] = 0.0 + masked[mpos, MASK_IDX] = 1.0 + return masked, y_true, sample_weight + + +def _read_embedding_file(path: str | os.PathLike) -> np.ndarray: + """Robust loader for latent embeddings""" + try: + arr = np.load(path) + if isinstance(arr, np.ndarray) and arr.dtype == np.float32: + return arr + raise ValueError + except ValueError: + obj = np.load(path, allow_pickle=True) + if isinstance(obj, np.ndarray) and obj.dtype == object: + obj = obj.item() + if isinstance(obj, dict) and "embedding" in obj: + return obj["embedding"].astype("float32") + raise ValueError(f"Unrecognised embedding file {path}") + + +# ---------------------------------------------------------------------------- +# Streaming dataset utilities +# ---------------------------------------------------------------------------- +class StreamingParquetReader: + """Memory-efficient streaming parquet reader""" + + def __init__(self, parquet_path: str, batch_size: int = 1000): + self.parquet_path = parquet_path + self.batch_size = batch_size + self._file = None + self._num_rows = None + + def __enter__(self): + self._file = pq.ParquetFile(self.parquet_path) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._file: + self._file = None + + @property + def num_rows(self): + """Get total number of rows without loading data""" + if self._num_rows is None: + if self._file is None: + with pq.ParquetFile(self.parquet_path) as f: + self._num_rows = f.metadata.num_rows + else: + self._num_rows = self._file.metadata.num_rows + return self._num_rows + + def iter_batches(self): + """Iterate over parquet file in batches""" + if self._file is None: + raise RuntimeError("Reader not opened. Use within 'with' statement.") + + for batch in self._file.iter_batches(batch_size=self.batch_size): + df = batch.to_pandas() + yield df + del df, batch # Explicit cleanup + + def sample_for_metadata(self, n_samples: int = 1000): + """Sample a small portion for metadata extraction""" + with pq.ParquetFile(self.parquet_path) as f: + # Read first batch for metadata + first_batch = next(f.iter_batches(batch_size=min(n_samples, self.num_rows))) + return first_batch.to_pandas() + + +def get_dataset_metadata(parquet_path: str): + """Extract dataset metadata without loading full dataset""" + with StreamingParquetReader(parquet_path) as reader: + sample_df = reader.sample_for_metadata(reader.num_rows) + + metadata = { + 'total_rows': reader.num_rows, + 'max_peptide_length': int(sample_df['long_mer'].str.len().max()) if 'long_mer' in sample_df.columns else 0, + 'class_distribution': sample_df[ + 'assigned_label'].value_counts().to_dict() if 'assigned_label' in sample_df.columns else {}, + } + + del sample_df + return metadata + + +def calculate_class_weights(parquet_path: str): + """Calculate class weights from a sample of the dataset""" + with StreamingParquetReader(parquet_path, batch_size=1000) as reader: + label_counts = {0: 0, 1: 0} + for batch_df in reader.iter_batches(): + batch_labels = batch_df['assigned_label'].values + unique, counts = np.unique(batch_labels, return_counts=True) + for label, count in zip(unique, counts): + if label in [0, 1]: + label_counts[int(label)] += count + del batch_df + + # Calculate balanced class weights + total = sum(label_counts.values()) + if total == 0 or label_counts[0] == 0 or label_counts[1] == 0: + return {0: 1.0, 1: 1.0} + + return { + 0: total / (2 * label_counts[0]), + 1: total / (2 * label_counts[1]) + } + + +# --------------------------------------------------------------------- +# Utility that is executed in worker processes +# (must be top-level so it can be pickled on Windows) +# --------------------------------------------------------------------- +def _row_to_tensor_pack(row_dict: dict, + max_pep_seq_len: int, + max_mhc_len: int, + mask_rate: float = 0.3): + # peptide ----------------------------------------------------------- + pep_oh = onehot(row_dict["long_mer"], max_pep_seq_len) + pep_masked, y_true, sw = mask_random_positions(pep_oh, mask_rate) + + # MHC latent -------------------------------------------------------- + mhc = _read_embedding_file(row_dict["mhc_embedding_path"]).astype("float32") + if mhc.shape[0] != max_mhc_len: + raise ValueError("MHC length mismatch") + + return (pep_masked, mhc), y_true, sw # <-- three-tuple + +from concurrent.futures import ProcessPoolExecutor +import functools, itertools + +def streaming_data_generator( + parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 1000): + """ + Yields *individual* samples, but converts an entire Parquet batch + on multiple CPU cores first. + """ + with StreamingParquetReader(parquet_path, batch_size) as reader, \ + ProcessPoolExecutor(max_workers=os.cpu_count()) as pool: + + # Partial function to avoid re-sending constants + worker_fn = functools.partial( + _row_to_tensor_pack, + max_pep_seq_len=max_pep_seq_len, + max_mhc_len=max_mhc_len, + ) + + for batch_df in reader.iter_batches(): + # Convert Arrow table → list[dict] once; avoids pandas overhead + dict_rows = batch_df.to_dict(orient="list") # columns -> python lists + # Re-shape to list[dict(row)] + rows_iter = ( {k: dict_rows[k][i] for k in dict_rows} # row dict + for i in range(len(batch_df)) ) + + # Parallel map; chunksize tuned for large batches + results = pool.map(worker_fn, rows_iter, chunksize=64) + + # Stream each converted sample back to the generator consumer + yield from results # <-- keeps memory footprint tiny + + # explicit clean-up + del batch_df, dict_rows, rows_iter, results + + +def create_streaming_dataset(parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 128, + buffer_size: int = 1000): + """ + Same semantics as before, but the generator already does parallel + preprocessing. We now ask tf.data to interleave multiple generator + shards in parallel as well. + """ + output_signature = ( + (tf.TensorSpec((max_pep_seq_len, AA_DIM), tf.float32), + tf.TensorSpec((max_mhc_len, 1152), tf.float32)), + tf.TensorSpec((max_pep_seq_len, AA_DIM), tf.float32), # y_true + tf.TensorSpec((max_pep_seq_len,), tf.float32), # sample_weight + ) + + ds = tf.data.Dataset.from_generator( + lambda: streaming_data_generator( + parquet_path, + max_pep_seq_len, + max_mhc_len, + buffer_size), + output_signature=output_signature, + ) + + # ► Parallel interleave gives another speed-up if the Parquet file has + # many row-groups – adjust cycle_length as needed. + ds = ds.interleave( + lambda features, y_true, sw: tf.data.Dataset.from_tensors((features, y_true, sw)), + cycle_length=tf.data.AUTOTUNE, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + + return ds + + +# ---------------------------------------------------------------------------- +# Visualization utilities (keeping the same as original) +# ---------------------------------------------------------------------------- +def plot_training_curve(history: tf.keras.callbacks.History, run_dir: str, fold_id: int = None, + model=None, val_dataset=None): + """Plot training curves and validation metrics""" + hist = history.history + plt.figure(figsize=(21, 6)) + plot_name = f"training_curve{'_fold' + str(fold_id) if fold_id is not None else ''}" + + plt.suptitle(f"Training Curves{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # Plot 1: Loss curve + plt.subplot(1, 4, 1) + plt.plot(hist["loss"], label="train", linewidth=2) + plt.plot(hist["val_loss"], label="val", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("BCE Loss") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 2: Accuracy curve + if "binary_accuracy" in hist and "val_binary_accuracy" in hist: + plt.subplot(1, 4, 2) + plt.plot(hist["binary_accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_binary_accuracy"], label="val acc", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + plt.title("Binary Accuracy") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 3: AUC curve + if "AUC" in hist and "val_AUC" in hist: + plt.subplot(1, 4, 3) + plt.plot(hist["AUC"], label="train AUC", linewidth=2) + plt.plot(hist["val_AUC"], label="val AUC", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("AUC") + plt.title("AUC") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 4: Confusion matrix placeholder + plt.subplot(1, 4, 4) + if model is not None and val_dataset is not None: + # Sample a subset for confusion matrix to avoid memory issues + sample_dataset = val_dataset.take(100) # Take only 100 batches + y_pred_proba = model.predict(sample_dataset, verbose=0) + y_pred = (y_pred_proba > 0.5).astype(int) + + y_true = [] + for _, labels, _ in sample_dataset: + y_true.extend(labels.numpy()) + y_true = np.array(y_true) + + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix (100 Batches)') + else: + plt.text(0.5, 0.5, 'Confusion Matrix N/A \n(Sample from validation)', + ha='center', va='center', transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) + plt.axis('off') + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{plot_name}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Training curve saved to {out_png}") + + +def plot_test_metrics(model, test_dataset, run_dir: str, fold_id: int = None, + history=None, string: str = None): + """Plot comprehensive evaluation metrics for test dataset""" + print("Generating predictions for test metrics...") + + # Collect predictions and labels in batches to avoid memory issues + y_true_list = [] + y_pred_proba_list = [] + + for batch_x, batch_y in test_dataset: + batch_pred = model.predict(batch_x, verbose=0) + y_true_list.append(batch_y.numpy()) + y_pred_proba_list.append(batch_pred.flatten()) + + y_true = np.concatenate(y_true_list).flatten() + y_pred_proba = np.concatenate(y_pred_proba_list) + y_pred = (y_pred_proba > 0.5).astype(int) + + # Calculate ROC curve + fpr, tpr, _ = roc_curve(y_true, y_pred_proba) + roc_auc = auc(fpr, tpr) + + # Create evaluation plot + plt.figure(figsize=(15, 10)) + plt.suptitle(f"{string} Evaluation Metrics{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # ROC Curve + plt.subplot(2, 2, 1) + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.8) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('ROC Curve') + plt.legend(loc="lower right") + plt.grid(True, alpha=0.3) + + # Confusion Matrix + plt.subplot(2, 2, 2) + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix') + plt.xlabel('Predicted Label') + plt.ylabel('True Label') + + # Metrics bar chart + plt.subplot(2, 2, 3) + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, zero_division=0) + recall = recall_score(y_true, y_pred, zero_division=0) + f1 = f1_score(y_true, y_pred, zero_division=0) + + metrics = {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1, 'AUC': roc_auc} + bars = plt.bar(range(len(metrics)), list(metrics.values()), + color=['skyblue', 'lightcoral', 'lightgreen', 'gold', 'orange'], + alpha=0.8, edgecolor='black', linewidth=1) + plt.xticks(range(len(metrics)), list(metrics.keys()), rotation=45, ha='right') + plt.ylim(0, 1.0) + plt.title('Evaluation Metrics') + plt.ylabel('Score') + plt.grid(True, alpha=0.3, axis='y') + + for bar, value in zip(bars, metrics.values()): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, + f'{value:.3f}', ha='center', va='bottom', fontweight='bold') + + # Prediction distribution + plt.subplot(2, 2, 4) + plt.hist(y_pred_proba[y_true == 0], bins=30, alpha=0.7, label='Negative Class', + color='red', density=True) + plt.hist(y_pred_proba[y_true == 1], bins=30, alpha=0.7, label='Positive Class', + color='blue', density=True) + plt.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Threshold') + plt.xlabel('Predicted Probability') + plt.ylabel('Density') + plt.title('Prediction Distribution') + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{string}_metrics{'_fold' + str(fold_id) if fold_id is not None else ''}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Test metrics visualization saved to {out_png}") + + # Print summary + print("\n" + "=" * 50) + print("EVALUATION SUMMARY") + print("=" * 50) + print(f"Accuracy: {accuracy:.4f}") + print(f"Precision: {precision:.4f}") + print(f"Recall: {recall:.4f}") + print(f"F1 Score: {f1:.4f}") + print(f"ROC AUC: {roc_auc:.4f}") + print("=" * 50) + + return { + 'roc_auc': roc_auc, 'accuracy': accuracy, 'precision': precision, + 'recall': recall, 'f1': f1, 'confusion_matrix': cm.tolist() + } + +def plot_attn(att_model, val_loader, run_dir: str, fold_id: int = None): + # ------------------------------------------------------------- + # ATTENTION VISUALISATION – take ONE batch from validation + # ------------------------------------------------------------- + (pep_ex, mhc_ex), _ = next(iter(val_loader)) # first batch + att_scores = att_model.predict([pep_ex, mhc_ex], verbose=0) + print("attn_scores", att_scores.shape) + print("attn_scores first 5 samples:", att_scores[:5]) + # save attention scores + if fold_id is None: + fold_id = 0 + run_dir = os.path.join(run_dir, f"fold_{fold_id}") + pathlib.Path(run_dir).mkdir(parents=True, exist_ok=True) + print(f"✓ Attention scores saved to {run_dir}") + out_attn = os.path.join(run_dir, f"attn_scores_fold{fold_id}.npy") + np.save(out_attn, att_scores) + print(f"✓ Attention scores saved to {out_attn}") + # ------------------------------------------------------------- + # att_scores shape : (B, heads, pep_len, mhc_len) + att_mean = att_scores.mean(axis=1) # (B,pep,mhc) + print("att_mean shape:", att_mean.shape) + sample_id = 0 + A = att_mean[sample_id] + A = A.transpose() + plt.figure(figsize=(8, 6)) + sns.heatmap(A, + cmap = "viridis", + xticklabels = [f"P{j}" for j in range(A.shape[1])], + yticklabels = [f"M{i}" for i in range(A.shape[0])], + cbar_kws = {"label": "attention"}) + plt.title(f"Fold {fold_id} – attention sample {sample_id}") + out_png = os.path.join(run_dir,f"attention_fold{fold_id}_sample{sample_id}.png") + plt.savefig(out_png, dpi=300, bbox_inches="tight") + plt.close() + print(f"✓ Attention heat-map saved to {out_png}") + + +# ---------------------------------------------------------------------------- +# Main training function +# ---------------------------------------------------------------------------- +def main(argv=None): + p = argparse.ArgumentParser() + p.add_argument("--dataset_path", required=True, + help="Path to the dataset directory") + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch", type=int, default=128) + p.add_argument("--outdir", default=None, + help="Output dir (default: runs/run_YYYYmmdd-HHMMSS)") + p.add_argument("--buffer_size", type=int, default=1000, + help="Buffer size for streaming data loading") + p.add_argument("--test_batches", type=int, default=3, + help="Number of batches to use for test dataset evaluation") + + args = p.parse_args(argv) + + run_dir = args.outdir or f"runs/run_{datetime.datetime.now():%Y%m%d-%H%M%S}" + pathlib.Path(run_dir).mkdir(parents=True, exist_ok=True) + print(f"★ Outputs → {run_dir}\n") + + # Set seeds for reproducibility + tf.random.set_seed(42) + np.random.seed(42) + print("Setting random seeds for reproducibility...") + + print("Initial memory state:") + monitor_memory() + + # Extract metadata from datasets without loading them fully + print("Extracting dataset metadata...") + + # Get fold information + fold_dir = os.path.join(args.dataset_path, 'folds') + fold_files = sorted([f for f in os.listdir(fold_dir) if f.endswith('.parquet')]) + n_folds = len(fold_files) // 2 + + # Find maximum peptide length across all datasets + max_peptide_length = 0 + max_mhc_length = 36 # Fixed for now + + print("Scanning datasets for maximum peptide length...") + all_parquet_files = [ + os.path.join(args.dataset_path, "test1.parquet"), + os.path.join(args.dataset_path, "test2.parquet") + ] + + # Add fold files + for i in range(1, n_folds + 1): + all_parquet_files.extend([ + os.path.join(fold_dir, f'fold_{i}_train.parquet'), + os.path.join(fold_dir, f'fold_{i}_val.parquet') + ]) + + for pq_file in all_parquet_files: + if os.path.exists(pq_file): + metadata = get_dataset_metadata(pq_file) + max_peptide_length = max(max_peptide_length, metadata['max_peptide_length']) + print( + f" {os.path.basename(pq_file)}: max_len={metadata['max_peptide_length']}, rows={metadata['total_rows']}") + + print(f"✓ Maximum peptide length across all datasets: {max_peptide_length}") + + # Create fold datasets and class weights + folds = [] + class_weights = [] + + for i in range(1, n_folds + 1): + print(f"\nProcessing fold {i}/{n_folds}") + train_path = os.path.join(fold_dir, f'fold_{i}_train.parquet') + val_path = os.path.join(fold_dir, f'fold_{i}_val.parquet') + + # Calculate class weights from training data + print(f" Calculating class weights...") + cw = calculate_class_weights(train_path) + print(f" Class weights: {cw}") + + # Create streaming datasets + train_ds = (create_streaming_dataset(train_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + .shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True) + .batch(args.batch) + .take(args.test_batches) + .prefetch(tf.data.AUTOTUNE)) + + val_ds = (create_streaming_dataset(val_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + .batch(args.batch) + .take(args.test_batches) + .prefetch(tf.data.AUTOTUNE)) + + folds.append((train_ds, val_ds)) + class_weights.append(cw) + + # Force cleanup + cleanup_memory() + + # Create test datasets + print("Creating test datasets...") + test1_ds = (create_streaming_dataset(os.path.join(args.dataset_path, "test1.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + test2_ds = (create_streaming_dataset(os.path.join(args.dataset_path, "test2.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + print(f"✓ Created {n_folds} fold datasets and 2 test datasets") + print("Memory after dataset creation:") + monitor_memory() + + # Training loop + print("\n" + "=" * 60) + print("STARTING TRAINING") + print("=" * 60) + + for fold_id, ((train_loader, val_loader), class_weight) in enumerate(zip(folds, class_weights), start=1): + print(f'\n🔥 Training fold {fold_id}/{n_folds}') + + # Clean up before each fold + cleanup_memory() + + # Build fresh model for each fold + print("Building model...") + model, attn_model = build_reconstruction_model(max_peptide_length, max_mhc_length) + model.summary() + + # Callbacks + ckpt_cb = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(run_dir, f'best_fold_{fold_id}.weights.h5'), + monitor='val_loss', save_best_only=True, mode='min', verbose=1) + early_cb = tf.keras.callbacks.EarlyStopping( + monitor='val_loss', patience=15, restore_best_weights=True, verbose=1) + + # Verify data shapes + for (x_pep, latents), labels, sw in train_loader.take(1): + print(f"✓ Input shapes: peptide={x_pep.shape}, mhc={latents.shape}, labels={labels.shape}") + break + + print("Memory before training:") + monitor_memory() + + # Train model + print("🚀 Starting training...") + hist = model.fit(train_loader, + validation_data=val_loader, + epochs=args.epochs, + verbose=1) + + print("Memory after training:") + monitor_memory() + + # Plot training curves + #plot_training_curve(hist, run_dir, fold_id, model, val_loader) + #plot_attn(attn_model, val_loader, run_dir, fold_id) + + # Save model and metadata + model.save_weights(os.path.join(run_dir, f'model_fold_{fold_id}.weights.h5')) + metadata = { + "fold_id": fold_id, + "epochs": args.epochs, + "batch_size": args.batch, + "max_peptide_length": max_peptide_length, + "max_mhc_length": max_mhc_length, + "class_weights": class_weight, + "run_dir": run_dir, + "mhc_class": MHC_CLASS + } + with open(os.path.join(run_dir, f'metadata_fold_{fold_id}.json'), 'w') as f: + json.dump(metadata, f, indent=4) + + # Evaluate on test sets + print(f"\n📊 Evaluating fold {fold_id} on test sets...") + + # Test1 evaluation + print("Evaluating on test1 (balanced alleles)...") + #plot_test_metrics(model, test1_ds, run_dir, fold_id, string="Test1_balanced_alleles") + + # Test2 evaluation + print("Evaluating on test2 (rare alleles)...") + #plot_test_metrics(model, test2_ds, run_dir, fold_id, string="Test2_rare_alleles") + + print(f"✅ Fold {fold_id} completed successfully") + + # Cleanup + del model, hist + cleanup_memory() + + print("\n🎉 Training completed successfully!") + print(f"📁 All results saved to: {run_dir}") + + +if __name__ == "__main__": + BUFFER = 8192 # Reduced buffer size for memory efficiency + MHC_CLASS = 2 + dataset_path = f"../data/Custom_dataset/NetMHCpan_dataset/mhc_{MHC_CLASS}" + main([ + "--dataset_path", dataset_path, + "--epochs", "10", + "--batch", "32", + "--buffer_size", "8192", + "--test_batches", "500", + ]) \ No newline at end of file diff --git a/src/run_pMHC_DL_ESM.py b/src/run_pMHC_DL_ESM.py new file mode 100644 index 00000000..8866e631 --- /dev/null +++ b/src/run_pMHC_DL_ESM.py @@ -0,0 +1,752 @@ +#!/usr/bin/env python +""" +========================= + +End‑to‑end trainer for a **peptide×MHC cross‑attention classifier**. +It loads a NetMHCpan‑style parquet that contains + + long_mer, assigned_label, allele, MHC_class, + mhc_embedding **OR** mhc_embedding_path + +columns. Each row supplies + +* a peptide sequence (long_mer) +* a pre‑computed MHC pseudo‑sequence embedding (36, 1152) +* a binary label (assigned_label) + +The script + +1.Derives the longest peptide length → SEQ_LEN. +2.Converts every peptide into a 21‑dim one‑hot tensor (SEQ_LEN, 21). +3.Feeds the pair + + (one_hot_peptide, mhc_latent) → classifier → P(binding) + +4.Trains with binary‑cross‑entropy and saves the best weights & metadata. + +Author : Amirreza (updated for cross‑attention, 2025‑05‑22) +""" +from __future__ import annotations +import os, sys, argparse, datetime, pathlib, json +print(sys.executable) + +import numpy as np +import pandas as pd +import tensorflow as tf +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +from tqdm import tqdm + +from utils.model import build_classifier +from sklearn.metrics import ( + confusion_matrix, roc_curve, auc, precision_score, + recall_score, f1_score, accuracy_score, roc_auc_score +) +import seaborn as sns + + +# --------------------------------------------------------------------------- +# Utility: peptide → one‑hot (seq_len, 21) +# --------------------------------------------------------------------------- +AA = "ACDEFGHIKLMNPQRSTVWY" # 20 standard AAs, order fixed +AA_TO_IDX = {aa: i for i, aa in enumerate(AA)} +UNK_IDX = 20 # index for unknown / padding + + +def peptides_to_onehot(seqs: list[str], seq_len: int) -> np.ndarray: + """Vectorise *seqs* into (N, seq_len, 21) one‑hot with padding.""" + N = len(seqs) + arr = np.zeros((N, seq_len, 21), dtype=np.float32) + for i, s in enumerate(seqs): + if not s: + raise "[ERR] empty peptide sequence" + for j, aa in enumerate(s.upper()[:seq_len]): + arr[i, j, AA_TO_IDX.get(aa, UNK_IDX)] = 1.0 + # remaining positions stay zero → acts as padded UNK (index 20) + return arr + + +# --------------------------------------------------------------------------- +# Robust loader for latent embeddings (same as before) +# --------------------------------------------------------------------------- + +def _read_embedding_file(path: str | os.PathLike) -> np.ndarray: + # Try fast numeric path first + try: + arr = np.load(path) + if isinstance(arr, np.ndarray) and arr.dtype == np.float32: + return arr + raise ValueError + except ValueError: + obj = np.load(path, allow_pickle=True) + if isinstance(obj, np.ndarray) and obj.dtype == object: + obj = obj.item() + if isinstance(obj, dict) and "embedding" in obj: + return obj["embedding"].astype("float32") + raise ValueError(f"Unrecognised embedding file {path}") + + +# --------------------------------------------------------------------------- +# Dataset loader – returns (peptide_onehot, latent), labels +# --------------------------------------------------------------------------- + +def load_dataset(parquet_path: str): + print(f"→ Reading {parquet_path}") + df = pd.read_parquet(parquet_path) + + # 1) Peptide one‑hot ---------------------------------------------------- + if "long_mer" not in df.columns: + raise ValueError("Expected a 'long_mer' column with peptide sequences") + pep_seq_len = int(df["long_mer"].str.len().max()) + print(f" longest peptide = {pep_seq_len} residues") + pep_onehot = peptides_to_onehot(df["long_mer"].tolist(), pep_seq_len) + print(" peptide one-hot shape:", pep_onehot.shape) + + # 2) Latent embeddings -------------------------------------------------- + print(f" loading MHC embeddings") + if "mhc_embedding" in df.columns: + latents = np.stack(df["mhc_embedding"].values).astype("float32") + elif "mhc_embedding_path" in df.columns: + latents = np.stack([_read_embedding_file(p) for p in df["mhc_embedding_path"]]) + else: + raise ValueError("Need 'mhc_embedding' or 'mhc_embedding_path' column") + + print(" latent shape:", latents.shape) + if latents.shape[1:] != (36, 1152): + raise ValueError(f"Unexpected latent shape {latents.shape[1:]}, expected (36,1152)") + + # 3) Labels ------------------------------------------------------------- + labels = df["assigned_label"].astype("float32").values[:, None] # (N,1) + + return pep_onehot, latents, labels, pep_seq_len + + +def load_dataset_in_batches(parquet: str, target_seq_len: int, batch_size: int = 100, subset: float = None): + print(f"→ reading {parquet} in batches of {batch_size}") + if subset is not None: + print(f" using subset fraction: {subset}") + df = pd.read_parquet(parquet).sample(frac=subset, random_state=42) + else: + df = pd.read_parquet(parquet) + if "long_mer" not in df.columns: + raise ValueError("Expected a 'long_mer' column with peptide sequences") + # REMOVE or IGNORE: pep_seq_len = int(df["long_mer"].str.len().max()) # This was the problematic line for one-hot encoding + total = len(df) + for start in range(0, total, batch_size): + end = min(start + batch_size, total) + # Use target_seq_len for one-hot encoding + pep_onehot = peptides_to_onehot(df["long_mer"].values[start:end].tolist(), target_seq_len) + if "mhc_embedding" in df.columns: + latents = np.stack(df["mhc_embedding"].values[start:end]).astype("float32") + elif "mhc_embedding_path" in df.columns: + latents = np.stack([_read_embedding_file(p) for p in df["mhc_embedding_path"].values[start:end]]) + else: + raise ValueError("Need a 'mhc_embedding' or 'mhc_embedding_path' column") + + if latents.shape[1:] != (36, 1152): + raise ValueError(f"Unexpected latent shape {latents.shape[1:]}, expected (36,1152)") + + labels = df["assigned_label"].values[start:end].astype("int32") + labels = labels[:, None] + + yield pep_onehot, latents, labels # Removed the local pep_seq_len from yield + + + +# --------------------------------------------------------------------------- +# Visualisation utility +# --------------------------------------------------------------------------- +def plot_training_curve(history: tf.keras.callbacks.History, run_dir: str, fold_id: int = None, + model=None, val_dataset=None): + """ + Plot training curves and validation metrics from a Keras history object. + + Args: + history: Keras history object containing training metrics. + run_dir: Directory to save the plot. + fold_id: Optional fold identifier for naming the output file. + model: Optional model to compute confusion matrix and ROC curve. + val_dataset: Optional validation dataset to generate predictions. + + Returns: None + """ + hist = history.history + plt.figure(figsize=(21, 6)) + plot_name = f"training_curve{'_fold' + str(fold_id) if fold_id is not None else ''}" + + # set plot name above all plots + plt.suptitle(f"Training Curves{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", fontsize=16, fontweight='bold') + + # Plot 1: Loss curve + plt.subplot(1, 4, 1) + plt.plot(hist["loss"], label="train", linewidth=2) + plt.plot(hist["val_loss"], label="val", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("BCE Loss") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 2: Accuracy curve + if "binary_accuracy" in hist and "val_binary_accuracy" in hist: + plt.subplot(1, 4, 2) + plt.plot(hist["binary_accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_binary_accuracy"], label="val acc", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + plt.title("Binary Accuracy") + plt.legend() + plt.grid(True, alpha=0.3) + elif "accuracy" in hist and "val_accuracy" in hist: + plt.subplot(1, 4, 2) + plt.plot(hist["accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_accuracy"], label="val acc", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + plt.title("Accuracy") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 3: AUC curve + if "auc" in hist and "val_auc" in hist: + plt.subplot(1, 4, 3) + plt.plot(hist["auc"], label="train AUC", linewidth=2) + plt.plot(hist["val_auc"], label="val AUC", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("AUC") + plt.title("AUC") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 4: Confusion matrix (if model and validation dataset provided) + if model is not None and val_dataset is not None: + plt.subplot(1, 4, 4) + + # Get predictions + y_pred_proba = model.predict(val_dataset) + y_pred = (y_pred_proba > 0.5).astype(int) + + # Extract true labels + y_true = [] + for _, labels in val_dataset: + y_true.extend(labels.numpy()) + y_true = np.array(y_true) + + # print ranges of y_true and y_pred + print(f"y_true range: {y_true.min()} to {y_true.max()}") + + print(f"y_pred range: {y_pred.min()} to {y_pred.max()}") + + # Create confusion matrix + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix') + plt.xlabel('Predicted') + plt.ylabel('Actual') + else: + # If no model/dataset provided, show a placeholder or additional metric + plt.subplot(1, 4, 4) + plt.text(0.5, 0.5, 'Confusion Matrix\n(Requires model + val_dataset)', + ha='center', va='center', transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) + plt.axis('off') + + # Save main plot + plt.tight_layout() + out_png = os.path.join(run_dir, f"{plot_name}.png") + + # Create directory if it doesn't exist + os.makedirs(run_dir, exist_ok=True) + + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.show() + print(f"✓ Training curve saved to {out_png}") + + +def plot_test_metrics(model, test_dataset, run_dir: str, fold_id: int = None, history=None, string: str =None): + """ + Plot comprehensive evaluation metrics for a test dataset using a trained model. + + Args: + model: Trained Keras model. + test_dataset: TensorFlow dataset for testing. + run_dir: Directory to save the plot. + fold_id: Optional fold identifier for naming the output file. + history: Optional training history object to display loss/accuracy curves. + + Returns: Dictionary containing evaluation metrics + """ + # Collect predictions + print("Generating predictions...") + y_pred_proba = model.predict(test_dataset) + y_pred = (y_pred_proba > 0.5).astype(int).flatten() + + # Extract true labels + y_true = [] + for _, labels in test_dataset: + y_true.extend(labels.numpy()) + y_true = np.array(y_true).flatten() + + # Calculate ROC curve + fpr, tpr, _ = roc_curve(y_true, y_pred_proba.flatten()) + roc_auc = auc(fpr, tpr) + + # Create a multi-panel figure + plt.figure(figsize=(15, 10)) + plt.suptitle(f"{string} Evaluation Metrics{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # Panel 1: ROC Curve + plt.subplot(2, 2, 1) + plt.plot(fpr, tpr, color='darkorange', lw=2, + label=f'ROC curve (AUC = {roc_auc:.4f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.8) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('Receiver Operating Characteristic (ROC) Curve') + plt.legend(loc="lower right") + plt.grid(True, alpha=0.3) + + # Panel 2: Confusion Matrix + plt.subplot(2, 2, 2) + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix') + plt.xlabel('Predicted Label') + plt.ylabel('True Label') + + # Panel 3: Loss curve (if history provided) + if history is not None: + plt.subplot(2, 2, 3) + hist = history.history + plt.plot(hist.get("loss", []), label="train", linewidth=2) + plt.plot(hist.get("val_loss", []), label="val", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("BCE Loss") + plt.legend() + plt.grid(True, alpha=0.3) + + # Panel 4: Accuracy curve (if available in history) + plt.subplot(2, 2, 4) + if "binary_accuracy" in hist and "val_binary_accuracy" in hist: + plt.plot(hist["binary_accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_binary_accuracy"], label="val acc", linewidth=2) + elif "accuracy" in hist and "val_accuracy" in hist: + plt.plot(hist["accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_accuracy"], label="val acc", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + plt.title("Accuracy") + plt.legend() + plt.grid(True, alpha=0.3) + else: + # Panel 3: Test metrics when history not available + plt.subplot(2, 2, 3) + + # Calculate metrics + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, zero_division=0) + recall = recall_score(y_true, y_pred, zero_division=0) + f1 = f1_score(y_true, y_pred, zero_division=0) + + metrics = { + 'Accuracy': accuracy, + 'Precision': precision, + 'Recall': recall, + 'F1 Score': f1, + 'AUC': roc_auc + } + + # Create a bar chart for metrics + bars = plt.bar(range(len(metrics)), list(metrics.values()), + color=['skyblue', 'lightcoral', 'lightgreen', 'gold', 'orange'], + alpha=0.8, edgecolor='black', linewidth=1) + plt.xticks(range(len(metrics)), list(metrics.keys()), rotation=45, ha='right') + plt.ylim(0, 1.0) + plt.title('Test Set Evaluation Metrics') + plt.ylabel('Score') + plt.grid(True, alpha=0.3, axis='y') + + # Add values on top of bars + for i, (bar, value) in enumerate(zip(bars, metrics.values())): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, + f'{value:.4f}', ha='center', va='bottom', fontweight='bold') + + # Panel 4: Prediction distribution + plt.subplot(2, 2, 4) + plt.hist(y_pred_proba[y_true == 0], bins=30, alpha=0.7, + label='Negative Class', color='red', density=True) + plt.hist(y_pred_proba[y_true == 1], bins=30, alpha=0.7, + label='Positive Class', color='blue', density=True) + plt.axvline(x=0.5, color='black', linestyle='--', linewidth=2, + label='Decision Threshold') + plt.xlabel('Predicted Probability') + plt.ylabel('Density') + plt.title('Prediction Probability Distribution') + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + + # Create directory if it doesn't exist + os.makedirs(run_dir, exist_ok=True) + + out_png = os.path.join(run_dir, f"{string}_metrics{'_fold' + str(fold_id) if fold_id is not None else ''}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.show() + print(f"✓ Test metrics visualization saved to {out_png}") + + # Calculate and return evaluation metrics + metrics_dict = { + 'roc_auc': roc_auc, + 'accuracy': accuracy_score(y_true, y_pred), + 'precision': precision_score(y_true, y_pred, zero_division=0), + 'recall': recall_score(y_true, y_pred, zero_division=0), + 'f1': f1_score(y_true, y_pred, zero_division=0), + 'confusion_matrix': cm.tolist(), + 'fpr': fpr.tolist(), + 'tpr': tpr.tolist() + } + + # Print summary + print("\n" + "=" * 50) + print("TEST SET EVALUATION SUMMARY") + print("=" * 50) + print(f"Accuracy: {metrics_dict['accuracy']:.4f}") + print(f"Precision: {metrics_dict['precision']:.4f}") + print(f"Recall: {metrics_dict['recall']:.4f}") + print(f"F1 Score: {metrics_dict['f1']:.4f}") + print(f"ROC AUC: {metrics_dict['roc_auc']:.4f}") + print("=" * 50) + + return metrics_dict + +# --------------------------------------------------------------------------- +# TF‑data helper +# --------------------------------------------------------------------------- + +def make_tf_dataset(source, longest_peptide_seq_length, batch: int = 128, shuffle: bool = True, subset: float = None): + if isinstance(source, (str, os.PathLike)): + def gen(): + # Pass longest_peptide_seq_length as target_seq_len + # The generator now yields 3 items + for peps, lats, labs in load_dataset_in_batches(str(source), + target_seq_len=longest_peptide_seq_length, + batch_size=batch, + subset=subset): + yield (peps, lats), labs + output_signature = ( + (tf.TensorSpec(shape=(None, longest_peptide_seq_length, 21), dtype=tf.float32), + tf.TensorSpec(shape=(None, 36, 1152), dtype=tf.float32)), + tf.TensorSpec(shape=(None, 1), dtype=tf.float32) + ) + ds = tf.data.Dataset.from_generator(gen, output_signature=output_signature) + if shuffle: + # Consider increasing buffer_size for better shuffling if dataset is large + # e.g., buffer_size=num_batches_in_epoch or a sufficiently large number + ds = ds.shuffle(buffer_size=max(10, len(pd.read_parquet(str(source))) // batch // 2 ), seed=42) + return ds.prefetch(tf.data.AUTOTUNE) + else: # This part for in-memory dataframes seems correct as it uses longest_peptide_seq_length + peps = peptides_to_onehot(source["long_mer"].tolist(), longest_peptide_seq_length) + labels = source["assigned_label"].values.astype("float32")[:, None] + + if "mhc_embedding" in source.columns: + lats = np.stack(source["mhc_embedding"].values).astype("float32") + elif "mhc_embedding_path" in source.columns: + lats = np.stack([_read_embedding_file(p) for p in source["mhc_embedding_path"]]).astype("float32") + else: + raise ValueError("Need 'mhc_embedding' or 'mhc_embedding_path' column") + + ds = tf.data.Dataset.from_tensor_slices(((peps, lats), labels)) + if shuffle: + ds = ds.shuffle(buffer_size=len(labels), seed=42) + return ds.batch(batch).prefetch(tf.data.AUTOTUNE) + + # peps = source["long_mer"] + # labels = source["assigned_label"] + # lats = np.stack([np.load(p) for p in source["mhc_embedding_path"]]).astype("float32") + # ds = tf.data.Dataset.from_tensor_slices(((peps, lats), labels)) + # if shuffle: + # ds = ds.shuffle(buffer_size=len(labels), seed=42) + # return ds.batch(batch).prefetch(tf.data.AUTOTUNE) + + +# def _collect_batches(loader, max_batch_size: int = 3_000): +# ''' +# Concatenate a lazy *loader* (yielding pep, latent, label, seq_len) +# into single large NumPy arrays **with just one final copy**. +# +# On each iteration we append references to the individual batch arrays; +# only at the very end do we *concatenate* along axis0. This strategy is +# RAM‑friendly because the intermediate lists only hold *views* of the +# already‑allocated batch memory while the batches are alive, and we avoid +# the O(N²) cost of repeated `np.concatenate` calls inside the loop. +# +# If the total number of samples reaches *max_batch_size*, stop loading more. +# ''' +# pep_chunks, lat_chunks, lab_chunks = [], [], [] +# seq_len = None +# total = 0 +# +# for peps, lats, labels, seq_len in loader: +# n = len(peps) +# if total + n > max_batch_size: +# # Only take up to the limit +# take = max_batch_size - total +# if take <= 0: +# break +# pep_chunks.append(peps[:take]) +# lat_chunks.append(lats[:take]) +# labels = labels[:take] +# labels = labels.reshape(-1, 1) if labels.ndim == 1 else labels +# lab_chunks.append(labels) +# total += take +# break +# pep_chunks.append(peps) +# lat_chunks.append(lats) +# labels = labels.reshape(-1, 1) if labels.ndim == 1 else labels +# lab_chunks.append(labels) +# total += n +# if total >= max_batch_size: +# break +# +# # single allocation per tensor ↓ +# peps = np.concatenate(pep_chunks, axis=0, dtype=np.float32) +# lats = np.concatenate(lat_chunks, axis=0, dtype=np.float32) +# labels = np.concatenate(lab_chunks, axis=0, dtype=np.float32) +# +# # explicitly free the now‑unneeded small arrays +# del pep_chunks, lat_chunks, lab_chunks +# return peps, lats, labels, seq_len + +# convenience wrappers ------------------------------------------------------- + +# def collect_parquet(parquet_path: str, batch_size: int = 30_000): +# '''Load **all** samples in *parquet_path* using streaming batches.''' +# loader = load_dataset_in_batches(parquet_path, batch_size=batch_size) +# return _collect_batches(loader) +# +# +# def collect_fold(fold_path: str, batch_size: int = 3_000): +# return collect_parquet(fold_path, batch_size=batch_size) + +# --------------------------------------------------------------------------- +# Training script +# --------------------------------------------------------------------------- + +def main(argv=None): + p = argparse.ArgumentParser() + p.add_argument("--parquet", required=False, + help="Input parquet with peptide, latent, label") + p.add_argument("--dataset_path", default=None, + help="Path to a custom dataset (not parquet)") + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch", type=int, default=128) + p.add_argument("--outdir", default=None, + help="Output dir (default: runs/run_YYYYmmdd-HHMMSS)") + p.add_argument("--val_fraction", type=float, default=0.2, + help="Fraction for train/val split") + args = p.parse_args(argv) + + run_dir = args.outdir or f"runs/run_{datetime.datetime.now():%Y%m%d-%H%M%S}" + pathlib.Path(run_dir).mkdir(parents=True, exist_ok=True) + print(f"★ Outputs → {run_dir}\n") + + # ----------------------- Load & split ---------------------------------- + if args.parquet: + # peps, lats, labels, seq_len = load_dataset(args.parquet) + peps, lats, labels, longest_peptide_seq_length = load_dataset(args.parquet) + + print(f"✓ loaded {len(peps):,} samples" + f" ({peps.shape[1]} residues, {lats.shape[1]} latent features)") + + X_train_p, X_val_p, X_train_l, X_val_l, y_train, y_val = train_test_split( + peps, lats, labels, + test_size=args.val_fraction, + random_state=42, + stratify=labels) + + train_loader = make_tf_dataset((X_train_p, X_train_l, y_train), longest_peptide_seq_length=longest_peptide_seq_length, batch=args.batch, shuffle=True) + val_loader = make_tf_dataset((X_val_p, X_val_l, y_val), longest_peptide_seq_length=longest_peptide_seq_length ,batch=args.batch, shuffle=False) + + elif args.dataset_path: + # Load test sets + test1 = pd.read_parquet(args.dataset_path + "/test1.parquet") + test2 = pd.read_parquet(args.dataset_path + "/test2.parquet") + fold_files = sorted([f for f in os.listdir(os.path.join(args.dataset_path, 'folds')) if f.endswith('.parquet')]) + n_folds = len(fold_files) // 2 + + # Find the longest peptide sequence across all datasets + longest_peptide_seq_length = 0 + + # Check test datasets + if "long_mer" in test1.columns: + longest_peptide_seq_length = max(longest_peptide_seq_length, int(test1["long_mer"].str.len().max())) + if "long_mer" in test2.columns: + longest_peptide_seq_length = max(longest_peptide_seq_length, int(test2["long_mer"].str.len().max())) + + # Check all fold files + for i in range(1, n_folds + 1): + train_path = os.path.join(args.dataset_path, f'folds/fold_{i}_train.parquet') + val_path = os.path.join(args.dataset_path, f'folds/fold_{i}_val.parquet') + + train_df = pd.read_parquet(train_path) + val_df = pd.read_parquet(val_path) + + if "long_mer" in train_df.columns: + longest_peptide_seq_length = max(longest_peptide_seq_length, int(train_df["long_mer"].str.len().max())) + if "long_mer" in val_df.columns: + longest_peptide_seq_length = max(longest_peptide_seq_length, int(val_df["long_mer"].str.len().max())) + + print(f"✓ Longest peptide sequence length across all datasets: {longest_peptide_seq_length}") + + # Create fold datasets with consistent sequence length + folds = [] + for i in range(1, n_folds + 1): + train_path = os.path.join(args.dataset_path, f'folds/fold_{i}_train.parquet') + val_path = os.path.join(args.dataset_path, f'folds/fold_{i}_val.parquet') + + train_loader = make_tf_dataset(train_path, longest_peptide_seq_length=longest_peptide_seq_length, batch=args.batch, shuffle=True) + val_loader = make_tf_dataset(val_path, longest_peptide_seq_length=longest_peptide_seq_length, batch=args.batch, shuffle=False) + + folds.append((train_loader, val_loader)) + + # Create test loaders with the same sequence length + test1_loader = make_tf_dataset(test1, longest_peptide_seq_length=longest_peptide_seq_length, batch=args.batch, shuffle=False) + test2_loader = make_tf_dataset(test2, longest_peptide_seq_length=longest_peptide_seq_length, batch=args.batch, shuffle=False) + + print(f"✓ loaded {len(test1):,} test1 samples, " + f"{len(test2):,} test2 samples, " + f"{len(folds)} folds") + else: + raise ValueError("Need either --parquet or --dataset_path argument") + + # ----------------------- Model ----------------------------------------- + # clear GPU memory + tf.keras.backend.clear_session() + # set random seeds for reproducibility + os.environ["PYTHONHASHSEED"] = "42" + os.environ["TF_DETERMINISTIC_OPS"] = "1" + tf.random.set_seed(42) + np.random.seed(42) + + # # Verify and explicitly pad/trim to match training seq_len + # if test1_seq_len != seq_len or test2_seq_len != seq_len: + # print(f"Adjusting test datasets to match training seq_len: {seq_len}") + # X_test1_p = X_test1_p[:, :seq_len, :] + # X_test2_p = X_test2_p[:, :seq_len, :] + # + # # Pad or trim test peptides to match seq_len + # X_test1_p = X_test1_p[:, :seq_len, :] + # X_test2_p = X_test2_p[:, :seq_len, :] + # test1_ds = make_tf_dataset(X_test1_p, X_test1_l, y_test1, batch=args.batch, shuffle=False) + # test2_ds = make_tf_dataset(X_test2_p, X_test2_l, y_test2, batch=args.batch, shuffle=False) + + # print(f'✓ loaded {len(test1):,} test1 samples, ' + # f'{len(test2):,} test2 samples, ' + # f'{len(folds)} folds') + # + # print("★ Done.") + # # TODO think about ensembling folds later + # + # else: + # raise ValueError("Need either --parquet or --dataset_path argument") + + # ------------------------- TRAIN -------------------------------------- + ckpt_cb = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(run_dir, 'best_weights.h5'), + monitor='val_loss', save_best_only=True, mode='min') + early_cb = tf.keras.callbacks.EarlyStopping( + monitor='val_loss', patience=10, restore_best_weights=True) + + if args.parquet: + tf.keras.backend.clear_session() + os.environ['PYTHONHASHSEED'] = '42' + os.environ['TF_DETERMINISTIC_OPS'] = '1' + model = build_classifier(longest_peptide_seq_length) + model.summary() + + + history = model.fit(train_loader, + validation_data=val_loader, + epochs=args.epochs, + callbacks=[ckpt_cb, early_cb]) + + # plot + plot_training_curve(history, run_dir, fold_id=None, model=model, val_dataset=val_loader) + + # save model and metadata + model.save(os.path.join(run_dir, 'model.h5')) + metadata = { + "epochs": args.epochs, + "batch_size": args.batch, + "longest_peptide_seq_length": longest_peptide_seq_length, + "run_dir": run_dir + } + with open(os.path.join(run_dir, 'metadata.json'), 'w') as f: + json.dump(metadata, f, indent=4) + + + elif args.dataset_path: + for fold_id, (train_loader, val_loader) in enumerate(folds, start=1): + print(f'Training on fold {fold_id}/{len(folds)}') + tf.keras.backend.clear_session() + tf.random.set_seed(42) + np.random.seed(42) + print("########################### seq length: ", longest_peptide_seq_length) + model = build_classifier(longest_peptide_seq_length) + model.summary() + + # print one sample + # print("Sample input shape:", next(iter(train_loader))[0][0].shape) + # print("Sample latent shape:", next(iter(train_loader))[0][1].shape) + # print("Sample label shape:", next(iter(train_loader))[1].shape) + + history = model.fit(train_loader, + validation_data=val_loader, + epochs=args.epochs, + callbacks=[ckpt_cb, early_cb]) + + # plot + plot_training_curve(history, run_dir, fold_id, model, val_loader) + + # save model and metadata + model.save(os.path.join(run_dir, f'model_fold_{fold_id}.h5')) + metadata = { + "fold_id": fold_id, + "epochs": args.epochs, + "batch_size": args.batch, + "seq_len": longest_peptide_seq_length, + "run_dir": run_dir + } + with open(os.path.join(run_dir, f'metadata_fold_{fold_id}.json'), 'w') as f: + json.dump(metadata, f, indent=4) + print(f"✓ Fold {fold_id} model saved to {run_dir}") + + # Evaluate on test sets + print("Evaluating on test1 set...") + test1_results = model.evaluate(test1_loader, verbose=1) + print(f"Test1 results: {test1_results}") + print("Evaluating on test2 set...") + test2_results = model.evaluate(test2_loader, verbose=1) + print(f"Test2 results: {test2_results}") + + # Plot ROC curve for test1 + plot_test_metrics(model, test1_loader, run_dir, fold_id, string="Test1 - balanced alleles") + # Plot ROC curve for test2 + plot_test_metrics(model, test2_loader, run_dir, fold_id, string="Test2 - rare alleles") + + +if __name__ == "__main__": + main([ + # "--parquet", "../data/Custom_dataset/NetMHCpan_dataset/mhc_2/mhc2_with_esm_embeddings.parquet", + "--dataset_path", "../data/Custom_dataset/NetMHCpan_dataset/mhc_2", + "--epochs", "3", "--batch", "32" + ]) \ No newline at end of file diff --git a/user_setting.py b/user_setting.py index 15574e55..49828ac1 100644 --- a/user_setting.py +++ b/user_setting.py @@ -1,7 +1,7 @@ ##### PLEASE UPDATE ##### #Absolute path to NetMHCIPan executable file e.g. 'home/user/netMHCpan-4.1/netMHCpan' -netmhcipan_path = '/home/amir/amir/ParseFold/PMGen/netMHCIpan-4.1/netMHCpan' -netmhciipan_path = '/home/amir/amir/ParseFold/PMGen/netMHCIIpan-4.3/netMHCIIpan' +netmhcipan_path = '/home/amirreza/Desktop/NetMHCPan/netMHCpan-4.1/netMHCpan' +netmhciipan_path = '/home/amirreza/Desktop/NetMHCPan/netMHCIIpan-4.3/netMHCIIpan' ##### Do not Change ####### import os diff --git a/utils/create_ESM_dataset.py b/utils/create_ESM_dataset.py new file mode 100644 index 00000000..28b35898 --- /dev/null +++ b/utils/create_ESM_dataset.py @@ -0,0 +1,505 @@ +#!/usr/bin/env python +""" +Create a parquet dataset that combines metadata with ESM-3 embeddings +for each MHC allele. + +Usage: python create_dataset.py +""" + +import os +import pathlib +import re + +import numpy as np +import pandas as pd +from tqdm import tqdm # progress bars +from typing import Dict +from processing_functions import create_progressive_k_fold_cross_validation, create_k_fold_leave_one_out_stratified_cv, normalize_netmhcpan_allele_to_pmgen + +# --------------------------------------------------------------------- +# 1. CONFIGURATION – adjust if your paths change +# --------------------------------------------------------------------- +dataset_name = "NetMHCpan_dataset" # "PMGen_sequences" # "NetMHCpan_dataset" +mhc_class = 1 +CSV_PATH = pathlib.Path(f"../data/NetMHCpan_dataset/combined_data_{mhc_class}.csv") # Training dataset +NPZ_PATH = pathlib.Path( + f"../data/ESM/esmc_600m/{dataset_name}/mhc{mhc_class}_encodings.npz" +) +OUT_PARQUET = pathlib.Path( + f"../data/Custom_dataset/{dataset_name}/mhc_{mhc_class}/mhc{mhc_class}_with_esm_embeddings.parquet" +) # Training dataset with ESM embeddings + +BENCHMARKS_PATH = pathlib.Path( + f"../data/Custom_dataset/benchmarks/mhc_{mhc_class}" +) + +# If you want to *save* each array as its own .npy rather than store the +# full tensor inside parquet, set this to a directory; otherwise None. +EMB_OUT_DIR = pathlib.Path( + f"../data/Custom_dataset/{dataset_name}/mhc_{mhc_class}/mhc{mhc_class}_encodings" +) # or `None` to embed directly + +AUGMENTATION = "down_sampling" # "GNUSS", None +K = 10 # Number of folds for cross-validation +# --------------------------------------------------------------------- + + +def load_mhc_embeddings(npz_path: pathlib.Path) -> Dict[str, np.ndarray]: + """ + Returns {allele_name: 187×1152 tensor}. + """ + emb_dict = {} + with np.load(npz_path) as npz_file: + for k in npz_file.files: + emb = npz_file[k] + # if emb.shape != (36, 1152): + # print(f"[WARN] Skipping {k}: shape {emb.shape} != (36,1152)") + # continue + emb_dict[k] = emb + print(f"Loaded embeddings for {len(emb_dict):,} alleles.") + return emb_dict + + +def normalise_netmhcpan_allele(a: str) -> str: + """ + Make the allele string format identical to the keys inside the NPZ. + Tweak this if your formats differ! + """ + # remove * and : from the allele name + # and convert to upper case + # e.g. "HLA-A*01:01" -> "HLA-A0101" + # e.g. "HLA-A*01:01:01" -> "HLA-A010101" + # remove spaces + # e.g. "HLA-A 01:01" -> "HLA-A0101" + # TODO fix eg. HLA-DRA/HLA-DRB1_0101 > DRB10101 eg. HLA-DRA/HLA-DRB1_0401 > DRB10401 + # Format heterodimer allele strings, e.g. "HLA-DRA/HLA-DRB1_0101" -> "DRB10101" + a = a.strip() + if "HLA-DRA/" in a: + # For heterodimers, take the second chain and format as e.g. DRB10101 + _, second = a.split("/", 1) + a = second.replace("HLA-", "") + + elif "mice-" in a: + _, second = a.split("/", 1) + a = second.replace("mice-", "") + else: + a = a.replace("/HLA", "") + # remove "*" and spaces + a = a.replace("*", "").replace(" ", "") + + return a + + +# def normalise_allele2(a: str) -> tuple[str, str]: +# """ +# Make the allele string format identical to the keys inside the NPZ. +# Tweak this if your formats differ! +# """ +# # remove * and : from the allele name +# # and convert to upper case +# # e.g. "HLA-A*01:01" -> "HLA-A0101" +# # e.g. "HLA-A*01:01:01" -> "HLA-A010101" +# # remove spaces +# # e.g. "HLA-A 01:01" -> "HLA-A0101" +# a = a.strip() +# if "HLA-DRA/" in a: +# # eg. "HLA-DRA/HLA-DRB1_0401" -> a: "HLA-DRA1_0401" b: "HLA-DRB1_0401" +# # For heterodimers, take the second chain and format as e.g. DRB10101 +# _, second = a.split("/", 1) +# b = second.replace("HLA-", "") +# a = b.replace("HLA-DRB1_", "HLA-DRA1_") +# return a, b +# +# # TODO fix later +# # elif "mice-" in a: +# # _, second = a.split("/", 1) +# # a = second.replace("mice-", "") +# # return a, a +# +# else: +# a = a.replace("/HLA", "") +# # remove "*" and spaces +# a = a.replace("*", "").replace(" ", "") +# return a, None + + +def normalise_allele_NetMHCPan_toPMGen(a: str) -> tuple[str, str]: + a = a.strip() + # remove : * _ from the allele name + # e.g. "HLA-A*01:01" -> "HLA-A0101" + + # + if ":" in a: + # remove ":" from the allele name + # e.g. "HLA-A*01:01" -> "HLA-A0101" + a = a.replace(":", "") + if "HLA-DRA/" in a: + _, second = a.split("/", 1) + b = second.replace("HLA-", "").replace("*", "").replace(" ", "") + a1 = b.replace("HLA-DRB1_", "HLA-DRA1_") + return a1, b + if "_" in a: + name, digits = a.split("_", 1) + name = name.upper() + if not name.startswith("HLA-"): + name = f"HLA-{name}" + return f"{name}{digits}", None + if "BoLa-" in a: + # eg. "BoLa-1*04:01" -> "BOLA-10401" + prefix, rest = a.split("-", 1) + prefix = prefix.replace("BoLa", "BOLA").upper() + digits = rest.replace("*", "").replace(":", "") + return f"{prefix}{digits}", None + if "*" in a: + prefix, rest = a.split("*", 1) + prefix = prefix.upper() + digits = rest.replace(":", "") + return f"{prefix}{digits}", None + + return a, None + + + +def attach_embeddings( + df: pd.DataFrame, + emb_dict: Dict[str, np.ndarray], + out_dir: pathlib.Path = None, +) -> pd.DataFrame: + """ + Add either: + • a column "mhc_embedding" holding the full tensor (object dtype) + • or a column "mhc_embedding_path" pointing to a .npy on disk. + + Rows with no matching embedding are dropped (could also be imputed). + """ + # Pre-compute paths if we are writing .npy files + if out_dir is not None: + out_dir.mkdir(parents=True, exist_ok=True) + + paths, embeds = [], [] + # Extract unique alleles + unique_alleles = df["allele"].unique() + print(f"Processing {len(unique_alleles)} unique alleles...") + allele_to_key_map = {} + allele_to_emb_map = {} + + # Process each unique allele once + for allele in tqdm(unique_alleles, desc="Creating allele mapping"): + key = normalise_netmhcpan_allele(allele) + emb = emb_dict.get(key) + + if emb is None: + if mhc_class == 2: + key1, key2 = normalize_netmhcpan_allele_to_pmgen(allele) + print(key1, key2) + emb1 = emb_dict.get(key1) + emb2 = emb_dict.get(key2) + if emb1 is not None and emb2 is not None: + try: + emb = np.concatenate([emb1, emb2], axis=0) + except ValueError as e: + print(f"Error concatenating embeddings for {allele}: {e}") + emb = None + elif emb1 is not None: + emb = emb1 + elif emb2 is not None: + emb = emb2 + else: + print(f"No embedding found for {allele} (keys tried: {key1}, {key2})") + emb = None + # # Try to find the first chain + # prefix_matches1 = [k for k in emb_dict if k.startswith(key1)] if key1 else [] + # if key2 is not None: + # # Try to find the second chain + # prefix_matches2 = [k for k in emb_dict if k.startswith(key2)] + # if prefix_matches1 and prefix_matches2: + # key = max(prefix_matches1, key=len) + # print(f"{allele} -> {key}") + # emb = emb_dict.get(key) + # if emb is None: + # key = max(prefix_matches2, key=len) + # print(f"{allele} -> {key}") + # emb = emb_dict.get(key) + # else: + # key = max(prefix_matches1, key=len) if prefix_matches1 else None + # print(f"{allele} -> {key}") + # emb = emb_dict.get(key) + else: + key1, _ = normalize_netmhcpan_allele_to_pmgen(allele) + print(key1) + # # find exact prefix matches for longer names + # prefix_matches = [k for k in emb_dict if k.startswith(key1)] + # if prefix_matches: + # key = max(prefix_matches, key=len) + print(f"{allele} -> {key1}") + emb = emb_dict.get(key1) + + allele_to_key_map[allele] = key + allele_to_emb_map[allele] = emb + + # Now use the maps for all rows + paths, embeds = [], [] + for allele in tqdm(df["allele"], desc="Attaching embeddings"): + key = allele_to_key_map.get(allele) + emb = allele_to_emb_map.get(allele) + + if emb is None: + paths.append(None) + embeds.append(None) + continue + + if out_dir is None: + embeds.append(emb) + paths.append(None) + else: + fname = key.replace("*", "").replace(" ", "").replace("/HLA", "") + ".npy" + fpath = out_dir / fname + if not fpath.exists(): # avoid double-saving + np.save(fpath, emb) + paths.append(str(fpath)) + embeds.append(None) + + if out_dir is None: + df = df.assign(mhc_embedding=embeds) + else: + df = df.assign(mhc_embedding_path=paths) + + # Drop rows where we could not find an embedding + n_before = len(df) + # print unique alleles of missing embeddings + if "mhc_embedding" in df.columns: + missing_alleles = df.loc[df["mhc_embedding"].isna(), "allele"].unique() + else: + missing_alleles = df.loc[df["mhc_embedding_path"].isna(), "allele"].unique() + if len(missing_alleles) > 0: + print(f"Missing embeddings for {len(missing_alleles):,} alleles:") + print(", ".join(sorted(missing_alleles))) + else: + print("✓ All embeddings found.") + df = df.dropna( + subset=["mhc_embedding"] if out_dir is None else ["mhc_embedding_path"] + ) + print(f"Dropped {n_before - len(df):,} rows with no MHC-{mhc_class} embedding.") + return df + + +def create_test_set(df: pd.DataFrame, bench_df= pd.DataFrame, samples_per_label: int=10000) -> dict[str, pd.DataFrame]: + """ + Create two test sets: + - test1: equally sample samples_per_label from each assigned_label + - test2: select all entries of the allele with the lowest count in train + """ + datasets = {} + train_df_ = df.copy() + + # test1: sample equally from each label + test1 = ( + train_df_ + .groupby('assigned_label', group_keys=False) + .sample(n=samples_per_label, random_state=42) + .reset_index(drop=True) + ) + train_mask = ~train_df_.index.isin(test1.index) + train_updated = train_df_.loc[train_mask].reset_index(drop=True) + + # remove alleles that are in the benchmark dataset from the train set + if not bench_df.empty: + # ensure the allele column is of type string + bench_df['allele'] = bench_df['allele'].astype(str) + # filter out alleles that are in the benchmark dataset + train_updated = train_updated[ + ~train_updated['allele'].isin(bench_df['allele']) + ].reset_index(drop=True) + + # drop the rows with nan assigned_label, allele and long_mer from the benchmark dataset + bench_df = bench_df.dropna(subset=['assigned_label', 'allele', 'long_mer']) + # ensure the assigned_label is of type int + bench_df['assigned_label'] = bench_df['assigned_label'].astype(int) + # ensure the long_mer is of type string + bench_df['long_mer'] = bench_df['long_mer'].astype(str) + # ensure the allele is of type string + bench_df['allele'] = bench_df['allele'].astype(str) + # ensure the mhc_class is of type int + bench_df['mhc_class'] = bench_df['mhc_class'].astype(int) + datasets["benchmark_dataset"] = bench_df + + + # test2: remove lowest-frequency allele from train to form test2 + allele_counts = train_updated['allele'].value_counts() + n_test2_samples = 0 + i = 1 + while n_test2_samples < 1000: + i += 1 + lowest_alleles = allele_counts.nsmallest(n=i).index.tolist() + test2 = train_updated[train_updated['allele'].isin(lowest_alleles)].copy() + n_test2_samples = test2.shape[0] + + train_updated = ( + train_updated + [train_updated['allele'].isin(lowest_alleles) == False] + .reset_index(drop=True) + ) + + datasets['train'] = train_updated + datasets['test1'] = test1 + datasets['test2'] = test2 + + return datasets + + +def main() -> None: + print("→ Loading cleaned NetMHCpan CSV") + df = pd.read_csv( + CSV_PATH, + usecols=["long_mer", "assigned_label", "allele", "mhc_class"], + ) + + # filter out + df = df[df["mhc_class"] == mhc_class] + + print(len(df), "rows in the dataset after filtering for MHC class", mhc_class) + + # drop mhc_class column + df = df.drop(columns=["mhc_class"]) + + print(f"→ Loading MHC class {mhc_class} embeddings") + emb_dict = load_mhc_embeddings(NPZ_PATH) + + # save keys to a text file + emb_keys_path = NPZ_PATH.parent / f"mhc{mhc_class}_emb_keys.txt" + with open(emb_keys_path, "w") as f: + for key in emb_dict.keys(): + f.write(f"{key}\n") + + print("→ Merging") + df = attach_embeddings(df, emb_dict, EMB_OUT_DIR) + + print(len(df), "rows in the dataset after filtering for Mer") + + # print("→ Dropping duplicates and nones") + df = df.drop_duplicates(subset=["long_mer", "allele"]) + df = df.dropna(subset=["long_mer"]) + + print(len(df), "rows in the dataset after dropping duplicates and Nones") + + # move the label column to the end + label_col = "assigned_label" + cols = [col for col in df.columns if col != label_col] + [label_col] + df = df[cols] + + print(len(df), "rows in the dataset after dropping duplicates and Nones") + + print(f"→ Writing parquet to {OUT_PARQUET}") + df.to_parquet(OUT_PARQUET, engine="pyarrow", index=False, compression="zstd") + + print(f"→ Dataset shape: {df.shape}") + + + # load benchmark datasets + # for folder in path and for file in folder read them, then save them in (OUT_PARQUET.parent / "benchmarks").mkdir(parents=True, exist_ok=True) + # with benchmark_{file_name}.parquet + + # load and save benchmark datasets + benchmarks_dir = OUT_PARQUET.parent / "benchmarks" + benchmarks_dir.mkdir(parents=True, exist_ok=True) + + for folder in BENCHMARKS_PATH.iterdir(): + if not folder.is_dir(): + continue + for csv_file in folder.glob("*.csv"): + print(f"Loading benchmark dataset {csv_file.name}") + tmp = pd.read_csv(csv_file, usecols=["long_mer", "assigned_label", "allele", "mhc_class"]) + tmp["assigned_label"] = tmp["assigned_label"].astype(int) + tmp["allele"] = tmp["allele"].astype(str) + tmp["long_mer"] = tmp["long_mer"].astype(str) + tmp = attach_embeddings(tmp, emb_dict, EMB_OUT_DIR) + out_path = benchmarks_dir / f"benchmark_{csv_file.stem}.parquet" + tmp.to_parquet(out_path, index=False, engine="pyarrow", compression="zstd") + print(f"Saved benchmark to {out_path}") + + bench_df1 = pd.read_csv( + "../data/Custom_dataset/benchmark_Conbot.csv") + # rename columns to match the main dataset + bench_df1 = bench_df1.rename(columns={ + "binding_label": "assigned_label", + }) + # ensure the assigned_label is of type int + bench_df1['assigned_label'] = bench_df1['assigned_label'].astype(int) + # ensure the allele is of type string + bench_df1['allele'] = bench_df1['allele'].astype(str) + # convert II to 2 # and I to 1 + bench_df1['mhc_class'] = bench_df1['mhc_class'].replace({"II": 2, "I": 1}) + + print(bench_df1.columns) + # TODO process later + bench_df2 = pd.read_csv( + "../data/Custom_dataset/benchmark_ConvNeXT.csv", + usecols=["allele"], + + ) + + # combine benchmark datasets + bench_df = pd.concat([bench_df1, bench_df2], ignore_index=True) + + print("→ Create and save test sets") + datasets = create_test_set(df, bench_df) + for name, subset in datasets.items(): + print(f"{name.capitalize()} set shape: {subset.shape}") + if name != "train": + subset.to_parquet(OUT_PARQUET.parent / f"{name}.parquet", index=False, engine="pyarrow", compression="zstd") + + ### + # TODO remove later + print("→ Loading existing dataset from parquet") + datasets = {} + datasets['train'] = pd.read_parquet(OUT_PARQUET, engine="pyarrow") + ### + + # Drop NaNs before converting to int + print("→ number of rows in train set before normalization:", len(datasets['train'])) + datasets['train'] = datasets['train'].dropna(subset=['assigned_label', 'allele']) + print("→ number of rows in train set after normalization:", len(datasets['train'])) + datasets['train']['assigned_label'] = datasets['train']['assigned_label'].astype(int) + + print("→ Creating cross-validation folds") + folds = create_k_fold_leave_one_out_stratified_cv( + datasets['train'], + target_col="assigned_label", + k=K, + id_col="allele", + augmentation=AUGMENTATION, + ) + + print("→ Saving folds to CSV") + # Ensure the output directory exists /folds + (OUT_PARQUET.parent / "folds").mkdir(parents=True, exist_ok=True) + held_out_ids_path = OUT_PARQUET.parent / "folds" / "held_out_ids.txt" + if held_out_ids_path.exists(): + held_out_ids_path.unlink() + + for fold_id, (train_df, val_df, validation_ids) in enumerate(folds, start=1): + train_path = OUT_PARQUET.parent / f"folds/fold_{fold_id}_train.parquet" + val_path = OUT_PARQUET.parent / f"folds/fold_{fold_id}_val.parquet" + train_df.to_parquet(train_path, index=False, engine="pyarrow", compression="zstd") + val_df.to_parquet(val_path, index=False, engine="pyarrow", compression="zstd") + print(f"Saved fold {fold_id} train to {train_path}") + print(f"Saved fold {fold_id} val to {val_path}") + if isinstance(validation_ids, list): + ids_str = ", ".join(map(str, validation_ids)) + else: + ids_str = str(validation_ids) + with open(held_out_ids_path, "a") as f: + f.write(f"Fold {fold_id}: {ids_str}\n") + + + # save only positive labels to a separate file + pos_labels = datasets['train'][datasets['train']['assigned_label'] == 1] + pos_labels_path = OUT_PARQUET.parent / "folds" / "positive_labels.parquet" + pos_labels.to_parquet(pos_labels_path, index=False, engine="pyarrow", compression="zstd") + print(f"Saved positive labels to {pos_labels_path}") + + print("✓ Done") + + +if __name__ == "__main__": + main() diff --git a/utils/model.py b/utils/model.py index 3a2d45e7..1014dc4b 100644 --- a/utils/model.py +++ b/utils/model.py @@ -1,39 +1,46 @@ +# -*- coding: utf-8 -*- import tensorflow as tf -import keras -from keras import layers +from tensorflow import keras +from tensorflow.keras import layers -# ------------------------------ -# SCQ Layer as defined before -# ------------------------------ -class SCQ(layers.Layer): - def __init__(self, num_embed=64, dim_embed=32, dim_input=128, lambda_reg=0.1, proj_iter=10, - descrete_loss=False, beta_loss=0.25, **kwargs): + +class SCQ_layer(layers.Layer): + def __init__(self, num_embed, dim_embed, lambda_reg=1.0, proj_iter=10, + discrete_loss=False, beta_loss=0.25, reset_dead_codes=False, + usage_threshold=1e-3, reset_interval=5, **kwargs): """ Soft Convex Quantization (SCQ) layer. Args: num_embed: Number of codebook vectors. dim_embed: Embedding dimension (for both encoder and codebook). - dim_input: Dimension of input features (projected to dim_embed). lambda_reg: Regularization parameter in the SCQ objective. proj_iter: Number of iterations for the iterative simplex projection. - descrete_loss: if True, commitment and codebook loss are calculated similar - to VQ-VAE loss, with stop-gradient operation. If False, only quantization error - is calculated as loss |Z_q - Z_e|**2 - beta_loss: A float to be used in VQ loss. Only used if descrete_loss==True. - multiplied to (1-b)*commitment_loss + b*codebook_loss + discrete_loss: If True, commitment and codebook loss are calculated similar + to VQ-VAE loss, with stop-gradient operation. If False, only the quantization error + is calculated as the loss |Z_q - Z_e|**2. + beta_loss: A float used in VQ loss. Only applicable if discrete_loss is True. + Multiplied to (1-beta_loss)*commitment_loss + beta_loss*codebook_loss. + reset_dead_codes: Whether to reset unused codebook vectors periodically. + usage_threshold: Minimum usage threshold for codebook vectors to avoid being reset. + reset_interval: Number of calls before checking and resetting dead codebooks. """ super().__init__(**kwargs) self.num_embed = num_embed self.dim_embed = dim_embed - self.dim_input = dim_input self.lambda_reg = lambda_reg self.proj_iter = proj_iter - self.descrete_loss = descrete_loss + self.discrete_loss = discrete_loss self.beta_loss = beta_loss self.epsilon = 1e-5 + # Codebook reset parameters + self.call_count = 0 + self.reset_dead_codes = reset_dead_codes + self.usage_threshold = usage_threshold + self.reset_interval = reset_interval + def build(self, input_shape): # Learnable scale and bias for layer normalization. self.gamma = self.add_weight( @@ -54,7 +61,7 @@ def build(self, input_shape): name='scq_embedding') # Projection weights to map input features into the embedding space. self.d_w = self.add_weight( - shape=(self.dim_input, self.dim_embed), + shape=(self.dim_embed, self.dim_embed), initializer='random_normal', trainable=True, name='scq_w') @@ -63,14 +70,27 @@ def build(self, input_shape): initializer='zeros', trainable=True, name='scq_b') + # Usage tracking for codebook vectors + self.code_usage = self.add_weight( + shape=(self.num_embed,), + initializer='zeros', + trainable=False, + name='code_usage', + dtype=tf.float32 + ) def call(self, inputs): """ - Forward pass for SCQ. + Forward pass of the SCQ layer. Args: - inputs: Tensor of shape (B, N, dim_input) + inputs: Input tensor of shape (B, N, dim_embed), where B is the batch size, + N is the number of input features, and dim_embed is the embedding dimension. Returns: - Quantized output: Tensor of shape (B, N, dim_embed) + Zq: Quantized output tensor of shape (B, N, dim_embed). + out_P_proj: + loss: The total loss calculated during the forward pass. + perplexity: The perplexity of the quantized output. + """ # 1. Project inputs to the embedding space and apply layer normalization. x = tf.matmul(inputs, self.d_w) + self.d_b # (B, N, dim_embed) @@ -79,6 +99,9 @@ def call(self, inputs): input_shape = tf.shape(x) # (B, N, dim_embed) flat_inputs = tf.reshape(x, [-1, self.dim_embed]) # (B*N, dim_embed) + # add a small epsilon to avoid numerical issues # TODO added + # flat_inputs = flat_inputs + tf.random.normal(tf.shape(flat_inputs), stddev=0.001) + # 2. Compute hard VQ assignments as initialization. flat_detached = tf.stop_gradient(flat_inputs) x_norm_sq = tf.reduce_sum(flat_detached ** 2, axis=1, keepdims=True) # (B*N, 1) @@ -90,6 +113,18 @@ def call(self, inputs): P_tilde = tf.one_hot(assign_indices, depth=self.num_embed, dtype=tf.float32) # (B*N, num_embed) P_tilde = tf.transpose(P_tilde) # (num_embed, B*N) + # Track codebook usage # TODO added + if self.reset_dead_codes: + # Update usage counts (moving average) + batch_usage = tf.reduce_mean(tf.transpose(P_tilde), axis=0) # (num_embed,) + decay = 0.99 # Exponential moving average decay factor + self.code_usage.assign(decay * self.code_usage + (1 - decay) * batch_usage) + + # Periodically check for dead codes and reset them + self.call_count += 1 + if self.call_count % self.reset_interval == 0: + self._reset_dead_codes(flat_inputs) + # 3. Solve the SCQ optimization via a linear system. C = self.embed_w # (dim_embed, num_embed) Z = tf.transpose(flat_inputs) # (dim_embed, B*N) @@ -102,28 +137,97 @@ def call(self, inputs): # 4. Project each column of P_sol onto the probability simplex. P_proj = self.project_columns_to_simplex(P_sol) # (num_embed, B*N) - # P_proj = tf.nn.softmax(P_sol, axis=0) out_P_proj = tf.transpose(P_proj) # (B*N, num_embed) + # save in a file for debugging + # print out for debugging + # tf.print("P_proj:", P_proj) + # tf.print("P_proj min:", tf.reduce_min(P_proj)) + # tf.print("P_proj max:", tf.reduce_max(P_proj)) + # tf.print("P_proj shape:", tf.shape(P_proj)) + # tf.print("out_P_proj1:", out_P_proj) + # tf.print("out_P_proj1 min:", tf.reduce_min(out_P_proj)) + # tf.print("out_P_proj1 max:", tf.reduce_max(out_P_proj)) + # tf.print("out_P_proj1 shape:", tf.shape(out_P_proj)) + out_P_proj = tf.reshape(out_P_proj, (input_shape[0], input_shape[1], -1)) # (B,N,num_embed) - # out_P_proj = tf.reduce_mean(out_P_proj, axis=-1, keepdims=True) #(B,N,1) + # tf.print("out_P_proj2:", out_P_proj) + # tf.print("out_P_proj2 min:", tf.reduce_min(out_P_proj)) + # tf.print("out_P_proj2 max:", tf.reduce_max(out_P_proj)) + # tf.print("out_P_proj2 shape:", tf.shape(out_P_proj)) + + # compute perplexity + perplexity = self.compute_perplexity(out_P_proj) # (B,N,num_embed) + # tf.print("perplexity out_proj2:", perplexity) + + # average over the last dimension + # out_P_proj = tf.reduce_mean(out_P_proj, axis=-1, keepdims=True) # (B,N,1) + # tf.print("out_P_proj3:", out_P_proj) + # tf.print("out_P_proj3 min:", tf.reduce_min(out_P_proj)) + # tf.print("out_P_proj3 max:", tf.reduce_max(out_P_proj)) + # tf.print("out_P_proj3 shape:", tf.shape(out_P_proj)) + + # # check if the shape of out_P_proj is (B,N,1) + # if out_P_proj.shape[-1] != 1: + # raise ValueError(f"Expected out_P_proj shape to be (B,N,1), but got {out_P_proj.shape}") # 5. Reconstruct quantized output: Z_q = C * P_proj. Zq_flat = tf.matmul(C, P_proj) # (dim_embed, B*N) Zq = tf.transpose(Zq_flat) # (B*N, dim_embed) - # 6. calculate quantization loss, combined commitment and codebook losses - if not self.descrete_loss: - loss = tf.reduce_mean((Zq - flat_inputs) ** 2) # (B*N, dim_embed) + # 6. Calculate quantization loss and add regularization penalties + # Calculate common regularization terms first + p_mean = tf.reduce_mean(out_P_proj, axis=[0, 1]) # Average usage across batch (num_embed,) + # Ensure probabilities sum to 1 for stable entropy calculation + p_mean_normalized = p_mean / (tf.reduce_sum(p_mean) + self.epsilon) + # Maximize entropy: encourages uniform cluster usage. Negative sign makes it a penalty. + # Adding epsilon inside log avoids log(0). + entropy_reg = -tf.reduce_sum(p_mean_normalized * tf.math.log(p_mean_normalized + self.epsilon)) + + # Calculate cosine similarities between codebook vectors + norm_codebook = tf.nn.l2_normalize(self.embed_w, axis=0) # (dim_embed, num_embed) + similarities = tf.matmul(tf.transpose(norm_codebook), norm_codebook) # (num_embed, num_embed) + # Create a mask to ignore self-similarities (diagonal) + mask = tf.ones_like(similarities) - tf.eye(self.num_embed) + masked_similarities = similarities * mask + # Penalize positive cosine similarities between different codebook vectors + # Encourages codebook vectors to be dissimilar (orthogonal ideally) + diversity_loss = tf.reduce_mean(tf.nn.relu(masked_similarities)) + + # Define weights for regularization terms (these can be tuned) + alpha_entropy = 0.5 # Weight for entropy regularization penalty + alpha_diversity = 0.5 # Weight for diversity penalty + + # Calculate the primary loss based on discrete_loss flag + if not self.discrete_loss: + # Original SCQ loss (quantization error) + primary_loss = tf.reduce_mean((Zq - flat_inputs) ** 2) else: + # VQ-VAE style loss commitment_loss = tf.reduce_mean((tf.stop_gradient(Zq) - flat_inputs) ** 2) - codebook_loss = tf.reduce_mean((Zq - flat_detached) ** 2) - loss = ( + codebook_loss = tf.reduce_mean((Zq - tf.stop_gradient(flat_inputs)) ** 2) # Corrected VQ codebook loss + primary_loss = ( (tf.cast(1 - self.beta_loss, tf.float32) * commitment_loss) + - (tf.cast(self.beta, tf.float32) * codebook_loss) + (tf.cast(self.beta_loss, tf.float32) * codebook_loss) ) + # Combine primary loss with regularization penalties + loss = primary_loss + alpha_entropy * entropy_reg + alpha_diversity * diversity_loss + Zq = tf.reshape(Zq, input_shape) # (B, N, dim_embed) - return Zq, out_P_proj, loss # (B,N,embed_dim), #(B,N,num_embed), (1,) + self.add_loss(loss) # Register the loss with the layer + + # Get hard indices and create one-hot representation # TODO added + # hard_indices = tf.argmax(out_P_proj, axis=-1) # (B,N) + # cast to int32 for one-hot encoding + # out_P_proj_cast = tf.cast(out_P_proj*10, tf.int32) # (B,N,num_embed) + # one_hot_indices = tf.one_hot(hard_indices, depth=self.num_embed) # (B,N,num_embed) + + # 7. Calculate perplexity (similar to VQ layer) + # flat_P_proj = tf.reshape(out_P_proj, [-1, self.num_embed]) # (B*N, num_embed) + # avg_probs = tf.reduce_mean(flat_P_proj, axis=0) # (num_embed,) + # perplexity = tf.exp(-tf.reduce_sum(avg_probs * tf.math.log(avg_probs + 1e-10))) + + return Zq, out_P_proj, loss, perplexity def project_columns_to_simplex(self, P_sol): """Projects columns of a matrix to the probability simplex.""" @@ -153,6 +257,14 @@ def project_columns_to_simplex(self, P_sol): P_projected = tf.nn.relu(P_t - theta[:, tf.newaxis]) return tf.transpose(P_projected) + def compute_perplexity(slef, out_P_proj): + p_j = tf.reduce_mean(out_P_proj, axis=[0, 1]) # (B, N, K) -> (K,) + p_j = tf.clip_by_value(p_j, 1e-10, 1 - (1e-9)) + p_j = p_j / tf.reduce_sum(p_j) # Normalize to ensure sum to 1 + entropy = -tf.reduce_sum(p_j * tf.math.log(p_j) / tf.math.log(2.0)) # Entropy: -sum(p_j * log2(p_j)) + perplexity = tf.pow(2.0, entropy) # Perplexity: 2^entropy + return perplexity + def layernorm(self, inputs): """ Custom layer normalization over the last dimension. @@ -162,318 +274,2028 @@ def layernorm(self, inputs): normed = (inputs - mean) / tf.sqrt(variance + self.epsilon) return self.gamma * normed + self.beta + def _reset_dead_codes(self, encoder_outputs): + dead_codes = tf.where(self.code_usage < self.usage_threshold) + num_dead = tf.shape(dead_codes)[0] + if num_dead > 0: + tf.print(f"Resetting {num_dead} dead codebook vectors") + most_used_idx = tf.argsort(self.code_usage, direction='DESCENDING')[:tf.maximum(3, num_dead)] + most_used = tf.gather(tf.transpose(self.embed_w), most_used_idx) + batch_size = tf.shape(encoder_outputs)[0] + for i in range(num_dead): + dead_idx = tf.cast(dead_codes[i][0], tf.int32) + if i % 2 == 0 and batch_size > 0: + random_idx = tf.random.uniform(shape=[], minval=0, maxval=batch_size, dtype=tf.int32) + new_vector = encoder_outputs[random_idx] + else: + source_idx = i % tf.shape(most_used)[0] + source_vector = most_used[source_idx] + noise = tf.random.normal(shape=tf.shape(source_vector), stddev=0.5) + new_vector = source_vector + noise + new_vector = new_vector / (tf.norm(new_vector) + 1e-8) * tf.norm(source_vector) + self.embed_w[:, dead_idx].assign(tf.reshape(new_vector, [self.dim_embed])) + self.code_usage[dead_idx].assign(0.2) -class Encoder(layers.Layer): - def __init__(self, dim_input, dim_embed, heads=4, names='encoder', **kwargs): + +class SCQ1DAutoEncoder(keras.Model): + def __init__(self, input_dim, num_embeddings, embedding_dim, commitment_beta, + scq_params, initial_codebook=None, cluster_lambda=1.0): super().__init__() - self.dim_input = dim_input - self.dim_embed = dim_embed - self.heads = heads - self.names = names - self.epsilon = 1e-5 - self.scale = 1 / tf.sqrt(tf.cast(dim_embed, dtype=tf.float32)) + self.input_dim = input_dim # e.g., (1024,) + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.commitment_beta = commitment_beta - def build(self, input_shape): - # Learnable scale and bias for layer normalization. - self.gamma = self.add_weight( - shape=(self.dim_embed,), - initializer="ones", - trainable=True, - name=f"ln_gamma_{self.names}") - self.beta = self.add_weight( - shape=(self.dim_embed,), - initializer="zeros", - trainable=True, - name=f"ln_beta_{self.names}") - # Projection weights to map input features into the embedding space. - self.d_w = self.add_weight( - shape=(self.dim_input, self.dim_embed), - initializer='random_normal', - trainable=True, - name=f'd_w_{self.names}') - self.d_b = self.add_weight( - shape=(self.dim_embed,), - initializer='zeros', - trainable=True, - name=f'd_b_{self.names}') - # gating - self.d_g = self.add_weight( - shape=(self.heads, 1, self.dim_embed, self.dim_embed), - initializer='uniform', - trainable=True, - name=f'g_w_{self.names}') - # attention - self.q = self.add_weight( - shape=(self.heads, 1, self.dim_embed, self.dim_embed), - initializer='random_normal', - trainable=True, - name=f'query_{self.names}') - self.k = self.add_weight( - shape=(self.heads, 1, self.dim_embed, self.dim_embed), - initializer='random_normal', - trainable=True, - name=f'key_{self.names}') - self.v = self.add_weight( - shape=(self.heads, 1, self.dim_embed, self.dim_embed), - initializer='random_normal', - trainable=True, - name=f'value_{self.names}') - # final projection - self.d_o = self.add_weight(shape=(int(self.heads * self.dim_embed), self.dim_embed), - initializer='random_normal', - trainable=True, name=f'dout_{self.names}') + # weight on the new cluster‐consistency loss + self.cluster_lambda = cluster_lambda - def call(self, inputs): - x = tf.matmul(inputs, self.d_w) + self.d_b # (B,N,I)->(B,N,E) - x = self.layernorm(x) - # attention - query = tf.matmul(x, self.q) # (B,N,E)@(H,1,E,E)->(H,B,N,E) - key = tf.matmul(x, self.k) - value = tf.matmul(x, self.v) - gate = tf.nn.sigmoid(tf.matmul(x, self.d_g)) + # Encoder: Dense layers to compress 1024 features to embedding_dim + self.encoder = keras.Sequential([ + layers.Dense(512, activation='relu', input_shape=self.input_dim), + layers.Dense(256, activation='relu'), + layers.Dense(128, activation='relu'), + layers.Dense(self.embedding_dim, activation='linear') + ]) - qk = tf.einsum('hbni,hbki->hbnk', query, key) # (H,B,N,N) - qk *= self.scale - att = tf.nn.softmax(qk) - out = tf.matmul(att, value) # (H,B,N,N)@(H,B,N,E)->(H,B,N,E) - out = tf.multiply(out, gate) # gate + # SCQ Layer for quantization + self.scq_layer = SCQ_layer(num_embed=self.num_embeddings, dim_embed=self.embedding_dim, + beta_loss=self.commitment_beta, **scq_params) - out = tf.transpose(out, perm=[1, 2, 3, 0]) - b, h, n, o = tf.shape(out)[0], tf.shape(out)[-1], tf.shape(out)[1], tf.shape(out)[2] - out = tf.reshape(out, [b, n, h * o]) # (B, N, H * E) - out = tf.matmul(out, self.d_o) # (B,N,H*E)@(H*E,E)->(B,N,E) + # Initialize codebook if provided + if initial_codebook is not None: + # Ensure the shape matches + expected_shape = (self.embedding_dim, self.num_embeddings) + if initial_codebook.shape == expected_shape: + self.scq_layer.embed_w.assign(initial_codebook) + else: + tf.print( + f"Warning: Initial codebook shape {initial_codebook.shape} doesn't match expected shape {expected_shape}. Using default initialization.") - return out + # Decoder: Dense layers to reconstruct from embedding_dim to 1024 features + self.decoder = keras.Sequential([ + layers.Dense(128, activation='relu', input_shape=(self.embedding_dim,)), + layers.Dense(256, activation='relu'), + layers.Dense(512, activation='relu'), + layers.Dense(self.input_dim[0], activation='linear') # output shape (B, 1024) + ]) - def layernorm(self, inputs): - """ - Custom layer normalization over the last dimension. - """ - mean = tf.reduce_mean(inputs, axis=-1, keepdims=True) - variance = tf.math.reduce_variance(inputs, axis=-1, keepdims=True) - normed = (inputs - mean) / tf.sqrt(variance + self.epsilon) - return self.gamma * normed + self.beta + # Loss trackers + self.total_loss_tracker = keras.metrics.Mean(name="total_loss") + self.recon_loss_tracker = keras.metrics.Mean(name="recon_loss") + self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss") + self.perplexity_tracker = keras.metrics.Mean(name="perplexity") + + @property + def metrics(self): + return [ + self.total_loss_tracker, + self.recon_loss_tracker, + self.vq_loss_tracker, + self.perplexity_tracker, + ] + + + def call(self, inputs, training=False): + if isinstance(inputs, tuple): + x, labels = inputs + else: + x, labels = inputs, None + # Encoder: Compress input to embedding space + x = self.encoder(x) # (batch_size, embedding_dim) + x = tf.expand_dims(x, axis=1) # (batch_size, 1, embedding_dim) + # SCQ: Quantize the embedding + Zq, out_P_proj, vq_loss, perplexity = self.scq_layer(x) + # Decoder: Reconstruct from quantized embedding + y = tf.squeeze(Zq, axis=1) # (batch_size, embedding_dim) + output = self.decoder(y) # (batch_size, 1024) + return output, Zq, out_P_proj, vq_loss, perplexity + + # # Method to get just the latent sequence and one-hot encodings for inference + def encode_(self, inputs): + if isinstance(inputs, tuple): + x, labels = inputs + else: + x, labels = inputs, None + # Encoder: Compress input to embedding space + x = self.encoder(x) # (batch_size, embedding_dim) + x = tf.expand_dims(x, axis=1) # (batch_size, 1, embedding_dim) + # SCQ: Quantize the embedding + Zq, out_P_proj, vq_loss, perplexity = self.scq_layer(x) + return Zq, out_P_proj, vq_loss, perplexity + + def train_step(self, data): + # unpack features and (optional) labels + if isinstance(data, tuple): + x, labels = data + else: + x, labels = data, None + + with tf.GradientTape() as tape: + reconstruction, Zq, _, vq_loss, perplexity = self(x, training=True) + + # Reconstruction loss + recon_loss = tf.reduce_mean(tf.math.squared_difference(x, reconstruction)) + + # Feature matching loss + with tape.stop_recording(): + _, mid_features, _, _, _ = self(reconstruction, training=False) + orig_mid_features = Zq + feature_loss = tf.reduce_mean(tf.math.squared_difference(orig_mid_features, mid_features)) + + # Total loss + total_loss = recon_loss + vq_loss + 0.1 * feature_loss + + # Usage penalty + usage_penalty = tf.cond( + perplexity < self.num_embeddings * 0.5, + lambda: 0.5 * (self.num_embeddings * 0.5 - perplexity), + lambda: 0.0 + ) + total_loss += usage_penalty + + grads = tape.gradient(total_loss, self.trainable_variables) + grads, _ = tf.clip_by_global_norm(grads, 5.0) + self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) + + # Update metrics + self.total_loss_tracker.update_state(total_loss) + self.recon_loss_tracker.update_state(recon_loss) + self.vq_loss_tracker.update_state(vq_loss) + self.perplexity_tracker.update_state(perplexity) + + # Log perplexity every 100 steps + step = self.optimizer.iterations + tf.cond( + tf.equal(tf.math.floormod(step, 100), 0), + lambda: tf.print("Step:", step, "Perplexity:", perplexity, "Target:", self.num_embeddings * 0.5), + lambda: tf.no_op() + ) + return {m.name: m.result() for m in self.metrics} -class Decoder(layers.Layer): - def __init__(self, dim_input, dim_embed, dim_output, heads=4, names='decoder', **kwargs): + def test_step(self, data): + if isinstance(data, tuple): + x = data[0] + y = data[1] if len(data) > 1 else data[0] + else: + x = data + + reconstruction, quantized, _, vq_loss, perplexity = self(x, training=False) + recon_loss = tf.reduce_mean(tf.math.squared_difference(x, reconstruction)) + total_loss = recon_loss + vq_loss + self.commitment_beta * tf.reduce_mean( + tf.math.squared_difference(x, reconstruction)) + + self.total_loss_tracker.update_state(total_loss) + self.recon_loss_tracker.update_state(recon_loss) + self.vq_loss_tracker.update_state(vq_loss) + self.perplexity_tracker.update_state(perplexity) + + return {m.name: m.result() for m in self.metrics} + + +########### Mixture of Experts (MoE) model ########### +class SparseDispatcher: + """Helper for dispatching inputs to experts and combining expert outputs.""" + def __init__(self, num_experts, gates): + # gates: [batch, num_experts] float tensor + self.num_experts = num_experts + self.gates = gates # shape [B, E] + # Find nonzero gate entries + indices = tf.where(gates > 0) + # Sort by expert_idx then batch_idx for consistency + sort_order = tf.argsort(indices[:,1] * tf.cast(tf.shape(gates)[0], indices.dtype) + indices[:,0]) + sorted_indices = tf.gather(indices, sort_order) + self.batch_index = sorted_indices[:,0] + self.expert_index = sorted_indices[:,1] + # Count number of samples per expert + self.part_sizes = tf.reduce_sum(tf.cast(gates > 0, tf.int32), axis=0) + # Extract the nonzero gate values + self.nonzero_gates = tf.gather_nd(gates, sorted_indices) + + def dispatch(self, inputs): + """inputs: [batch, ...] -> Returns list of [num_samples_i, ...]""" + inputs_expanded = tf.gather(inputs, self.batch_index) + parts = tf.split(inputs_expanded, self.part_sizes, axis=0) + return parts + + def combine(self, expert_outputs, multiply_by_gates=True): + """expert_outputs: list of [num_samples_i, output_dim] -> Returns [batch, output_dim]""" + stitched = tf.concat(expert_outputs, axis=0) + if multiply_by_gates: + stitched = stitched * tf.expand_dims(self.nonzero_gates, axis=1) + batch_size = tf.shape(self.gates)[0] + combined = tf.math.unsorted_segment_sum(stitched, self.batch_index, batch_size) + return combined + + +class Expert(layers.Layer): + """A binary prediction expert with added complexity using GLU.""" + + def __init__(self, input_dim, hidden_dim, output_dim=1): super().__init__() - self.dim_input = dim_input - self.dim_embed = dim_embed - self.dim_output = dim_output - self.heads = heads - self.names = names - self.epsilon = 1e-5 - self.scale = 1 / tf.sqrt(tf.cast(dim_embed, dtype=tf.float32)) + self.fc1 = layers.Dense(hidden_dim, activation='relu', input_shape=(input_dim,)) + self.fc2 = layers.Dense(output_dim) - def build(self, input_shape): - self.w_in = self.add_weight(shape=(self.dim_input, self.dim_embed), trainable=True, initializer='random_normal', - name=f'w_in_{self.names}') - self.b_in = self.add_weight(shape=(self.dim_embed,), trainable=True, initializer='zeros', - name=f'b_in_{self.names}') - # Learnable scale and bias for layer normalization. - self.gamma = self.add_weight(shape=(self.dim_embed,), initializer="ones", trainable=True, - name=f"ln_gamma_{self.names}") - self.beta = self.add_weight(shape=(self.dim_embed,), initializer="zeros", trainable=True, - name=f"ln_beta_{self.names}") - # attetnion - self.q = self.add_weight(shape=(self.heads, 1, self.dim_embed, self.dim_embed), trainable=True, - initializer='random_normal', name=f'query_{self.names}') - self.k = self.add_weight(shape=(self.heads, 1, self.dim_embed, self.dim_embed), trainable=True, - initializer='random_normal', name=f'key_{self.names}') - self.v = self.add_weight(shape=(self.heads, 1, self.dim_embed, self.dim_embed), trainable=True, - initializer='random_normal', name=f'value_{self.names}') - # project out - self.w_out = self.add_weight(shape=(int(self.heads * self.dim_embed), self.dim_output), trainable=True, - initializer='random_normal', name=f'w_out_{self.names}') - self.b_out = self.add_weight(shape=(self.dim_output,), trainable=True, initializer='zeros', - name=f'b_out_{self.names}') + def call(self, x): + x = self.fc1(x) + x = self.fc2(x) + return tf.nn.sigmoid(x) - def call(self, inputs): - Zq = inputs # (B,N,embed_dim), #(B,N,1), (1,) - x = tf.matmul(Zq, self.w_in) + self.b_in - x = self.layernorm(x) # (B,N,dim_embed) - # attention - query = tf.matmul(x, self.q) # (B,N,dim_emned)@(H,1,dim_embed,dim_embed)->(H,B,N,dim_embed) - key = tf.matmul(x, self.k) - value = tf.matmul(x, self.v) - qk = tf.einsum('hbni,hbji->hbnj', query, key) # (H,B,N,N) - qk *= self.scale - att = tf.nn.softmax(qk) - out = tf.matmul(att, value) # (H,B,N,N)@(H,B,N,dim_embed) - # out projection - out = tf.transpose(out, perm=[1, 2, 3, 0]) - b, h, n, o = tf.shape(out)[0], tf.shape(out)[-1], tf.shape(out)[1], tf.shape(out)[2] - out = tf.reshape(out, [b, n, h * o]) # (B, N, H * E) - out = tf.matmul(out, self.w_out) + self.b_out # (B,N,H*E)@(H*E,O)->(B,N,O) - out = tf.nn.softmax(out) - return out +# class Expert(layers.Layer): +# """Enhanced binary prediction expert with dropout and more layers.""" +# def __init__(self, input_dim, hidden_dim, output_dim=1, dropout_rate=0.2): +# super().__init__() +# self.fc1 = layers.Dense(hidden_dim, activation='relu', input_shape=(input_dim,)) +# self.dropout1 = layers.Dropout(dropout_rate) +# self.fc2 = layers.Dense(hidden_dim // 2, activation='relu') +# self.dropout2 = layers.Dropout(dropout_rate) +# self.fc3 = layers.Dense(output_dim) +# +# def call(self, x, training=False): +# x = self.fc1(x) +# x = self.dropout1(x, training=training) +# x = self.fc2(x) +# x = self.dropout2(x, training=training) +# x = self.fc3(x) +# return tf.nn.sigmoid(x) - def layernorm(self, inputs): - """ - Custom layer normalization over the last dimension. - """ - mean = tf.reduce_mean(inputs, axis=-1, keepdims=True) - variance = tf.math.reduce_variance(inputs, axis=-1, keepdims=True) - normed = (inputs - mean) / tf.sqrt(variance + self.epsilon) - return self.gamma * normed + self.beta +class MixtureOfExperts(layers.Layer): + """Mixture of Experts layer with optional learned gating network.""" -class SCQ_model(tf.keras.models.Model): - def __init__(self, general_embed_dim=128, codebook_dim=16, codebook_num=64, - descrete_loss=False, heads=4, names='SCQ_model', - weight_recon=1, weight_vq=1, **kwargs): + def __init__(self, input_dim, hidden_dim, num_experts, use_provided_gates=True, + gating_hidden_dim=64, top_k=None): super().__init__() - self.general_embed_dim = general_embed_dim - self.codebook_dim = codebook_dim - self.codebook_num = codebook_num - self.descrete_loss = descrete_loss - self.heads = heads - self.names = names - self.weight_recon = tf.cast(weight_recon, tf.float32) - self.weight_vq = tf.cast(weight_vq, tf.float32) - # define model - self.encoder = Encoder(dim_input=21, dim_embed=self.general_embed_dim, heads=self.heads) + self.num_experts = num_experts + self.use_provided_gates = use_provided_gates + self.top_k = top_k # If set, only route to top_k experts + self.experts = [Expert(input_dim, hidden_dim) for _ in range(num_experts)] - self.scq = SCQ(dim_input=self.general_embed_dim, dim_embed=self.codebook_dim, - num_embed=self.codebook_num, descrete_loss=self.descrete_loss) + # Add a learned gating network if we might use it + if not use_provided_gates: + self.gating_network = tf.keras.Sequential([ + layers.Dense(gating_hidden_dim, activation='relu', input_shape=(input_dim,)), + layers.Dense(num_experts, activation='softmax') + ]) - self.decoder = Decoder(dim_input=self.codebook_dim, dim_embed=self.general_embed_dim, - dim_output=21, heads=self.heads) + def call(self, inputs, training=False): + if isinstance(inputs, tuple) and len(inputs) == 2: + x, gates = inputs + else: + x = inputs + gates = None - # Loss trackers - self.total_loss_tracker = tf.keras.metrics.Mean(name='total_loss') - self.recon_loss_tracker = tf.keras.metrics.Mean(name='recon_loss') - self.vq_loss_tracker = tf.keras.metrics.Mean(name='vq_loss') - self.perplexity_tracker = tf.keras.metrics.Mean(name='perplexity') + # Determine the gates to use + if gates is None or not self.use_provided_gates: + if hasattr(self, 'gating_network'): + # Use learned gating if available + gates = self.gating_network(x) + else: + # Fall back to uniform distribution + batch_size = tf.shape(x)[0] + gates = tf.ones([batch_size, self.num_experts]) / self.num_experts - @property - def metrics(self): - return [ - self.total_loss_tracker, - self.recon_loss_tracker, - self.vq_loss_tracker, - self.perplexity_tracker + # Apply top-k gating if configured + if self.top_k is not None and self.top_k < self.num_experts: + # Get values and indices of top-k gate values + _, top_k_indices = tf.math.top_k(gates, k=self.top_k) + # Create a mask for the top_k gates (1 for top-k, 0 for others) + mask = tf.reduce_sum( + tf.one_hot(top_k_indices, depth=self.num_experts), + axis=1 + ) + # Zero out non-top-k gates and renormalize + gates = gates * mask + gates = gates / (tf.reduce_sum(gates, axis=-1, keepdims=True) + 1e-10) + + # Create dispatcher with these gates + dispatcher = SparseDispatcher(self.num_experts, gates) + + # Dispatch inputs to experts + expert_inputs = dispatcher.dispatch(x) + + expert_outputs = [ + expert(inp) + for expert, inp in zip(self.experts, expert_inputs) ] - def train_step(self, x): - with tf.GradientTape() as tape: - encoded = self.encoder(x) - Zq, out_P_proj, vq_loss = self.scq(encoded) # (B,N,E),(B,N,K),(1,) - decoded = self.decoder(Zq) - - y = tf.clip_by_value(decoded, 1e-10, 1 - (1e-10)) - recon_loss = -tf.reduce_sum(x * tf.math.log(y), axis=-1) - recon_loss = tf.reduce_mean(recon_loss) - - final_loss = self.weight_recon * recon_loss + self.weight_vq * vq_loss - - vars = self.encoder.trainable_weights + self.scq.trainable_weights + self.decoder.trainable_weights - grads = tape.gradient(final_loss, vars) - self.optimizer.apply_gradients(zip(grads, vars)) - self.perplexity = self.compute_perplexity(out_P_proj) - # Loss Tracking - self.total_loss_tracker.update_state(final_loss) - self.recon_loss_tracker.update_state(self.weight_recon * recon_loss) - self.vq_loss_tracker.update_state(self.weight_vq * vq_loss) - self.perplexity_tracker.update_state(self.perplexity) + # Combine expert outputs + combined_outputs = dispatcher.combine(expert_outputs) return { - 'loss': tf.convert_to_tensor(self.total_loss_tracker.result()), - 'recon': tf.convert_to_tensor(self.recon_loss_tracker.result()), - 'vq': tf.convert_to_tensor(self.vq_loss_tracker.result()), - 'perplexity': tf.convert_to_tensor(self.perplexity_tracker.result()) + 'prediction': combined_outputs, + 'gates': gates, + 'expert_outputs': expert_outputs if training else None } +class MoEModel(tf.keras.Model): + """Model wrapping the MoE layer with optional learned gating.""" + def __init__(self, input_dim, hidden_dim, num_experts, use_provided_gates=True, + gating_hidden_dim=64, top_k=None): + super().__init__() + self.use_provided_gates = use_provided_gates + self.moe_layer = MixtureOfExperts( + input_dim, + hidden_dim, + num_experts, + use_provided_gates=use_provided_gates, + gating_hidden_dim=gating_hidden_dim, + top_k=top_k + ) + def call(self, inputs, training=False): - encoded = self.encoder(inputs) - Zq, out_P_proj, vq_loss = self.scq(encoded) # Ignore the loss output - decoded = self.decoder(Zq) - return decoded, Zq, out_P_proj # Return final output + # Handle both cases: with and without provided gates + moe_outputs = self.moe_layer(inputs, training=training) + return moe_outputs if not training else moe_outputs['prediction'] - def compute_perplexity(slef, out_P_proj): - p_j = tf.reduce_mean(out_P_proj, axis=[0, 1]) # (B, N, K) -> (K,) - p_j = tf.clip_by_value(p_j, 1e-10, 1 - (1e-9)) - p_j = p_j / tf.reduce_sum(p_j) # Normalize to ensure sum to 1 - entropy = -tf.reduce_sum(p_j * tf.math.log(p_j) / tf.math.log(2.0)) # Entropy: -sum(p_j * log2(p_j)) - perplexity = tf.pow(2.0, entropy) # Perplexity: 2^entropy - return perplexity + def train_step(self, data): + if isinstance(data, tuple): + x, y = data + else: + raise ValueError("Training data must include both inputs and labels") + + with tf.GradientTape() as tape: + # Forward pass - handle both input types + if isinstance(x, tuple) and len(x) == 2: + # Input includes features and gates + inputs, gates = x + outputs = self.moe_layer((inputs, gates), training=True) + else: + # Input is just features, gates will be computed by gating network + outputs = self.moe_layer(x, training=True) + + predictions = outputs['prediction'] + + # Compute loss + loss = self.compiled_loss(y, predictions, regularization_losses=self.losses) + + # Add expert load balancing loss + gates = outputs['gates'] + expert_usage = tf.reduce_mean(gates, axis=0) + load_balancing_loss = tf.reduce_sum(expert_usage * tf.math.log(expert_usage + 1e-10)) * 0.01 + total_loss = loss + load_balancing_loss + + # Compute gradients and update weights + gradients = tape.gradient(total_loss, self.trainable_variables) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + + # Update metrics + self.compiled_metrics.update_state(y, predictions) + + # Return metrics + results = {m.name: m.result() for m in self.metrics} + results.update({'loss': loss, 'load_balancing_loss': load_balancing_loss}) + return results + + def test_step(self, data): + if isinstance(data, tuple): + x, y = data + else: + raise ValueError("Test data must include both inputs and labels") + + # Forward pass - handle both input types + if isinstance(x, tuple) and len(x) == 2: + # Input includes features and gates + inputs, gates = x + outputs = self.moe_layer((inputs, gates), training=False) + else: + # Input is just features, gates will be computed by gating network + outputs = self.moe_layer(x, training=False) + + predictions = outputs['prediction'] + + # Compute loss + loss = self.compiled_loss(y, predictions, regularization_losses=self.losses) + + # Update metrics + self.compiled_metrics.update_state(y, predictions) + + # Return metrics + results = {m.name: m.result() for m in self.metrics} + results.update({'loss': loss}) + return results + + def predict(self, inputs): + if isinstance(inputs, tuple) and len(inputs) == 2: + # Input includes features and gates + inputs, gates = inputs + outputs = self.moe_layer((inputs, gates), training=False) + else: + # Input is just features, gates will be computed by gating network + outputs = self.moe_layer(inputs, training=False) + + predictions = outputs['prediction'] + return predictions + + +def visualize_dataset_analysis(features, labels, cluster_probs, method='pca', raw_dot_plot=False, feature_indices=None): + """ + Visualize the relationship between features, labels, and cluster probabilities. + + Args: + features: TensorFlow tensor containing feature vectors + labels: TensorFlow tensor containing labels + cluster_probs: TensorFlow tensor containing cluster probabilities + method: Dimensionality reduction method ('umap', 'tsne', or 'pca') + raw_dot_plot: If True, plot raw feature values directly + feature_indices: Indices of features to plot (tuple of 2 indices) + """ + import matplotlib.pyplot as plt + import numpy as np + + # Convert tensors to numpy arrays + features_np = features.numpy() + labels_np = labels.numpy().flatten() + cluster_probs_np = cluster_probs.numpy() + + # Compute dominant cluster for each sample + dominant_clusters = np.argmax(cluster_probs_np, axis=1) + + if raw_dot_plot: + # Plot raw feature values directly without dimensionality reduction + feature_indices = feature_indices or (0, 1) # Default to first two features + + # Create figure with two subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + + # Scatter by true label + scatter1 = ax1.scatter( + features_np[:, feature_indices[0]], + features_np[:, feature_indices[1]], + c=labels_np, + cmap='viridis', + s=10, + alpha=0.6 + ) + ax1.set_title(f'Raw Features: Colored by Label') + ax1.set_xlabel(f'Feature {feature_indices[0]}') + ax1.set_ylabel(f'Feature {feature_indices[1]}') + cbar1 = plt.colorbar(scatter1, ax=ax1) + cbar1.set_label('Label') + + # Scatter by dominant cluster + n_clusters = cluster_probs_np.shape[1] + cmap_clusters = plt.cm.get_cmap('tab20b', n_clusters) if n_clusters <= 20 else plt.cm.get_cmap('gist_ncar', n_clusters) + + scatter2 = ax2.scatter( + features_np[:, feature_indices[0]], + features_np[:, feature_indices[1]], + c=dominant_clusters, + cmap=cmap_clusters, + s=10, + alpha=0.6 + ) + ax2.set_title(f'Raw Features: Colored by Cluster') + ax2.set_xlabel(f'Feature {feature_indices[0]}') + ax2.set_ylabel(f'Feature {feature_indices[1]}') + cbar2 = plt.colorbar(scatter2, ax=ax2, ticks=range(n_clusters)) + cbar2.set_label('Cluster ID') + + plt.tight_layout() + plt.savefig(f'raw_feature_analysis.png', dpi=300, bbox_inches='tight') + plt.show() + return + + # Original dimensionality reduction visualization + try: + if method == 'umap': + import umap + reducer = umap.UMAP(n_neighbors=30, min_dist=0.1, n_components=2, random_state=42) + features_2d = reducer.fit_transform(features_np) + method_name = "UMAP" + elif method == 'tsne': + from sklearn.manifold import TSNE + reducer = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42) + features_2d = reducer.fit_transform(features_np) + method_name = "t-SNE" + else: # default to PCA + from sklearn.decomposition import PCA + reducer = PCA(n_components=2) + features_2d = reducer.fit_transform(features_np) + method_name = "PCA" + except ImportError: + # Fall back to PCA if the requested method is not available + from sklearn.decomposition import PCA + reducer = PCA(n_components=2) + features_2d = reducer.fit_transform(features_np) + method_name = "PCA" + print(f"Warning: {method} not available, falling back to PCA") + + # Prepare subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + + # Scatter by true label + scatter1 = ax1.scatter( + features_2d[:, 0], + features_2d[:, 1], + c=labels_np, + cmap='viridis', + s=10, + alpha=0.6 + ) + ax1.set_title(f'Samples by Label ({method_name})') + ax1.set_xlabel(f'{method_name} Component 1') + ax1.set_ylabel(f'{method_name} Component 2') + cbar1 = plt.colorbar(scatter1, ax=ax1) + cbar1.set_label('Label') + + # Scatter by dominant cluster with improved discrete colormap + n_clusters = cluster_probs_np.shape[1] + # Use a better colormap for distinguishing clusters + cmap_clusters = plt.cm.get_cmap('tab20b', n_clusters) if n_clusters <= 20 else plt.cm.get_cmap('gist_ncar', n_clusters) + + # Add confidence information - marker size based on max probability + max_probs = np.max(cluster_probs_np, axis=1) + marker_sizes = 10 + 40 * max_probs # Scale confidence to marker size + + scatter2 = ax2.scatter( + features_2d[:, 0], + features_2d[:, 1], + c=dominant_clusters, + cmap=cmap_clusters, + s=marker_sizes, + alpha=0.6, + edgecolors='none' + ) + ax2.set_title(f'Samples by Dominant Cluster ({method_name})') + ax2.set_xlabel(f'{method_name} Component 1') + ax2.set_ylabel(f'{method_name} Component 2') + cbar2 = plt.colorbar(scatter2, ax=ax2, ticks=range(n_clusters)) + cbar2.set_label('Cluster ID') + + plt.tight_layout() + plt.savefig(f'cluster_analysis_{method}.png', dpi=300, bbox_inches='tight') + plt.show() + + # Plot overall label and cluster probability distributions + fig, ax = plt.subplots(1, 2, figsize=(12, 4)) + + # Label distribution + unique, counts = np.unique(labels_np, return_counts=True) + ax[0].bar(unique.astype(int), counts, color='skyblue') + ax[0].set_xlabel('Label') + ax[0].set_ylabel('Count') + ax[0].set_title('Label Distribution') + + # Cluster probability distribution + cluster_prob_sums = np.sum(cluster_probs_np, axis=0) + ax[1].bar(np.arange(len(cluster_prob_sums)), cluster_prob_sums, color='salmon') + ax[1].set_xlabel('Cluster') + ax[1].set_ylabel('Sum of Probabilities') + ax[1].set_title('Cluster Probability Distribution') + + plt.tight_layout() + plt.savefig('dataset_distributions.png', dpi=300, bbox_inches='tight') + plt.show() + + # TODO show the whole dataset as a dot plot + # Plot the dataset as a dot plot of the first two features + plt.figure(figsize=(10, 8)) + plt.scatter(features_np[:, 0], features_np[:, 1], c=labels_np, cmap='coolwarm', + s=10, alpha=0.7, edgecolors='none') + plt.colorbar(label='Label') + plt.title('Dot Plot of First Two Features') + plt.xlabel('Feature 0') + plt.ylabel('Feature 1') + plt.grid(linestyle='--', alpha=0.6) + plt.savefig('feature_dot_plot.png', dpi=300, bbox_inches='tight') + plt.show() + + +# # Example usage +# if __name__ == "__main__": +# # Generate dummy dataset with clustered data and labels for training +# num_train_samples = 8000 +# num_test_samples = 2000 +# feature_dim = 64 +# num_clusters = 32 +# +# # Function to generate dataset with specified parameters +# def generate_dataset(num_samples, feature_dim, num_clusters, epsilon=0.1): +# # Compute samples per cluster +# base_count = num_samples // num_clusters +# counts = [base_count + (1 if i < num_samples % num_clusters else 0) for i in range(num_clusters)] +# +# features_list = [] +# labels_list = [] +# cluster_probs_list = [] +# distributions = ['normal', 'uniform', 'gamma', 'poisson'] +# +# for c in range(num_clusters): +# cluster_count = counts[c] +# n0 = cluster_count // 2 +# n1 = cluster_count - n0 +# dist = distributions[(2 * c) % len(distributions)] +# print(f"Cluster {c}: {dist} distribution, {n0} samples 0, {n1} samples 1") +# +# if dist == 'normal': +# features_0 = tf.random.normal([n0, feature_dim], mean=c, stddev=1.0) +# elif dist == 'uniform': +# features_0 = tf.random.uniform([n0, feature_dim], minval=0, maxval=1) +# elif dist == 'gamma': +# features_0 = tf.random.gamma([n0, feature_dim], alpha=2.0, beta=1.0) +# elif dist == 'poisson': +# features_0 = tf.cast(tf.random.poisson([n0, feature_dim], lam=3), tf.float32) +# else: +# features_0 = tf.random.normal([n0, feature_dim], mean=c, stddev=1.0) +# +# if dist == 'normal': +# features_1 = tf.random.normal([n1, feature_dim], mean=c+0.5, stddev=1.5) +# elif dist == 'uniform': +# features_1 = tf.random.uniform([n1, feature_dim], minval=1, maxval=2) +# elif dist == 'gamma': +# features_1 = tf.random.gamma([n1, feature_dim], alpha=5.0, beta=2.0) +# elif dist == 'poisson': +# features_1 = tf.cast(tf.random.poisson([n1, feature_dim], lam=6), tf.float32) +# else: +# features_1 = tf.random.normal([n1, feature_dim], mean=c+0.5, stddev=1.5) +# +# features_i = tf.concat([features_0, features_1], axis=0) +# labels_i = tf.concat([tf.zeros([n0, 1], tf.int32), tf.ones([n1, 1], tf.int32)], axis=0) +# features_list.append(features_i) +# labels_list.append(labels_i) +# +# # Generate random cluster probabilities per sample +# cluster_indices = tf.fill([cluster_count], c) +# lam_value = tf.maximum(tf.cast(c, tf.float32) + 1.0, 1.0) +# noise = tf.cast(tf.random.poisson([cluster_count, num_clusters], lam=lam_value), tf.float32) +# noise = noise + epsilon +# probs = noise / (tf.reduce_sum(noise, axis=1, keepdims=True)) +# alpha = tf.random.uniform([cluster_count, 1], minval=0.5, maxval=0.8) +# probs = (1 - alpha) * probs + alpha * tf.one_hot(cluster_indices, num_clusters) +# probs = probs / tf.reduce_sum(probs, axis=1, keepdims=True) +# cluster_probs_list.append(probs) +# +# features = tf.concat(features_list, axis=0) +# labels = tf.concat(labels_list, axis=0) +# cluster_probs = tf.concat(cluster_probs_list, axis=0) +# +# # Shuffle dataset +# indices = tf.random.shuffle(tf.range(tf.shape(features)[0])) +# features = tf.gather(features, indices) +# labels = tf.gather(labels, indices) +# cluster_probs = tf.gather(cluster_probs, indices) +# +# return features, labels, cluster_probs +# +# # Generate training dataset +# print("\nGenerating training dataset...") +# train_features, train_labels, train_cluster_probs = generate_dataset( +# num_train_samples, feature_dim, num_clusters) +# +# # Generate separate test dataset +# print("\nGenerating test dataset...") +# test_features, test_labels, test_cluster_probs = generate_dataset( +# num_test_samples, feature_dim, num_clusters, epsilon=1) +# +# print(f"\nTraining labels min: {tf.reduce_min(train_labels)}, max: {tf.reduce_max(train_labels)}") +# print(f"Training labels head: {train_labels[:5]}") +# print(f"Test labels min: {tf.reduce_min(test_labels)}, max: {tf.reduce_max(test_labels)}") +# print(f"Test labels head: {test_labels[:5]}") +# +# # Visualize training dataset +# print("\nVisualizing training dataset...") +# visualize_dataset_analysis(train_features, train_labels, train_cluster_probs, +# raw_dot_plot=False, method='pca', feature_indices=(0, 1)) +# +# # Create training dataset +# train_dataset = tf.data.Dataset.from_tensor_slices( +# ((train_features, train_cluster_probs), train_labels) +# ).shuffle(1000).batch(64) +# +# # Create test dataset (no need to shuffle extensively) +# test_dataset = tf.data.Dataset.from_tensor_slices( +# ((test_features, test_cluster_probs), test_labels) +# ).batch(64) +# +# # Print info about test dataset +# for features, labels in test_dataset.take(1): +# print(f"\nTest features shape: {features[0].shape}, Test labels shape: {labels.shape}") +# print(f"Test features head: {features[0][:3]}") +# print(f"Test labels head: {labels[:3]}") +# +# # Create and compile model +# model = MoEModel(feature_dim, hidden_dim=8, num_experts=num_clusters, use_provided_gates=True) +# model.compile( +# optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), +# loss=tf.keras.losses.BinaryCrossentropy(), +# metrics=['accuracy'], +# ) +# +# # Train model +# print("\nTraining model...") +# model.fit(train_dataset, epochs=20) +# +# # Evaluate on independent test set +# print("\nEvaluating on independent test set...") +# eval_results = model.evaluate(test_dataset) +# print(f"Evaluation results: {eval_results}") +# +# # Predict and analyze +# # test_batch = next(iter(test_dataset.take(1))) +# # predictions = model(test_batch[0], training=False) +# # print(f"Sample predictions shape: {predictions['prediction'].shape}") +# # print(f"Gate activations: {tf.reduce_mean(predictions['gates'], axis=0)}") +# +# # Test on a single sample +# for (feat, cluster_prob), label in test_dataset.unbatch().take(1): +# feat = tf.expand_dims(feat, 0) +# cluster_prob = tf.expand_dims(cluster_prob, 0) +# output = model((feat, cluster_prob), training=False) +# print("Single sample prediction:", output['prediction'].numpy()[0], "True label:", label.numpy()) +# print("Gate activations:", tf.reduce_mean(output['gates'], axis=0).numpy()) +# +# # Optional: Visualize test dataset +# print("\nVisualizing test dataset...") +# visualize_dataset_analysis(test_features, test_labels, test_cluster_probs, +# raw_dot_plot=False, method='pca', feature_indices=(0, 1)) + +class BinaryMLP(tf.keras.Model): + def __init__(self, input_dim=1024): + super(BinaryMLP, self).__init__() + # First hidden layer + self.dense1 = tf.keras.layers.Dense( + units=512, activation='relu', name='dense_1' + ) + self.dropout1 = tf.keras.layers.Dropout( + rate=0.5, name='dropout_1' + ) + # Second hidden layer + self.dense2 = tf.keras.layers.Dense( + units=256, activation='relu', name='dense_2' + ) + self.dropout2 = tf.keras.layers.Dropout( + rate=0.5, name='dropout_2' + ) + # Output layer for binary classification + self.output_layer = tf.keras.layers.Dense( + units=1, activation='sigmoid', name='output_layer' + ) + + def call(self, inputs, training=False): + # If inputs come in as (features, labels), unpack and ignore labels + if isinstance(inputs, (tuple, list)): + x, _ = inputs + else: + x = inputs + + x = self.dense1(x) + x = self.dropout1(x, training=training) + x = self.dense2(x) + x = self.dropout2(x, training=training) + return self.output_layer(x) + +# # Example usage +# if __name__ == "__main__": +# # Generate dummy dataset +# num_samples = 1000 +# input_dim = 1024 +# features = tf.random.normal((num_samples, input_dim)) +# labels = tf.random.uniform((num_samples,), minval=0, maxval=2, dtype=tf.int32) +# +# # Create and compile model +# model = BinaryMLP(input_dim=input_dim) +# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) +# +# # Train model +# model.fit(features, labels, epochs=5, batch_size=32) +# # Evaluate model +# loss, accuracy = model.evaluate(features, labels) +# print(f"Loss: {loss}, Accuracy: {accuracy}") +# # Predict +# predictions = model.predict(features) +# print(f"Predictions shape: {predictions.shape}") +# print(f"Predictions head: {predictions[:5]}") + + +class TransformerBlock(layers.Layer): + """Single Transformer encoder block.""" + + def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs): + super().__init__(**kwargs) + self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) + self.dropout1 = layers.Dropout(dropout_rate) + self.norm1 = layers.LayerNormalization(epsilon=1e-6) + + self.ffn = tf.keras.Sequential([ + layers.Dense(ff_dim, activation='gelu'), + layers.Dropout(dropout_rate), + layers.Dense(embed_dim), + layers.Dropout(dropout_rate), + ]) + self.norm2 = layers.LayerNormalization(epsilon=1e-6) + + def call(self, inputs, training=False): + # Self-attention block + attn_output = self.attn(inputs, inputs) + attn_output = self.dropout1(attn_output, training=training) + out1 = self.norm1(inputs + attn_output) + + # Feed-forward block + ffn_output = self.ffn(out1, training=training) + return self.norm2(out1 + ffn_output) -'''import tensorflow as tf +import tensorflow as tf +from tensorflow.keras import layers, Model + +class TransformerBlock(layers.Layer): + """Single Transformer encoder block.""" + def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs): + super().__init__(**kwargs) + self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) + self.dropout1 = layers.Dropout(dropout_rate) + self.norm1 = layers.LayerNormalization(epsilon=1e-6) + + self.ffn = tf.keras.Sequential([ + layers.Dense(ff_dim, activation='gelu'), + layers.Dropout(dropout_rate), + layers.Dense(embed_dim), + ]) + self.dropout2 = layers.Dropout(dropout_rate) + self.norm2 = layers.LayerNormalization(epsilon=1e-6) + + def call(self, x, training=False): + # Self-attention + residual + norm + attn_output = self.attn(x, x, training=training) + attn_output = self.dropout1(attn_output, training=training) + x = self.norm1(x + attn_output) + + # Feed-forward + residual + norm + ffn_output = self.ffn(x, training=training) + ffn_output = self.dropout2(ffn_output, training=training) + return self.norm2(x + ffn_output) + + +class TabularTransformer(Model): + """ + Transformer-based classifier for tabular data. + Applies a downsampling pool to reduce sequence length and avoid OOM. + """ + def __init__( + self, + input_dim=1024, + embed_dim=64, + num_heads=8, + ff_dim=256, + num_layers=4, + dropout_rate=0.1, + pool_size=4, + **kwargs + ): + super().__init__(**kwargs) + self.pool_size = pool_size + # Project each scalar feature to an embedding + self.feature_embedding = layers.Dense(embed_dim, name='feature_embedding') + # Downsample sequence length: 1024 -> 1024/pool_size + self.pool = layers.MaxPool1D(pool_size=pool_size, name='sequence_pool') + reduced_seq_len = input_dim // pool_size + + # Positional embeddings for reduced tokens + self.pos_embedding = self.add_weight( + name='pos_embedding', + shape=(1, reduced_seq_len, embed_dim), + initializer='random_normal' + ) + # Transformer encoder stack + self.transformer_blocks = [ + TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate) + for _ in range(num_layers) + ] + # Pool & classification head + self.global_pool = layers.GlobalAveragePooling1D(name='global_avg_pool') + self.dropout = layers.Dropout(dropout_rate, name='dropout_final') + self.classifier = layers.Dense(1, activation='sigmoid', name='output') + + def call(self, inputs, training=False): + # Unpack if inputs come as (features, labels) + if isinstance(inputs, (tuple, list)): + x, _ = inputs + else: + x = inputs + + # shape -> (batch, input_dim, 1) + x = tf.expand_dims(x, axis=-1) + # Embed features -> (batch, input_dim, embed_dim) + x = self.feature_embedding(x) + # Downsample tokens -> (batch, reduced_seq_len, embed_dim) + x = self.pool(x) + # Add positional embeddings + x = x + self.pos_embedding + + # Transformer encoder stack + for block in self.transformer_blocks: + x = block(x, training=training) + + # Pool over tokens -> (batch, embed_dim) + x = self.global_pool(x) + x = self.dropout(x, training=training) + return self.classifier(x) + +# # Example usage +# if __name__ == "__main__": +# num_samples = 1000 +# input_dim = 1024 +# features = tf.random.normal((num_samples, input_dim)) +# labels = tf.random.uniform((num_samples,), maxval=2, dtype=tf.int32) +# +# model = TabularTransformer(input_dim=input_dim, pool_size=4) +# model.build(input_shape=(None, input_dim)) +# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) +# model.summary() +# +# model.fit(features, labels, epochs=5, batch_size=32) +# loss, accuracy = model.evaluate(features, labels) +# print(f"Loss: {loss:.4f}, Accuracy: {accuracy:.4f}") +# # Predict +# predictions = model.predict(features) +# print(f"Predictions shape: {predictions.shape}") +# print(f"Predictions head: {predictions[:5]}") + + +class EmbeddingCNN(Model): + """ + 1D-CNN + MLP head for binary classification on fixed-length embeddings. + """ + + def __init__(self, input_dim, dropout_rate=0.5): + super().__init__(name='Embedding_CNN') + self.input_dim = input_dim + + # Reshape flat embedding to sequence (length=input_dim, channels=1) + self.reshape_layer = layers.Reshape((input_dim, 1), name='reshape') + + # Convolutional blocks + self.conv1 = layers.Conv1D(64, 5, padding='same', activation='relu', name='conv1') + self.pool1 = layers.MaxPool1D(2, name='pool1') + + self.conv2 = layers.Conv1D(128, 5, padding='same', activation='relu', name='conv2') + self.pool2 = layers.MaxPool1D(2, name='pool2') + + # Flatten and MLP head + self.flatten = layers.Flatten(name='flatten') + self.dense = layers.Dense(64, activation='relu', name='dense') + self.dropout = layers.Dropout(dropout_rate, name='dropout') + + # Final binary output + self.output_layer = layers.Dense(1, activation='sigmoid', name='output') + + def call(self, inputs, training=False): + # Unpack if inputs come as (features, labels) + if isinstance(inputs, (tuple, list)): + x, _ = inputs + else: + x = inputs + + # Now safe to reshape just the feature tensor + x = self.reshape_layer(x) + x = self.conv1(x) + x = self.pool1(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.flatten(x) + x = self.dense(x) + x = self.dropout(x, training=training) + return self.output_layer(x) + +# # # Example usage +# if __name__ == "__main__": +# num_samples = 1000 +# input_dim = 1024 +# features = tf.random.normal((num_samples, input_dim)) +# labels = tf.random.uniform((num_samples,), maxval=2, dtype=tf.int32) +# model = EmbeddingCNN(input_dim=input_dim, dropout_rate=0.5) +# model.build(input_shape=(None, input_dim)) +# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) +# model.summary() +# model.fit(features, labels, epochs=5, batch_size=32) +# loss, accuracy = model.evaluate(features, labels) +# print(f"Loss: {loss:.4f}, Accuracy: {accuracy:.4f}") +# # Predict +# predictions = model.predict(features) +# print(f"Predictions shape: {predictions.shape}") +# print(f"Predictions head: {predictions[:5]}") + + +# mixture_of_experts.py +""" +Mixture‑of‑Experts implementation supporting **soft** and **hard** clustering. + +* **Soft clustering** (probabilistic gates) is used **during inference** to fuse the + parameters of all experts into a *virtual* expert. +* **Hard clustering** (one‑hot gates) is used **during training**; each sample + activates exactly one expert so gradients flow only through the selected expert. + +Key ideas +--------- +1. **SparseDispatcher** routes inputs to experts and fuses outputs. It now accepts + either soft or hard gates transparently. +2. **Expert** is a light two‑layer perceptron ending in a sigmoid. +3. **MixtureOfExperts** + * consumes `inputs` **and** a *soft* clustering vector (`gates_soft`). + * converts to hard gates (`gates_hard`) with `tf.one_hot(tf.argmax(...))` when + `hard_gating=True`. + * during inference (`hard_gating=False`) the *soft* gates are used to create a + **parameter‑blended virtual expert**: every weight matrix and bias vector is + a convex combination of the corresponding parameters of the individual + experts. +4. **MoEModel** overrides `train_step` / `test_step` so that + * training → `hard_gating=True` + * inference → `hard_gating=False` + +The code is ready to run in a standard TensorFlow‑2 environment. +""" + + +class Expert(layers.Layer): + """A binary prediction expert with added complexity.""" + + def __init__(self, input_dim, hidden_dim, output_dim=1, dropout_rate=0.2): + super().__init__() + self.fc1 = layers.Dense(hidden_dim, activation='relu', input_shape=(input_dim,)) + self.dropout1 = layers.Dropout(dropout_rate) + self.fc2 = layers.Dense(hidden_dim // 2, activation='relu') + self.dropout2 = layers.Dropout(dropout_rate) + self.fc3 = layers.Dense(output_dim) + + def call(self, x, training=False): + x = self.fc1(x) + x = self.dropout1(x, training=training) + x = self.fc2(x) + x = self.dropout2(x, training=training) + x = self.fc3(x) + return tf.nn.sigmoid(x) + + +class EnhancedMixtureOfExperts(layers.Layer): + """ + Enhanced Mixture of Experts layer that uses cluster assignments. + + This implementation eliminates the need for a SparseDispatcher by: + - During training: Using hard clustering to train specific experts + - During inference: Using soft clustering to mix the experts' weights + """ + + def __init__(self, input_dim, hidden_dim, num_experts, output_dim=1, + use_hard_clustering=True, dropout_rate=0.2): + super().__init__() + self.num_experts = num_experts + self.use_hard_clustering = use_hard_clustering + self.output_dim = output_dim + + # Create n experts + self.experts = [ + Expert(input_dim, hidden_dim, output_dim, dropout_rate) + for _ in range(num_experts) + ] + + def convert_to_hard_clustering(self, soft_clusters): + """Convert soft clustering values to hard clustering (one-hot encoding)""" + # Get the index of the maximum value for each sample + hard_indices = tf.argmax(soft_clusters, axis=1) + # Convert to one-hot encoding + return tf.one_hot(hard_indices, depth=self.num_experts) + + def call(self, inputs, training=False): + # Unpack inputs + if isinstance(inputs, tuple) and len(inputs) == 2: + x, soft_cluster_probs = inputs + else: + raise ValueError("Inputs must include both features and clustering values") + + batch_size = tf.shape(x)[0] + + # Convert to hard clustering during training if requested + if training and self.use_hard_clustering: + clustering = self.convert_to_hard_clustering(soft_cluster_probs) + else: + clustering = soft_cluster_probs + + # Initialize output tensor + combined_output = tf.zeros([batch_size, self.output_dim]) + + # Process each expert + for i, expert in enumerate(self.experts): + # Get the weight for this expert for each sample in the batch + expert_weights = clustering[:, i:i + 1] # Shape: [batch_size, 1] + + # Only compute outputs for samples with non-zero weights + # to save computation during training with hard clustering + if training and self.use_hard_clustering: + # Find samples assigned to this expert + assigned_indices = tf.where(expert_weights[:, 0] > 0)[:, 0] + + if tf.size(assigned_indices) > 0: + # Get assigned samples + assigned_x = tf.gather(x, assigned_indices) + + # Get expert output for assigned samples + expert_output = expert(assigned_x, training=training) + + # Use scatter_nd to place results back into full batch tensor + indices = tf.expand_dims(assigned_indices, axis=1) + updates = expert_output + combined_output += tf.scatter_nd(indices, updates, [batch_size, self.output_dim]) + else: + # During inference or when using soft clustering: + # Compute expert output for all samples + expert_output = expert(x, training=training) + + # Weight the output by the clustering values + weighted_output = expert_output * expert_weights + + # Add to combined output + combined_output += weighted_output + + return combined_output + + +class EnhancedMoEModel(tf.keras.Model): + """Complete model wrapping the Enhanced MoE layer""" + + def __init__(self, input_dim, hidden_dim, num_experts, output_dim=1, + use_hard_clustering=True, dropout_rate=0.2): + super().__init__() + self.use_hard_clustering = use_hard_clustering + self.moe_layer = EnhancedMixtureOfExperts( + input_dim, + hidden_dim, + num_experts, + output_dim=output_dim, + use_hard_clustering=use_hard_clustering, + dropout_rate=dropout_rate + ) + + def call(self, inputs, training=False): + return self.moe_layer(inputs, training=training) + + def train_step(self, data): + if isinstance(data, tuple) and len(data) == 2: + # Unpack the data + x, y = data + + # Ensure x contains both inputs and clustering values + if not (isinstance(x, tuple) and len(x) == 2): + raise ValueError("Input must be a tuple of (features, clustering)") + else: + raise ValueError("Training data must include both inputs and labels") + + # Unpack inputs + inputs, soft_cluster_vector = x + + with tf.GradientTape() as tape: + # Forward pass + predictions = self(x, training=True) + + # Compute loss + loss = self.compiled_loss(y, predictions, regularization_losses=self.losses) + + # Add expert load balancing loss if using soft clustering + if not self.use_hard_clustering: + expert_usage = tf.reduce_mean(soft_cluster_vector, axis=0) + load_balancing_loss = tf.reduce_sum(expert_usage * tf.math.log(expert_usage + 1e-10)) * 0.01 + loss += load_balancing_loss + + # Compute gradients and update weights + gradients = tape.gradient(loss, self.trainable_variables) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + + # Update metrics + self.compiled_metrics.update_state(y, predictions) + + # Return metrics + results = {m.name: m.result() for m in self.metrics} + results.update({'loss': loss}) + return results + + def test_step(self, data): + if isinstance(data, tuple) and len(data) == 2: + # Unpack the data + x, y = data + + # Ensure x contains both inputs and clustering values + if not (isinstance(x, tuple) and len(x) == 2): + raise ValueError("Input must be a tuple of (features, clustering)") + else: + raise ValueError("Test data must include both inputs and labels") + + # Forward pass (always use soft clustering for inference) + original_hard_clustering = self.moe_layer.use_hard_clustering + self.moe_layer.use_hard_clustering = False + predictions = self(x, training=False) + self.moe_layer.use_hard_clustering = original_hard_clustering + + # Compute loss + loss = self.compiled_loss(y, predictions, regularization_losses=self.losses) + + # Update metrics + self.compiled_metrics.update_state(y, predictions) + + # Return metrics + results = {m.name: m.result() for m in self.metrics} + results.update({'loss': loss}) + return results + + +# if __name__ == "__main__": +# # Generate dummy dataset with clustered data and labels for training +# num_train_samples = 8000 +# num_test_samples = 2000 +# feature_dim = 64 +# num_clusters = 32 +# +# # Function to generate dataset with specified parameters +# def generate_dataset(num_samples, feature_dim, num_clusters, epsilon=0.1): +# # Compute samples per cluster +# base_count = num_samples // num_clusters +# counts = [base_count + (1 if i < num_samples % num_clusters else 0) for i in range(num_clusters)] +# +# features_list = [] +# labels_list = [] +# cluster_probs_list = [] +# distributions = ['normal', 'uniform', 'gamma', 'poisson'] +# +# for c in range(num_clusters): +# cluster_count = counts[c] +# n0 = cluster_count // 2 +# n1 = cluster_count - n0 +# dist = distributions[(2 * c) % len(distributions)] +# print(f"Cluster {c}: {dist} distribution, {n0} samples 0, {n1} samples 1") +# +# if dist == 'normal': +# features_0 = tf.random.normal([n0, feature_dim], mean=c, stddev=1.0) +# elif dist == 'uniform': +# features_0 = tf.random.uniform([n0, feature_dim], minval=0, maxval=1) +# elif dist == 'gamma': +# features_0 = tf.random.gamma([n0, feature_dim], alpha=2.0, beta=1.0) +# elif dist == 'poisson': +# features_0 = tf.cast(tf.random.poisson([n0, feature_dim], lam=3), tf.float32) +# else: +# features_0 = tf.random.normal([n0, feature_dim], mean=c, stddev=1.0) +# +# if dist == 'normal': +# features_1 = tf.random.normal([n1, feature_dim], mean=c+0.5, stddev=1.5) +# elif dist == 'uniform': +# features_1 = tf.random.uniform([n1, feature_dim], minval=1, maxval=2) +# elif dist == 'gamma': +# features_1 = tf.random.gamma([n1, feature_dim], alpha=5.0, beta=2.0) +# elif dist == 'poisson': +# features_1 = tf.cast(tf.random.poisson([n1, feature_dim], lam=6), tf.float32) +# else: +# features_1 = tf.random.normal([n1, feature_dim], mean=c+0.5, stddev=1.5) +# +# features_i = tf.concat([features_0, features_1], axis=0) +# labels_i = tf.concat([tf.zeros([n0, 1], tf.int32), tf.ones([n1, 1], tf.int32)], axis=0) +# features_list.append(features_i) +# labels_list.append(labels_i) +# +# # Generate random cluster probabilities per sample +# cluster_indices = tf.fill([cluster_count], c) +# lam_value = tf.maximum(tf.cast(c, tf.float32) + 1.0, 1.0) +# noise = tf.cast(tf.random.poisson([cluster_count, num_clusters], lam=lam_value), tf.float32) +# noise = noise + epsilon +# probs = noise / (tf.reduce_sum(noise, axis=1, keepdims=True)) +# alpha = tf.random.uniform([cluster_count, 1], minval=0.5, maxval=0.8) +# probs = (1 - alpha) * probs + alpha * tf.one_hot(cluster_indices, num_clusters) +# probs = probs / tf.reduce_sum(probs, axis=1, keepdims=True) +# cluster_probs_list.append(probs) +# +# features = tf.concat(features_list, axis=0) +# labels = tf.concat(labels_list, axis=0) +# cluster_probs = tf.concat(cluster_probs_list, axis=0) +# +# # Shuffle dataset +# indices = tf.random.shuffle(tf.range(tf.shape(features)[0])) +# features = tf.gather(features, indices) +# labels = tf.gather(labels, indices) +# cluster_probs = tf.gather(cluster_probs, indices) +# +# return features, labels, cluster_probs +# +# # Generate training dataset +# print("\nGenerating training dataset...") +# train_features, train_labels, train_cluster_probs = generate_dataset( +# num_train_samples, feature_dim, num_clusters) +# +# # Generate separate test dataset +# print("\nGenerating test dataset...") +# test_features, test_labels, test_cluster_probs = generate_dataset( +# num_test_samples, feature_dim, num_clusters, epsilon=1) +# +# print(f"\nTraining labels min: {tf.reduce_min(train_labels)}, max: {tf.reduce_max(train_labels)}") +# print(f"Training labels head: {train_labels[:5]}") +# print(f"Test labels min: {tf.reduce_min(test_labels)}, max: {tf.reduce_max(test_labels)}") +# print(f"Test labels head: {test_labels[:5]}") +# +# # Visualize training dataset +# print("\nVisualizing training dataset...") +# visualize_dataset_analysis(train_features, train_labels, train_cluster_probs, +# raw_dot_plot=False, method='pca', feature_indices=(0, 1)) +# +# # Create training dataset +# train_dataset = tf.data.Dataset.from_tensor_slices( +# ((train_features, train_cluster_probs), train_labels) +# ).shuffle(1000).batch(64) +# +# # Create test dataset (no need to shuffle extensively) +# test_dataset = tf.data.Dataset.from_tensor_slices( +# ((test_features, test_cluster_probs), test_labels) +# ).batch(64) +# +# # Print info about test dataset +# for features, labels in test_dataset.take(1): +# print(f"\nTest features shape: {features[0].shape}, Test labels shape: {labels.shape}") +# print(f"Test features head: {features[0][:3]}") +# print(f"Test labels head: {labels[:3]}") +# +# # Create and compile model +# model = EnhancedMoEModel(feature_dim, hidden_dim=8, num_experts=num_clusters, use_hard_clustering=True) +# model.compile( +# optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), +# loss=tf.keras.losses.BinaryCrossentropy(), +# metrics=['accuracy'], +# ) +# +# # Train model +# print("\nTraining model...") +# model.fit(train_dataset, epochs=20) +# +# # Evaluate on independent test set +# print("\nEvaluating on independent test set...") +# eval_results = model.evaluate(test_dataset) +# print(f"Evaluation results: {eval_results}") +# +# # Predict and analyze +# # test_batch = next(iter(test_dataset.take(1))) +# # predictions = model(test_batch[0], training=False) +# # print(f"Sample predictions shape: {predictions['prediction'].shape}") +# # print(f"Gate activations: {tf.reduce_mean(predictions['gates'], axis=0)}") +# +# # Test on a single sample +# for (feat, cluster_prob), label in test_dataset.unbatch().take(1): +# feat = tf.expand_dims(feat, 0) +# cluster_prob = tf.expand_dims(cluster_prob, 0) +# output = model((feat, cluster_prob), training=False) +# print("Single sample prediction:", output.numpy()[0], "True label:", label.numpy()) + + +# ------------------------------------------------------------------------------- # +# manual implementations +# ----------------------------------------------------------------------------- # +class AttentionLayer(keras.layers.Layer): + """ + Custom multi-head attention layer supporting self- and cross-attention. + + Args: + input_dim (int): Input feature dimension. + output_dim (int): Output feature dimension per head. + type (str): 'self' or 'cross'. + heads (int): Number of attention heads. + resnet (bool): Whether to use residual connection. + return_att_weights (bool): Whether to return attention weights. + name (str): Name for weight scopes. + epsilon (float): Epsilon for layer normalization. + gate (bool): Whether to use gating mechanism. + """ + + def __init__(self, input_dim, output_dim, type, heads=4, + resnet=True, return_att_weights=False, name='attention', + epsilon=1e-6, gate=True): + super().__init__(name=name) + assert isinstance(input_dim, int) and isinstance(output_dim, int) + assert type in ['self', 'cross'] + if resnet: + assert input_dim == output_dim + + self.input_dim = input_dim + self.output_dim = output_dim + self.type = type + self.heads = heads + self.resnet = resnet + self.return_att_weights = return_att_weights + self.epsilon = epsilon + self.gate = gate + + self.q = self.add_weight(shape=(heads, input_dim, output_dim), + initializer='random_normal', trainable=True, name=f'q_{name}') + self.k = self.add_weight(shape=(heads, input_dim, output_dim), + initializer='random_normal', trainable=True, name=f'k_{name}') + self.v = self.add_weight(shape=(heads, input_dim, output_dim), + initializer='random_normal', trainable=True, name=f'v_{name}') + if gate: + self.g = self.add_weight(shape=(heads, input_dim, output_dim), + initializer='random_uniform', trainable=True, name=f'gate_{name}') + self.norm = layers.LayerNormalization(epsilon=epsilon, name=f'ln_{name}') + self.norm_out = layers.LayerNormalization(epsilon=epsilon, name=f'ln_out_{name}') + if resnet: + self.norm_resnet = layers.LayerNormalization(epsilon=epsilon, name=f'ln_resnet_{name}') + self.out_w = self.add_weight(shape=(output_dim * heads, output_dim), + initializer='random_normal', trainable=True, name=f'outw_{name}') + self.out_b = self.add_weight(shape=(output_dim,), initializer='zeros', + trainable=True, name=f'outb_{name}') + self.scale = 1.0 / tf.math.sqrt(tf.cast(output_dim, tf.float32)) + + def call(self, x, context=None, mask=None): + """ + Args: + x: Tensor of shape (B, N, D) + context: Optional tensor (B, M, D) for cross-attention + mask: Optional boolean mask of shape (B, N) or (B, N, 1) + """ + # Auto-generate padding mask if not provided (based on all-zero tokens) + if mask is None: + mask = tf.reduce_sum(tf.abs(x), axis=-1) > 0 # shape: (B, N) + mask = tf.cast(mask, tf.float32) # shape: (B, N) + + x_norm = self.norm(x) + if self.type == 'self': + q_input = k_input = v_input = x_norm + mask_k = mask_q = mask + else: + assert context is not None, "context is required for cross-attention" + context_norm = self.norm(context) + q_input = x_norm + k_input = v_input = context_norm + mask_q = tf.cast(tf.reduce_sum(tf.abs(x), axis=-1) > 0, tf.float32) + mask_k = tf.cast(tf.reduce_sum(tf.abs(context), axis=-1) > 0, tf.float32) + + q = tf.einsum('bnd,hde->hbne', q_input, self.q) + k = tf.einsum('bmd,hde->hbme', k_input, self.k) + v = tf.einsum('bmd,hde->hbme', v_input, self.v) + + att = tf.einsum('hbne,hbme->hbnm', q, k) * self.scale + + # Add large negative mask to padded keys + mask_k = tf.expand_dims(mask_k, 1) # (B, 1, M) + mask_q = tf.expand_dims(mask_q, 1) # (B, 1, N) + attention_mask = tf.einsum('bqn,bkm->bnm', mask_q, mask_k) # (B, N, M) + attention_mask = tf.expand_dims(attention_mask, 0) # (1, B, N, M) + att += (1.0 - attention_mask) * -1e9 + + att = tf.nn.softmax(att, axis=-1) * attention_mask + + out = tf.einsum('hbnm,hbme->hbne', att, v) + + if self.gate: + g = tf.einsum('bnd,hde->hbne', x_norm, self.g) + g = tf.nn.sigmoid(g) + out *= g + + if self.resnet: + out += tf.expand_dims(x, axis=0) + out = self.norm_resnet(out) + + out = tf.transpose(out, [1, 2, 3, 0]) # (B, N, E, H) + out = tf.reshape(out, [tf.shape(x)[0], tf.shape(x)[1], self.output_dim * self.heads]) + out = tf.matmul(out, self.out_w) + self.out_b + + if self.resnet: + out += x + out = self.norm_out(out) + # Zero out padded tokens after bias addition + mask_exp = tf.expand_dims(mask, axis=-1) # (B, N, 1) + out *= mask_exp + return (out, att) if self.return_att_weights else out + + +class PositionalEncoding(keras.layers.Layer): + """ + Sinusoidal Positional Encoding layer that applies encodings + only to non-masked tokens. + + Args: + embed_dim (int): Dimension of embeddings (must match input last dim). + max_len (int): Maximum sequence length expected (used to precompute encodings). + """ + + def __init__(self, embed_dim, max_len=100): + super().__init__() + self.embed_dim = embed_dim + self.max_len = max_len + + # Create (1, max_len, embed_dim) encoding matrix + pos = tf.range(max_len, dtype=tf.float32)[:, tf.newaxis] # (max_len, 1) + i = tf.range(embed_dim, dtype=tf.float32)[tf.newaxis, :] # (1, embed_dim) + angle_rates = 1 / tf.pow(1000.0, (2 * (i // 2)) / tf.cast(embed_dim, tf.float32)) + angle_rads = pos * angle_rates # (max_len, embed_dim) + + # Apply sin to even indices, cos to odd indices + sines = tf.sin(angle_rads[:, 0::2]) + cosines = tf.cos(angle_rads[:, 1::2]) + + pos_encoding = tf.concat([sines, cosines], axis=-1) # (max_len, embed_dim) + pos_encoding = pos_encoding[tf.newaxis, ...] # (1, max_len, embed_dim) + self.pos_encoding = tf.cast(pos_encoding, dtype=tf.float32) + + def call(self, x, mask=None): + """ + Args: + x: Input tensor of shape (B, N, D) + mask: Optional boolean mask of shape (B, N). True = valid, False = padding + Returns: + Tensor with positional encodings added where mask is True. + """ + seq_len = tf.shape(x)[1] + pe = self.pos_encoding[:, :seq_len, :] # (1, N, D) + + if mask is not None: + mask = tf.cast(mask[:, :, tf.newaxis], tf.float32) # (B, N, 1) + pe = pe * mask # zero out positions where mask is 0 (# TODO: check if this is correct) + + return x + pe + +# --------------------------------------------------------------------------- # + import tensorflow as tf +from tensorflow.keras import layers, Model, Sequential + + +# --------------------------------------------------------------------------- # +# 1.1 Positional + projection layer for the peptide (21-dim per residue) # +# --------------------------------------------------------------------------- # +class PeptideProj(layers.Layer): + """ + Projects peptide vectors (one-hot or 21-dim physicochemical) to embed_dim + and adds a learned positional embedding. + """ + + def __init__(self, max_seq_len, embed_dim, **kwargs): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + self.proj = layers.Dense(embed_dim, use_bias=False, name="peptide_proj") + self.pos_emb = layers.Embedding( + input_dim=max_seq_len, + output_dim=embed_dim, + name="peptide_pos" + ) + + def call(self, x): + # x: (batch, S, 21) + batch_size = tf.shape(x)[0] + seq_len = tf.shape(x)[1] + + # Project input features + h = self.proj(x) # (batch, S, embed_dim) + + # Create position indices + pos_indices = tf.range(seq_len) # (S,) + pos_embeddings = self.pos_emb(pos_indices) # (S, embed_dim) + + # Add positional embeddings (broadcasting) + pos_embeddings = tf.expand_dims(pos_embeddings, 0) # (1, S, embed_dim) + return h + pos_embeddings # (batch, S, embed_dim) + + def get_config(self): + config = super().get_config() + config.update({ + "max_seq_len": self.max_seq_len, + "embed_dim": self.embed_dim + }) + return config + + +# --------------------------------------------------------------------------- # +# 1.2 Positional + projection layer for the latent (1152-dim per residue) # +# --------------------------------------------------------------------------- # +class LatentProj(layers.Layer): + """ + Projects latent vectors (1152-dim) to embed_dim and adds a learned positional + embedding. + """ + + def __init__(self, max_n_residues, embed_dim, **kwargs): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.max_n_residues = max_n_residues + self.proj = layers.Dense(embed_dim, use_bias=False, name="latent_proj") + self.pos_emb = layers.Embedding( + input_dim=max_n_residues, + output_dim=embed_dim, + name="latent_pos" + ) + def call(self, x): + # x: (batch, R, 1152) + batch_size = tf.shape(x)[0] + n_residues = tf.shape(x)[1] -def random_one_hot_sequence(batch, seq_len, num_classes=21, class_probs=None): + # Project input features + h = self.proj(x) # (batch, R, embed_dim) + + # Create position indices + pos_indices = tf.range(n_residues) # (R,) + pos_embeddings = self.pos_emb(pos_indices) # (R, embed_dim) + + # Add positional embeddings (broadcasting) + pos_embeddings = tf.expand_dims(pos_embeddings, 0) # (1, R, embed_dim) + return h + pos_embeddings # (batch, R, embed_dim) + + def get_config(self): + config = super().get_config() + config.update({ + "max_n_residues": self.max_n_residues, + "embed_dim": self.embed_dim + }) + return config + + +# --------------------------------------------------------------------------- # +# 2.1 Self-attention transformer block # +# --------------------------------------------------------------------------- # +class SelfAttentionBlock(layers.Layer): + """ + Self-attention block for sequences, followed by FFN + residuals. + """ + + def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ff_dim = ff_dim + self.dropout_rate = dropout_rate + + self.attn = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=embed_dim // num_heads, + name="self_attn" + ) + self.ffn = Sequential([ + layers.Dense(ff_dim, activation="relu"), + layers.Dense(embed_dim) + ], name="ffn") + self.norm1 = layers.LayerNormalization(epsilon=1e-6) + self.norm2 = layers.LayerNormalization(epsilon=1e-6) + self.drop1 = layers.Dropout(dropout_rate) + self.drop2 = layers.Dropout(dropout_rate) + + def call(self, x, mask=None, training=False): + # x: (B, L, D) + # Convert mask to proper format for attention if provided + attention_mask = None + if mask is not None: + # mask: (B, L) -> need (B, 1, 1, L) for self-attention + attention_mask = mask[:, tf.newaxis, tf.newaxis, :] # (B, 1, 1, L) + + # Self-attention + attn_out = self.attn( + query=x, key=x, value=x, + attention_mask=attention_mask, + training=training + ) + attn_out = self.drop1(attn_out, training=training) + x = self.norm1(x + attn_out) # residual + norm + + # Feed-forward + ff_out = self.ffn(x) + ff_out = self.drop2(ff_out, training=training) + return self.norm2(x + ff_out) # residual + norm + + def get_config(self): + config = super().get_config() + config.update({ + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "ff_dim": self.ff_dim, + "dropout_rate": self.dropout_rate + }) + return config + + +# --------------------------------------------------------------------------- # +# 2.2 Cross-attention transformer block # +# --------------------------------------------------------------------------- # +class CrossAttentionBlock(layers.Layer): + """ + Cross-attention block with self-attention on queries first, then cross-attention. + First applies self-attention to queries, then cross-attention with keys/values. + """ + + def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, **kwargs): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ff_dim = ff_dim + self.dropout_rate = dropout_rate + + # Self-attention for queries + self.self_attn = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=embed_dim // num_heads, + name="self_attn" + ) + + # Cross-attention between queries and keys/values + self.cross_attn = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=embed_dim // num_heads, + name="cross_attn" + ) + + self.ffn = Sequential([ + layers.Dense(ff_dim, activation="relu"), + layers.Dense(embed_dim) + ], name="ffn") + + # Normalization layers + self.norm1 = layers.LayerNormalization(epsilon=1e-6, name="norm1") + self.norm2 = layers.LayerNormalization(epsilon=1e-6, name="norm2") + self.norm3 = layers.LayerNormalization(epsilon=1e-6, name="norm3") + + # Dropout layers + self.drop1 = layers.Dropout(dropout_rate) + self.drop2 = layers.Dropout(dropout_rate) + self.drop3 = layers.Dropout(dropout_rate) + + def call(self, queries, keys_values, query_mask=None, key_mask=None, training=False): + """ + Args: + queries: (B, L_q, D) - query sequences + keys_values: (B, L_kv, D) - key/value sequences + query_mask: (B, L_q) - mask for queries + key_mask: (B, L_kv) - mask for keys/values + training: bool + """ + # Convert masks to proper format for attention if provided + query_attention_mask = None + if query_mask is not None: + # query_mask: (B, L_q) -> need (B, 1, 1, L_q) for self-attention + query_attention_mask = query_mask[:, tf.newaxis, tf.newaxis, :] # (B, 1, 1, L_q) + + key_attention_mask = None + if key_mask is not None: + # key_mask: (B, L_kv) -> need (B, 1, 1, L_kv) for cross-attention + key_attention_mask = key_mask[:, tf.newaxis, tf.newaxis, :] # (B, 1, 1, L_kv) + + # Step 1: Self-attention on queries + self_attn_out = self.self_attn( + query=queries, + key=queries, + value=queries, + attention_mask=query_attention_mask, + training=training + ) + self_attn_out = self.drop1(self_attn_out, training=training) + queries_refined = self.norm1(queries + self_attn_out) # residual + norm + + # Step 2: Cross-attention between refined queries and keys/values + cross_attn_out = self.cross_attn( + query=queries_refined, + key=keys_values, + value=keys_values, + attention_mask=key_attention_mask, + training=training + ) + cross_attn_out = self.drop2(cross_attn_out, training=training) + x = self.norm2(queries_refined + cross_attn_out) # residual connection + + # Step 3: Feed-forward network + ff_out = self.ffn(x) + ff_out = self.drop3(ff_out, training=training) + return self.norm3(x + ff_out) # residual + norm + + def get_config(self): + config = super().get_config() + config.update({ + "embed_dim": self.embed_dim, + "num_heads": self.num_heads, + "ff_dim": self.ff_dim, + "dropout_rate": self.dropout_rate + }) + return config + + +# --------------------------------------------------------------------------- # +# 3. Build the complete classifier # +# --------------------------------------------------------------------------- # +def build_classifier(max_seq_len=50, + max_n_residues=500, + n_blocks=4, + embed_dim=256, + num_heads=8, + ff_dim=512, + dropout_rate=0.01): """ - Generates a random one-hot encoded tensor of shape (batch, seq_len, num_classes) - using a multinomial distribution for more varied sampling. + Build peptide-latent interaction classifier. Args: - batch: Number of sequences (batch size). - seq_len: Length of each sequence. - num_classes: Number of classes for one-hot encoding (default is 21). - class_probs: Optional tensor of shape (num_classes,) defining the probability - distribution over classes. If None, assumes uniform distribution. + max_seq_len: Maximum peptide sequence length + max_n_residues: Maximum number of protein residues + n_blocks: Number of transformer blocks + embed_dim: Embedding dimension + num_heads: Number of attention heads + ff_dim: Feed-forward dimension + dropout_rate: Dropout rate Returns: - A tensor of shape (batch, seq_len, num_classes) with one-hot encoded values. + Compiled Keras model """ - # Define class probabilities (uniform if not provided) - if class_probs is None: - class_probs = tf.ones([num_classes], dtype=tf.float32) / num_classes - else: - class_probs = tf.convert_to_tensor([class_probs], dtype=tf.float32) - class_probs = class_probs / tf.reduce_sum(class_probs) # Normalize to sum to 1 - # Expand class probabilities to match the number of samples - logits = tf.math.log(class_probs)[tf.newaxis, :] # Shape: (1, num_classes) - logits = tf.tile(logits, [batch * seq_len, 1]) # Shape: (batch * seq_len, num_classes) + # --- Inputs ------------------------------------------------------------- + peptide_in = layers.Input(shape=(None, 21), name="peptide") # (B, S, 21) + latent_in = layers.Input(shape=(None, 1152), name="latent_raw") # (B, R, 1152) - # Sample indices from the multinomial distribution - random_indices = tf.random.categorical(logits=logits, num_samples=1) # Shape: (batch * seq_len, 1) - random_indices = tf.reshape(random_indices, [batch, seq_len]) # Reshape to (batch, seq_len) + # --- Create attention masks --------------------------------------------- + # Peptide mask: True where peptide has content (non-zero vectors) + pep_mask = layers.Lambda(lambda x: tf.reduce_any(tf.abs(x) > 1e-6, axis=-1))(peptide_in) # (B, S) - # Apply one-hot encoding - one_hot_encoded = tf.one_hot(random_indices, depth=num_classes) + # Latent mask: True where latent has content (non-zero vectors) + latent_mask = layers.Lambda(lambda x: tf.reduce_any(tf.abs(x) > 1e-6, axis=-1))(latent_in) # (B, R) - return one_hot_encoded + # --- Projections -------------------------------------------------------- + pep_proj = PeptideProj(max_seq_len, embed_dim, name="peptide_projection")(peptide_in) + latent_proj = LatentProj(max_n_residues, embed_dim, name="latent_projection")(latent_in) + # --- Self-attention blocks for latent representation ------------------- + latent_embed = latent_proj + for i in range(n_blocks): + latent_embed = SelfAttentionBlock( + embed_dim=embed_dim, + num_heads=num_heads, + ff_dim=ff_dim, + dropout_rate=dropout_rate, + name=f"latent_self_attn_block_{i + 1}" + )(latent_embed, mask=latent_mask) -import tensorflow as tf -import numpy as np -import pandas as pd -import os - -# Set parameters -os.makedirs('test_tmp', exist_ok=True) -batch_size = 600 -seq_length = 20 # Example sequence length -feature_dim = 21 # Must match encoder input dim -# Generate random input data (batch_size, seq_length, feature_dim) -x_train = random_one_hot_sequence(batch_size, seq_length, feature_dim) -# Initialize SCQ model -model = SCQ_model(general_embed_dim=128, codebook_dim=32, codebook_num=5, descrete_loss=False, heads=8) -# Compile model -model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001)) -# Train model and capture history -history = model.fit(x_train, epochs=1000, batch_size=batch_size) -# Convert history to DataFrame and save as CSV -history_df = pd.DataFrame(history.history) -history_df.to_csv("test_tmp/model_history.csv", index=False) -# Test model on new data -x_test = random_one_hot_sequence(batch_size, seq_length, feature_dim) -output = model(x_test) -# Save input and output arrays as .npy files - -os.makedirs -np.save("test_tmp/input_data.npy", x_test) -np.savez("test_tmp/output_data.npz", decoded=output[0].numpy(), zq=output[1].numpy(), pj=output[2].numpy()) -print("Training complete. Model history saved as 'model_history.csv'.") -print("Input and output arrays saved as 'input_data.npy' and 'output_data.npy'.") -print((output[2])) -print(x_train[0])''' + # --- Cross-attention fusion --------------------------------------------- + # Latent queries attend to peptide keys/values + fused = latent_embed + for i in range(n_blocks): + fused = CrossAttentionBlock( + embed_dim=embed_dim, + num_heads=num_heads, + ff_dim=ff_dim, + dropout_rate=dropout_rate, + name=f"cross_attn_block_{i + 1}" + )(queries=fused, keys_values=pep_proj, query_mask=latent_mask, key_mask=pep_mask) + + + # --- Aggregation and prediction head ----------------------------------- + # Global average pooling with masking + latent_mask_expanded = layers.Lambda(lambda x: tf.expand_dims(tf.cast(x, tf.float32), -1))(latent_mask) # (B, R, 1) + masked_fused = fused * latent_mask_expanded # Zero out padded positions + + # Compute mean only over valid positions + pooled = layers.Lambda(lambda x: tf.reduce_sum(x, axis=1))(masked_fused) # (B, D) + valid_lengths = layers.Lambda(lambda x: tf.reduce_sum(x, axis=1))(latent_mask_expanded) # (B, 1) + pooled = pooled / (valid_lengths + 1e-8) # Average over valid positions + + # Simpler classification head + output = layers.Dense(1, activation="sigmoid", name="output")(pooled) + # # Final prediction layers + # x = layers.Dense(embed_dim, activation="relu", name="pred_hidden")(pooled) + # x = layers.Dropout(dropout_rate)(x) + # x = layers.Dense(embed_dim // 2, activation="relu", name="pred_hidden2")(x) + # x = layers.Dropout(dropout_rate)(x) + # output = layers.Dense(1, activation="softmax", name="output")(x) + + # --- Build and compile model -------------------------------------------- + model = Model( + inputs=[peptide_in, latent_in], + outputs=output, + name="peptide_latent_classifier" + ) + + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), + loss="binary_crossentropy", + metrics=["binary_accuracy", "AUC"] + ) + + return model + + +# ----------------------------------------------------------------------------- # +# # 4. Utility function to test the model # +# # --------------------------------------------------------------------------- # +# def test_model(): +# """Test the model with dummy data to ensure it works.""" +# +# # Create model +# model = build_classifier( +# max_seq_len=20, +# max_n_residues=100, +# n_blocks=2, +# embed_dim=128, +# num_heads=4, +# ff_dim=256, +# dropout_rate=0.1 +# ) +# +# # Print model summary +# print("Model Summary:") +# model.summary() +# +# # Create dummy data +# batch_size = 4 +# seq_len = 15 +# n_residues = 80 +# +# # Dummy peptide data (one-hot encoded) +# peptide_data = tf.random.uniform((batch_size, seq_len, 21), maxval=2, dtype=tf.int32) +# peptide_data = tf.cast(peptide_data, tf.float32) +# +# # Dummy latent data +# latent_data = tf.random.normal((batch_size, n_residues, 1152)) +# +# # Test forward pass +# print("\nTesting forward pass...") +# predictions = model([peptide_data, latent_data]) +# print(f"Output shape: {predictions.shape}") +# print(f"Sample predictions: {predictions[:3].numpy().flatten()}") +# +# # Test with different sequence lengths +# print("\nTesting with variable sequence lengths...") +# peptide_data2 = tf.random.uniform((batch_size, 10, 21), maxval=2, dtype=tf.int32) +# peptide_data2 = tf.cast(peptide_data2, tf.float32) +# latent_data2 = tf.random.normal((batch_size, 60, 1152)) +# +# predictions2 = model([peptide_data2, latent_data2]) +# print(f"Output shape with different lengths: {predictions2.shape}") +# +# print("\nModel test completed successfully!") +# +# return model +# +# +# if __name__ == "__main__": +# # Test the model +# test_model() + + +# # --------------------------------------------------------------------------- # +# # 4. Utility function to test the model # +# # --------------------------------------------------------------------------- # +# def test_model(): +# """Test the model with dummy data to ensure it works.""" +# +# # Create model +# model = build_classifier( +# max_seq_len=20, +# max_n_residues=100, +# n_blocks=2, +# embed_dim=128, +# num_heads=4, +# ff_dim=256, +# dropout_rate=0.1 +# ) +# +# # Print model summary +# print("Model Summary:") +# model.summary() +# +# # Create dummy data +# batch_size = 4 +# seq_len = 15 +# n_residues = 80 +# +# # Dummy peptide data (one-hot encoded) +# peptide_data = tf.random.uniform((batch_size, seq_len, 21), maxval=2, dtype=tf.int32) +# peptide_data = tf.cast(peptide_data, tf.float32) +# +# # Dummy latent data +# latent_data = tf.random.normal((batch_size, n_residues, 1152)) +# +# # Test forward pass +# print("\nTesting forward pass...") +# predictions = model([peptide_data, latent_data]) +# print(f"Output shape: {predictions.shape}") +# print(f"Sample predictions: {predictions[:3].numpy().flatten()}") +# +# # Test with different sequence lengths +# print("\nTesting with variable sequence lengths...") +# peptide_data2 = tf.random.uniform((batch_size, 10, 21), maxval=2, dtype=tf.int32) +# peptide_data2 = tf.cast(peptide_data2, tf.float32) +# latent_data2 = tf.random.normal((batch_size, 60, 1152)) +# +# predictions2 = model([peptide_data2, latent_data2]) +# print(f"Output shape with different lengths: {predictions2.shape}") +# +# print("\nModel test completed successfully!") +# +# return model +# +# +# if __name__ == "__main__": +# # Test the model +# test_model() +# --------------------------------------------------------------------------- # +# 4. Demo: instantiate and inspect # +# --------------------------------------------------------------------------- # +# if __name__ == "__main__": +# SEQ_LEN = 15 # typical peptide length (adjust to your data) +# +# model = build_classifier(max_seq_len=SEQ_LEN) +# model.summary() +# +# # Training example (dummy): +# peptide_batch = tf.random.uniform((320, SEQ_LEN, 21)) +# latent_batch = tf.random.uniform((320, 36, 1152)) +# labels = tf.random.uniform((320, 1), maxval=2, dtype=tf.int32) +# model.fit([peptide_batch, latent_batch], labels, epochs=100) +## Barcode peptides ## +# a model that creates a barcode for peptides by taking 9mer windows and returning a 1D vector that represents \ No newline at end of file diff --git a/utils/model_archive.py b/utils/model_archive.py new file mode 100644 index 00000000..c76ecea5 --- /dev/null +++ b/utils/model_archive.py @@ -0,0 +1,838 @@ +#!/usr/bin/env python +""" +========================= + +End‑to‑end trainer for a **peptide×MHC cross‑attention classifier**. +It loads a NetMHCpan‑style parquet that contains + + long_mer, assigned_label, allele, MHC_class, + mhc_embedding **OR** mhc_embedding_path + +columns. Each row supplies + +* a peptide sequence (long_mer) +* a pre‑computed MHC pseudo‑sequence embedding (36, 1152) +* a binary label (assigned_label) + +The script + +1.Derives the longest peptide length → SEQ_LEN. +2.Converts every peptide into a 21‑dim one‑hot tensor (SEQ_LEN, 21). +3.Feeds the pair + + (one_hot_peptide, mhc_latent) → classifier → P(binding) + +4.Trains with binary‑cross‑entropy and saves the best weights & metadata. + +Author : Amirreza (updated for cross‑attention, 2025‑05‑22) +""" +from __future__ import annotations +import os, sys, argparse, datetime, pathlib, json +from random import random + +print(sys.executable) + +import numpy as np +from tensorflow import keras +from tensorflow.keras import layers + +# ~8×-40× faster on CPU, zero-copy on GPU +import tensorflow as tf + +AA = tf.constant(list("ACDEFGHIKLMNPQRSTVWY")) +TABLE = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(AA, + tf.range(20, dtype=tf.int32)), + default_value=20) # UNK_IDX + +def peptides_to_onehot_tf(seqs, seq_len, k=9): + # seqs : tf.Tensor([b'PEPTIDE', ...]) shape=(N,) + # output: (N, RF, k, 21) where RF = seq_len - k + 1 + tokens = tf.strings.bytes_split(seqs) # ragged ‹N, L› + idx = TABLE.lookup(tokens.flat_values) + ragged = tf.RaggedTensor.from_row_lengths(idx, tokens.row_lengths()) + idx_pad = ragged.to_tensor(default_value=20, shape=(None, seq_len)) + onehot = tf.one_hot(idx_pad, 21, dtype=tf.float32) # (N, seq_len, 21) + + RF = seq_len - k + 1 + ta = tf.TensorArray(dtype=tf.float32, size=RF) + def body(i, ta): + slice_k = onehot[:, i:i+k, :] # (N, k, 21) + return i+1, ta.write(i, slice_k) + _, ta = tf.while_loop(lambda i, _: i < RF, body, [0, ta]) + patches = ta.stack() # (RF, N, k, 21) + return tf.transpose(patches, [1, 0, 2, 3]) # (N, RF, k, 21) + + +class AttentionLayer(keras.layers.Layer): + """ + Custom multi-head attention layer supporting self- and cross-attention. + + Args: + input_dim (int): Input feature dimension. + output_dim (int): Output feature dimension per head. + type (str): 'self' or 'cross'. + heads (int): Number of attention heads. + resnet (bool): Whether to use residual connection. + return_att_weights (bool): Whether to return attention weights. + name (str): Name for weight scopes. + epsilon (float): Epsilon for layer normalization. + gate (bool): Whether to use gating mechanism. + """ + + def __init__(self, input_dim, output_dim, type, heads=4, + resnet=True, return_att_weights=False, name='attention', + epsilon=1e-6, gate=True, mask_token=-1., pad_token=-2.): + super().__init__(name=name) + assert isinstance(input_dim, int) and isinstance(output_dim, int) + assert type in ['self', 'cross'] + if resnet: + assert input_dim == output_dim + + self.input_dim = input_dim + self.output_dim = output_dim + self.type = type + self.heads = heads + self.resnet = resnet + self.return_att_weights = return_att_weights + self.epsilon = epsilon + self.gate = gate + self.mask_token = mask_token + self.pad_token = pad_token + + def build(self, x): + self.q = self.add_weight(shape=(self.heads, self.input_dim, self.output_dim), + initializer='random_normal', trainable=True, name=f'q_{self.name}') + self.k = self.add_weight(shape=(self.heads, self.input_dim, self.output_dim), + initializer='random_normal', trainable=True, name=f'k_{self.name}') + self.v = self.add_weight(shape=(self.heads, self.input_dim, self.output_dim), + initializer='random_normal', trainable=True, name=f'v_{self.name}') + if self.gate: + self.g = self.add_weight(shape=(self.heads, self.input_dim, self.output_dim), + initializer='random_uniform', trainable=True, name=f'gate_{self.name}') + self.norm = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_{self.name}') + if self.type == 'cross': + self.norm_context = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_context_{self.name}') + self.norm_out = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_out_{self.name}') + if self.resnet: + self.norm_resnet = layers.LayerNormalization(epsilon=self.epsilon, name=f'ln_resnet_{self.name}') + self.out_w = self.add_weight(shape=(self.output_dim * self.heads, self.output_dim), + initializer='random_normal', trainable=True, name=f'outw_{self.name}') + self.out_b = self.add_weight(shape=(self.output_dim,), initializer='zeros', + trainable=True, name=f'outb_{self.name}') + self.scale = 1.0 / tf.math.sqrt(tf.cast(self.output_dim, tf.float32)) + + def call(self, x, mask, context=None, context_mask=None): + """ + Args: + x: Tensor of shape (B, N, D) + mask: Tensor of shape (B,N) + context: Optional tensor (B, M, D) for cross-attention + """ + # Auto-generate padding mask if not provided (based on all-zero tokens) + mask = tf.cast(mask, tf.float32) # shape: (B, N) + + x_norm = self.norm(x) + if self.type == 'self': + q_input = k_input = v_input = x_norm + mask = tf.where(mask == self.pad_token, 0., + 1.) # all padded ones are 0 and masked ones (for mask learning) and normal ones are 1 + mask_k = mask_q = mask + else: + assert context is not None, "context is required for cross-attention" + assert context_mask is not None, "context_mask is required for cross-attention" + context_norm = self.norm_context(context) + q_input = x_norm + k_input = v_input = context_norm + mask_q = tf.where(mask == self.pad_token, 0., 1.) + mask_k = tf.where(context_mask == self.pad_token, 0., 1.) + + q = tf.einsum('bnd,hde->hbne', q_input, self.q) + k = tf.einsum('bmd,hde->hbme', k_input, self.k) + v = tf.einsum('bmd,hde->hbme', v_input, self.v) + + att = tf.einsum('hbne,hbme->hbnm', q, k) * self.scale + + # Add large negative mask to padded keys + mask_k = tf.expand_dims(mask_k, 1) # (B, 1, M) + mask_q = tf.expand_dims(mask_q, 1) # (B, 1, N) + attention_mask = tf.einsum('bqn,bkm->bnm', mask_q, mask_k) # (B, N, M) + attention_mask = tf.expand_dims(attention_mask, 0) # (1, B, N, M) + att += (1.0 - attention_mask) * -1e9 + + att = tf.nn.softmax(att, axis=-1) * attention_mask + + out = tf.einsum('hbnm,hbme->hbne', att, v) + + if self.gate: + g = tf.einsum('bnd,hde->hbne', x_norm, self.g) + g = tf.nn.sigmoid(g) + out *= g + + if self.resnet: + out += tf.expand_dims(x, axis=0) + out = self.norm_resnet(out) + + out = tf.transpose(out, [1, 2, 3, 0]) # (B, N, E, H) + out = tf.reshape(out, [tf.shape(x)[0], tf.shape(x)[1], self.output_dim * self.heads]) + out = tf.matmul(out, self.out_w) + self.out_b + + if self.resnet: + out += x + out = self.norm_out(out) + # Zero out padded tokens after bias addition + mask_exp = tf.expand_dims(mask, axis=-1) # (B, N, 1) + out *= mask_exp + return (out, att) if self.return_att_weights else out + + def get_config(self): + config = super().get_config() + config.update({ + 'input_dim': self.input_dim, + 'output_dim': self.output_dim, + 'type': self.type, + 'heads': self.heads, + 'resnet': self.resnet, + 'return_att_weights': self.return_att_weights, + 'epsilon': self.epsilon, + 'gate': self.gate, + 'mask_token': self.mask_token, + 'pad_token': self.pad_token, + }) + return config + + +class PositionalEncoding(keras.layers.Layer): + """ + Sinusoidal Positional Encoding layer that applies encodings + only to non-masked tokens. + + Args: + embed_dim (int): Dimension of embeddings (must match input last dim). + max_len (int): Maximum sequence length expected (used to precompute encodings). + """ + + def __init__(self, embed_dim, max_len: int =100, mask_token: float =-1., pad_token: float =-2., name: str ='positional_encoding'): + super().__init__(name=name) + self.embed_dim = embed_dim + self.max_len = max_len + self.mask_token = mask_token + self.pad_token = pad_token + + def build(self, x): + # Create (1, max_len, embed_dim) encoding matrix + pos = tf.range(self.max_len, dtype=tf.float32)[:, tf.newaxis] # (max_len, 1) + i = tf.range(self.embed_dim, dtype=tf.float32)[tf.newaxis, :] # (1, embed_dim) + angle_rates = 1 / tf.pow(300.0, (2 * (i // 2)) / tf.cast(self.embed_dim, tf.float32)) + angle_rads = pos * angle_rates # (max_len, embed_dim) + + # Apply sin to even indices, cos to odd indices + sines = tf.sin(angle_rads[:, 0::2]) + cosines = tf.cos(angle_rads[:, 1::2]) + + pos_encoding = tf.concat([sines, cosines], axis=-1) # (max_len, embed_dim) + pos_encoding = pos_encoding[tf.newaxis, ...] # (1, max_len, embed_dim) + self.pos_encoding = tf.cast(pos_encoding, dtype=tf.float32) + + def call(self, x, mask): + """ + Args: + x: Input tensor of shape (B, N, D) + mask: Tensor of shape (B,N) + Returns: + Tensor with positional encodings added for masked and non padded tokens. + """ + seq_len = tf.shape(x)[1] + pe = self.pos_encoding[:, :seq_len, :] # (1, N, D) + mask = tf.cast(mask[:, :, tf.newaxis], tf.float32) # (B, N, 1) + mask = tf.where(mask == self.pad_token, 0., 1.) + pe = pe * mask # zero out positions where mask is 0 + + return x + pe + + def get_config(self): + config = super().get_config() + config.update({ + 'embed_dim': self.embed_dim, + 'max_len': self.max_len, + 'mask_token': self.mask_token, + 'pad_token': self.pad_token, + }) + return config + +# class RotaryPositionalEncoding(keras.layers.Layer): +# """ +# Rotary Positional Encoding layer for transformer models. +# Applies rotary embeddings to the last two dimensions of the input. +# Args: +# embed_dim (int): Embedding dimension (must be even). +# max_len (int): Maximum sequence length. +# """ +# +# def __init__(self, embed_dim, max_len: int = 100, mask_token: float = -1., pad_token: float = -2., name: str = 'rotary_positional_encoding'): +# super().__init__(name=name) +# assert embed_dim % 2 == 0, "embed_dim must be even for rotary encoding" +# self.embed_dim = embed_dim +# self.max_len = max_len +# self.mask_token = mask_token +# self.pad_token = pad_token +# +# def build(self, x): +# # Precompute rotary frequencies +# pos = tf.range(self.max_len, dtype=tf.float32)[:, tf.newaxis] # (max_len, 1) +# dim = tf.range(self.embed_dim // 2, dtype=tf.float32)[tf.newaxis, :] # (1, embed_dim//2) +# inv_freq = 1.0 / (10000 ** (dim / (self.embed_dim // 2))) +# freqs = pos * inv_freq # (max_len, embed_dim//2) +# self.cos_cached = tf.cast(tf.cos(freqs), tf.float32) # (max_len, embed_dim//2) +# self.sin_cached = tf.cast(tf.sin(freqs), tf.float32) # (max_len, embed_dim//2) +# +# def call(self, x, mask): +# """ +# Args: +# x: Input tensor of shape (B, N, D) +# mask: Tensor of shape (B, N) +# Returns: +# Tensor with rotary positional encoding applied. +# """ +# seq_len = tf.shape(x)[1] +# cos = self.cos_cached[:seq_len, :] # (N, D//2) +# sin = self.sin_cached[:seq_len, :] # (N, D//2) +# cos = tf.expand_dims(cos, 0) # (1, N, D//2) +# sin = tf.expand_dims(sin, 0) # (1, N, D//2) +# +# x1, x2 = tf.split(x, 2, axis=-1) # (B, N, D//2), (B, N, D//2) +# x_rot = tf.concat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1) # (B, N, D) +# +# mask = tf.cast(mask[:, :, tf.newaxis], tf.float32) # (B, N, 1) +# mask = tf.where(mask == self.pad_token, 0., 1.) +# x_rot = x_rot * mask # zero out positions where mask is 0 +# +# return x_rot +# + # def get_config(self): + # config = super().get_config() + # config.update({ + # 'embed_dim': self.embed_dim, + # 'max_len': self.max_len, + # 'mask_token': self.mask_token, + # 'pad_token': self.pad_token, + # }) + # return config + + +@tf.function +def select_indices(ind, n, m_range): + """ + Select top-n indices from `ind` (descending sorted) such that: + - First index is always selected. + - Each subsequent index has a distance from all previously selected + indices between m_range[0] and m_range[1], inclusive. + Args: + ind: Tensor of shape (B, N) with descending sorted indices. + n: Number of indices to select. + m_range: List or tuple [min_distance, max_distance] + Returns: + Tensor of shape (B, n) with selected indices per batch. + """ + m_min = tf.constant(m_range[0], dtype=tf.int32) + m_max = tf.constant(m_range[1], dtype=tf.int32) + + def per_batch_select(indices): + top = indices[0] + selected = tf.TensorArray(dtype=tf.int32, size=n) + selected = selected.write(0, top) + count = tf.constant(1) + i = tf.constant(1) + + def cond(i, count, selected): + return tf.logical_and(i < tf.shape(indices)[0], count < n) + + def body(i, count, selected): + candidate = indices[i] + selected_vals = selected.stack()[:count] + distances = tf.abs(selected_vals - candidate) + if_valid = tf.reduce_all( + tf.logical_and(distances >= m_min, distances <= m_max) + ) + selected = tf.cond(if_valid, + lambda: selected.write(count, candidate), + lambda: selected) + count = tf.cond(if_valid, lambda: count + 1, lambda: count) + return i + 1, count, selected + + _, _, selected = tf.while_loop( + cond, body, [i, count, selected], + shape_invariants=[i.get_shape(), count.get_shape(), tf.TensorShape(None)] + ) + return selected.stack() + + return tf.map_fn(per_batch_select, ind, dtype=tf.int32) + + +class AnchorPositionExtractor(keras.layers.Layer): + def __init__(self, num_anchors, dist_thr, name='anchor_extractor', project=True, + mask_token=-1., pad_token=-2., return_att_weights=False): + super().__init__(name=name) + assert isinstance(dist_thr, list) and len(dist_thr) == 2 + assert num_anchors > 0 + self.num_anchors = num_anchors + self.dist_thr = dist_thr + self.project = project + self.mask_token = mask_token + self.pad_token = pad_token + self.return_att_weights = return_att_weights + + def build(self, input_shape): # att_out (B,N,E) + b, n, e = input_shape[0], input_shape[1], input_shape[2] + self.barcode = tf.random.uniform(shape=(1, 1, e)) # add as a token to input + self.q = self.add_weight(shape=(e, e), + initializer='random_normal', + trainable=True, name=f'query_{self.name}') + self.k = self.add_weight(shape=(e, e), + initializer='random_normal', + trainable=True, name=f'key_{self.name}') + self.v = self.add_weight(shape=(e, e), + initializer='random_normal', + trainable=True, name=f'value_{self.name}') + self.ln = layers.LayerNormalization(name=f'ln_{self.name}') + if self.project: + self.g = self.add_weight(shape=(self.num_anchors, e, e), + initializer='random_uniform', + trainable=True, name=f'gate_{self.name}') + self.w = self.add_weight(shape=(1, self.num_anchors, e, e), + initializer='random_normal', + trainable=True, name=f'w_{self.name}') + + def call(self, input, mask): # (B,N,E) this is peptide embedding and (B,N) for mask + + mask = tf.cast(mask, tf.float32) # (B, N) + mask = tf.where(mask == self.pad_token, 0., 1.) + + barcode = self.barcode + barcode = tf.broadcast_to(barcode, (tf.shape(input)[0], 1, tf.shape(input)[-1])) # (B,N,E) + q = tf.matmul(barcode, self.q) # (B,1,E)*(E,E)->(B,1,E) + k = tf.matmul(input, self.k) # (B,N,E)*(E,E)->(B,N,E) + v = tf.matmul(input, self.v) # (B,N,E)*(E,E)->(B,N,E) + scale = 1 / tf.math.sqrt(tf.cast(tf.shape(input)[-1], tf.float32)) + barcode_att = tf.matmul(q, k, transpose_b=True) * scale # (B,1,E)*(B,E,N)->(B,1,N) + # mask: (B,N) => (B,1,N) + mask_exp = tf.expand_dims(mask, axis=1) + additive_mask = (1.0 - mask_exp) * -1e9 + barcode_att += additive_mask + barcode_att = tf.nn.softmax(barcode_att) + barcode_att *= mask_exp # to remove the impact of row wise attention of padded tokens. since all are 1e-9 + barcode_out = tf.matmul(barcode_att, v) # (B,1,N)*(B,N,E)->(B,1,E) + # barcode_out represents a vector for all information from peptide + # barcode_att represents the anchor positions which are the tokens with highest weights + inds, weights, outs = self.find_anchor(input, + barcode_att) # (B,num_anchors) (B,num_anchors) (B, num_anchors, E) + if self.project: + pos_encoding = tf.broadcast_to( + tf.expand_dims(inds, axis=-1), + (tf.shape(outs)[0], tf.shape(outs)[1], tf.shape(outs)[2]) + ) + pos_encoding = tf.cast(pos_encoding, tf.float32) + dim = tf.cast(tf.shape(outs)[-1], tf.float32) + ra = tf.range(dim, dtype=tf.float32) / dim + pos_encoding = tf.sin(pos_encoding / tf.pow(40., ra)) + outs += pos_encoding + + weights_bc = tf.expand_dims(weights, axis=-1) + weights_bc = tf.broadcast_to(weights_bc, (tf.shape(weights_bc)[0], + tf.shape(weights_bc)[1], + tf.shape(outs)[-1] + )) # (B,num_anchors, E) + outs = tf.expand_dims(outs, axis=-2) # (B, num_anchors, 1, E) + outs_w = tf.matmul(outs, self.w) # (B,num_anchors,1,E)*(1,num_anchors,E,E)->(B,num_anchors,1,E) + outs_g = tf.nn.sigmoid(tf.matmul(outs, self.g)) + outs_w = tf.squeeze(outs_w, axis=-2) # (B,num_anchors,E) + outs_g = tf.squeeze(outs_g, axis=-2) + # multiply by attention weights from barcode_att to choose best anchors and additional feature gating + outs = outs_w * outs_g * weights_bc # (B, num_anchors, E) + outs = self.ln(outs) + # outs -> anchor info, inds -> anchor indeces, weights -> anchor att weights, barcode_out -> whole peptide features + # (B,num_anchors,E), (B,num_anchors), (B,num_anchors), (B,E) + if self.return_att_weights: + return outs, inds, weights, tf.squeeze(barcode_out, axis=1), barcode_att + else: + return outs, inds, weights, tf.squeeze(barcode_out, axis=1) + + def find_anchor(self, input, barcode_att): # (B,N,E), (B,1,N) + inds = tf.argsort(barcode_att, axis=-1, direction='DESCENDING', stable=False) # (B,1,N) + inds = tf.squeeze(inds, axis=1) # (B,N) + selected_inds = select_indices(inds, n=self.num_anchors, m_range=self.dist_thr) # (B,num_anchors) + sorted_selected_inds = tf.sort(selected_inds) + sorted_selected_weights = tf.gather(tf.squeeze(barcode_att, axis=1), + sorted_selected_inds, + axis=1, + batch_dims=1) # (B,num_anchors) + sorted_selected_output = tf.gather(input, sorted_selected_inds, axis=1, batch_dims=1) # (B,num_anchors,E) + return sorted_selected_inds, sorted_selected_weights, sorted_selected_output + + def get_config(self): + config = super().get_config() + config.update({ + 'num_anchors': self.num_anchors, + 'dist_thr': self.dist_thr, + 'project': self.project, + 'mask_token': self.mask_token, + 'pad_token': self.pad_token, + 'return_att_weights': self.return_att_weights, + }) + return config + + +def generate_mhc(samples=1024, min_len=5, max_len=15, dim=16): + X_list = [] + mask_list = [] + + for _ in range(samples): + l = np.random.randint(min_len, max_len + 1) + x = np.random.rand(l, dim).astype(np.float32) + X_list.append(x) + mask = np.ones((l,), dtype=bool) + mask_list.append(mask) + + # Pad sequences + X_padded = tf.keras.preprocessing.sequence.pad_sequences( + X_list, maxlen=max_len, dtype='float32', padding='post' + ) # shape: (samples, max_len, dim) + + mask_padded = tf.keras.preprocessing.sequence.pad_sequences( + mask_list, maxlen=max_len, dtype=bool, padding='post' + ) # shape: (samples, max_len) + + # Compute masked mean (only over valid tokens) + mask_exp = np.expand_dims(mask_padded, axis=-1).astype(np.float32) # (samples, max_len, 1) + y = np.sum(X_padded * mask_exp, axis=1) / np.maximum(np.sum(mask_exp, axis=1), 1e-8) # (samples, dim) + mean_of_all = tf.reduce_mean(y) + y = tf.reduce_mean(y, axis=-1, keepdims=True) + y = np.where(y > mean_of_all, 1., 0.) + mask_padded = np.where(mask_padded == False, -2., 0.) + + return X_padded, y, mask_padded + +def generate_peptide(samples=1024, min_len=5, max_len=15, k=9): + """ + Generate random peptide one-hot tensors of shape (N, RF, k, 21) + where RF = max_len - k + 1. + """ + peptides = [] + for _ in range(samples): + l = np.random.randint(min_len, max_len + 1) + seq = np.random.choice(list("ACDEFGHIKLMNPQRSTVWY"), l) + seq_str = ''.join(seq) + peptides.append(seq_str) + # Convert to tf.Tensor + seqs = tf.constant(peptides) + RF = max_len - k + 1 + onehot = peptides_to_onehot_tf(seqs, max_len, k) # (N, RF, k, 21) + return onehot.numpy() + + + +# --------------------------------------------------------------------------- +# Constants & helpers +# --------------------------------------------------------------------------- +MASK_TOKEN = -1 +PAD_TOKEN = -2. + +@tf.keras.utils.register_keras_serializable() +def make_rf_mask(pep_batch: tf.Tensor, MASK_TOKEN, PAD_TOKEN) -> tf.Tensor: + """Return (B, RF) float mask for a peptide batch of shape *(B, RF,K,21)*. + + A slot is **1** when any (K,21) element is non‑zero, else **PAD_TOKEN**. + """ + # True where *any* channel is non‑zero → valid RF window + non_zero = tf.math.reduce_any(tf.not_equal(pep_batch, 0.), axis=[2, 3]) # (B, RF) + return tf.where(non_zero, 1., PAD_TOKEN) # float32 + + +#### define model +# input layers +def build_custom_classifier(max_len_peptide: int = 50, + max_len_mhc: int = 200, + k: int = 9, + embed_dim_pep: int = 64, + embed_dim_mhc: int = 128, + mask_token: float = MASK_TOKEN, + pad_token: float = PAD_TOKEN): + """Return a compiled Keras model that consumes + + pep_input : (RF_max, k, 21) + mhc_input : (max_len_mhc, 1152) + + RF_max is ``max_len_peptide − k + 1``. + """ + RF_max = max_len_peptide - k + 1 + + # ----- inputs ---------------------------------------------------------- + pep_input = keras.Input(shape=(RF_max, k, 21), name="pep_input") + mhc_input = keras.Input(shape=(max_len_mhc, 1152), name="mhc_input") + + # ----- peptide branch -------------------------------------------------- + pep_mask = layers.Lambda(lambda x: make_rf_mask(x, MASK_TOKEN, PAD_TOKEN), + output_shape=lambda input_shape: (input_shape[0], input_shape[1]), + name="pep_mask")(pep_input) # (B, RF) + + pep_flat = layers.Reshape((RF_max, k * 21), name="pep_flat")(pep_input) # (B, RF, k·21) + + # TODO, check if it make sence to do this before or after positional encoding + pep_proj = layers.Dense(embed_dim_pep, activation="relu", + name="pep_proj1")(pep_flat) # (B, RF, 64) + + # Use Lambda to dynamically determine max_len from input shape + pep_pe = PositionalEncoding(embed_dim_pep, max_len=RF_max, + mask_token=mask_token, pad_token=pad_token, + name="pep_pos_enc")(pep_proj, pep_mask) + + pep_att1 = AttentionLayer(input_dim=embed_dim_pep, output_dim=embed_dim_pep, + type="self", heads=4, name="pep_self_att1", + mask_token=mask_token, pad_token=pad_token)( + pep_pe, mask=pep_mask) + + # ----- MHC branch ------------------------------------------------------ + mhc_mask = layers.Lambda( + lambda x: tf.where(tf.math.reduce_any(tf.not_equal(x, 0.), axis=-1), 1., pad_token), + output_shape=lambda input_shape: (input_shape[0], input_shape[1]), + name="mhc_mask" + )(mhc_input) + + # TODO check if it is correct + mhc_proj1 = layers.Dense(embed_dim_mhc, activation="relu", + name="mhc_proj1")(mhc_input) # (B, L_mhc, D_mhc) + + mhc_pe = PositionalEncoding(embed_dim=embed_dim_mhc,max_len=max_len_mhc, + mask_token=mask_token,pad_token=pad_token, + name="mhc_pos_enc")(mhc_proj1, mhc_mask) + + mhc_att1 = AttentionLayer(input_dim=embed_dim_mhc, output_dim=embed_dim_mhc, + type="self", heads=4, name="mhc_self_att1", + mask_token=mask_token, pad_token=pad_token)( + mhc_pe, mask=mhc_mask) + mhc_att2 = AttentionLayer(input_dim=embed_dim_mhc, output_dim=embed_dim_mhc, + type="self", heads=4, name="mhc_self_att2", + resnet=False, mask_token=mask_token, pad_token=pad_token)( + mhc_att1, mask=mhc_mask) # (B, L_mhc, D_mhc), att_weights + + + # NEW: project MHC tokens → pep embed‑dim (64) so Q/K/V dims align + mhc_to_D_pep = layers.Dense(embed_dim_pep, activation="relu", + name="mhc_to_D_pep")(mhc_att2) # (B, L_mhc, D_pep) + + # ----- cross‑attention ------------------------------------------------- + cross_att = AttentionLayer(input_dim=embed_dim_pep, output_dim=embed_dim_pep, + type="cross", heads=8, name="cross_att", + mask_token=mask_token, pad_token=pad_token)( + pep_att1, mask=pep_mask, + context=mhc_to_D_pep, context_mask=mhc_mask) + + cross_proj = layers.Dense(64, activation="relu", name="cross_proj")(cross_att) + + final_att = AttentionLayer(input_dim=64, output_dim=64, type="self", + heads=2, name="final_pep_self_att", + return_att_weights=True, + mask_token=mask_token, pad_token=pad_token)( + cross_proj, mask=pep_mask) + + final_features = AnchorPositionExtractor(num_anchors=2, dist_thr=[8, 15], # outs, inds, weights, barcode_out, barcode_att + name="anchor_extractor", # (B,num_anchors,E), (B,num_anchors), (B,num_anchors), (B,E), (B,N,N) + project=True, + return_att_weights=True, + mask_token=mask_token, pad_token=pad_token)( + final_att[0], mask=pep_mask) + + # ---------------------------------------------------------------------- + # Three heads (barcode, anchors, pooled) — unchanged + # ---------------------------------------------------------------------- + # barcode_vec = final_features[-2] + # x_bc = layers.Dense(32, activation="relu", name="barcodout1_dense")(barcode_vec) + # x_bc = layers.Dropout(0.3, name="barcodout1_dropout")(x_bc) + # out_bc = layers.Dense(1, activation="sigmoid", name="barcode_cls")(x_bc) + + anchor_feat = final_features[0] # (B, num_anchors, E) + x_an1 = layers.Flatten(name="anchorout_flatten")(anchor_feat) + x_an2 = layers.Dense(64, activation="relu", name="anchorout1_dense")(x_an1) + x_an3 = layers.Dropout(0.3, name="anchorout1_dropout")(x_an2) + x_an4 = layers.Dense(16, activation="relu", name="anchorout2_dense")(x_an3) + x_an5 = layers.Dropout(0.3, name="anchorout2_dropout")(x_an4) + out_an = layers.Dense(1, activation="sigmoid", name="anchor_cls")(x_an5) + + # pooled = layers.GlobalAveragePooling1D(name="attout_gap")(final_att[0]) + # x_po = layers.Dense(32, activation="relu", name="attout1_dense")(pooled) + # x_po = layers.Dropout(0.3, name="attout1_dropout")(x_po) + # out_po = layers.Dense(1, activation="sigmoid", name="attn_cls")(x_po) + + model = keras.Model(inputs=[pep_input, mhc_input], + outputs=out_an, + name="PeptideMHC_CrossAtt") + + model.compile( + loss="binary_crossentropy", + optimizer="adam", + metrics=["binary_accuracy", "AUC"], + ) + return model + +# def main(): +# max_len_peptide = 14 +# k = 9 +# max_len_mhc = 36 +# RF_max = max_len_peptide - k + 1 +# +# model = build_custom_classifier(max_len_peptide, max_len_mhc, k=k) +# model.summary(line_length=110) +# +# batch = 16 +# # pep_dummy = np.zeros((batch, RF_max, k, 21), dtype=np.float32) +# # pep_dummy[:, :3] = np.random.rand(batch, 3, k, 21) +# # +# # mhc_dummy = np.zeros((batch, max_len_mhc, 1152), dtype=np.float32) +# # mhc_dummy[:, :25] = np.random.rand(batch, 25, 1152) +# +# pep_syn = generate_peptide(samples=10, min_len=9, max_len=max_len_peptide, k=k) +# mhc_syn, _, mhc_mask = generate_mhc(samples=10, min_len=25, max_len=max_len_mhc, dim=1152) +# +# y = np.random.randint(0, 2, size=(10, 1)).astype(np.float32) +# +# history = model.fit(x=[pep_syn, mhc_syn], y=y, epochs=2, batch_size=batch) +# +# # save model +# model_dir = pathlib.Path("model_output") +# model_dir.mkdir(parents=True, exist_ok=True) +# model_path = model_dir / "peptide_mhc_cross_attention_model.h5" +# model.save(model_path) +# print(f"Model saved to {model_path}") +# +# print("Sanity‑check complete — no dimension errors.") +# +# # PREDICT +# preds = model.predict([pep_syn[:batch], mhc_syn[:batch]]) +# print("Predictions for first batch:") +# for i, pred in enumerate(preds): +# print(f"Sample {i + 1}: Anchor: {pred[0]:.4f}") +# # Save model metadata +# metadata = { +# "max_len_peptide": max_len_peptide, +# "k": k, +# "max_len_mhc": max_len_mhc, +# "RF_max": RF_max, +# "embed_dim_pep": 64, +# "embed_dim_mhc": 128, +# "mask_token": MASK_TOKEN, +# "pad_token": PAD_TOKEN +# } +# metadata_path = model_dir / "model_metadata.json" +# with open(metadata_path, 'w') as f: +# json.dump(metadata, f, indent=4) +# +# print(f"Model metadata saved to {metadata_path}") +# +# # plot metrics and confusion +# import matplotlib.pyplot as plt +# import seaborn as sns +# import pandas as pd +# history_df = pd.DataFrame(history.history) +# print(f"Keys in history: {list(history_df.columns)}") +# +# ## Plot metrics with error handling and saving to disk +# metrics_dir = model_dir / "metrics" +# metrics_dir.mkdir(parents=True, exist_ok=True) +# +# # Plot binary accuracy +# plt.figure(figsize=(10, 5)) +# sns.lineplot(data=history_df, x=history_df.index, y='binary_accuracy', label='Binary Accuracy') +# plt.title('Binary Accuracy Over Epochs') +# plt.xlabel('Epochs') +# plt.ylabel('Binary Accuracy') +# plt.legend() +# # plt.show() +# plt.savefig("model_output/binary_accuracy_plot.png") +# +# # plot AUC +# plt.figure(figsize=(10, 5)) +# sns.lineplot(data=history_df, x=history_df.index, y='AUC', label='AUC') +# plt.title('AUC Over Epochs') +# plt.xlabel('Epochs') +# plt.ylabel('AUC') +# plt.legend() +# # plt.show() +# plt.savefig("model_output/auc_plot.png") +# + +# def test1(): +# import matplotlib.pyplot as plt +# import seaborn as sns +# import numpy as np +# +# max_len_peptide = 50 # Maximum length of peptide sequences +# max_len_mhc = 200 # Maximum length of MHC sequences +# k = 9 # Length of peptide k-mers +# batch = 5 # Number of samples to generate +# RF_max = max_len_peptide - k + 1 # Number of RF tokens +# +# # Build and summarize model +# model = build_custom_classifier(max_len_peptide, max_len_mhc, k=k, +# mask_token=MASK_TOKEN, pad_token=PAD_TOKEN) +# model.summary(line_length=110) +# +# # Generate synthetic peptide input and MHC input +# pep_syn = generate_peptide(samples=batch, min_len=9, max_len=max_len_peptide, k=k) +# mhc_syn, _, _ = generate_mhc(samples=batch, min_len=25, max_len=max_len_mhc, dim=1152) +# +# # For demonstration, we perform one prediction without training. +# # In practice, you would train the model and then run prediction. +# preds = model.predict([pep_syn, mhc_syn]) +# print("Overall Predictions (binding probability) per sample:") +# for i, pred in enumerate(preds): +# print(f"Sample {i + 1}: Probability = {pred[0]:.4f}") +# +# # ----- Extract attention weights from the final self-attention layer ----- +# # We know that final self-attention is applied with return_att_weights=True +# # and its name is 'final_pep_self_att' +# # Its output is a tuple: (final_features, att_weights) +# # We create a new model with the same inputs but outputting the attention tuple. +# final_att_layer = model.get_layer("final_pep_self_att") +# att_model = keras.Model(inputs=model.inputs, outputs=final_att_layer.output) +# +# # Run the synthetic data through the attention model +# # Note: Depending on your model, the layer returns a tuple. +# att_outputs = att_model.predict([pep_syn, mhc_syn]) +# # att_outputs is a tuple: (final_features, att_weights) +# # final_features: shape (B, RF, output_dim) +# # att_weights: shape (heads, B, RF, RF) [RF tokens attend to RF tokens] +# final_features, att_weights = att_outputs +# att_weights = np.array(att_weights) # ensure numpy array +# +# # Average attention weights over heads +# # New shape: (B, RF, RF) +# att_avg = np.mean(att_weights, axis=0) +# +# # For each sample and each row in the RF dimension, determine the index that got the highest attention. +# print("\nAttention analysis per sample and per RF row:") +# for sample in range(batch): +# print(f"\nSample {sample + 1}:") +# for i in range(RF_max): +# att_row = att_avg[sample, i, :] +# # To ignore padded positions (if any), you might add a threshold. Here we just take argmax. +# max_ind = np.argmax(att_row) +# max_val = att_row[max_ind] +# print(f" RF row {i}: highest attention on row {max_ind} (value = {max_val:.4f})") +# +# # ----- Visualization ----- +# # For one sample (say sample 0), plot a heatmap of the averaged attention matrix. +# sample_to_plot = 0 +# plt.figure(figsize=(6, 5)) +# sns.heatmap(att_avg[sample_to_plot], annot=True, cmap="viridis", +# xticklabels=[f"r{j}" for j in range(RF_max)], +# yticklabels=[f"r{i}" for i in range(RF_max)]) +# plt.title("Averaged Attention Weights (over heads) for Sample 1") +# plt.xlabel("Attended RF row") +# plt.ylabel("Query RF row") +# plt.tight_layout() +# plt.savefig("model_output/attention_heatmap_sample1.png") +# plt.show() +# print("Attention heatmap saved to 'model_output/attention_heatmap_sample1.png'") + +# if __name__ == "__main__": +# main() diff --git a/utils/processing_functions.py b/utils/processing_functions.py index 6e471ab7..9ec3d379 100644 --- a/utils/processing_functions.py +++ b/utils/processing_functions.py @@ -16,6 +16,9 @@ warnings.filterwarnings("ignore") sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from user_setting import netmhcipan_path, netmhciipan_path, pmgen_abs_dir +from sklearn.model_selection import train_test_split +from sklearn.model_selection import GroupShuffleSplit, StratifiedShuffleSplit +from sklearn.utils import resample def process_dataframe(df): # Step 1: Sort IDs with 2 or 5 parts @@ -1027,7 +1030,7 @@ def run_netmhcpan(peptide_fasta, allele_list, output, mhc_type, final_allele += allele if 'DQB' in allele or 'DPB' in allele: final_allele += f'-{allele.replace("HLA-", "")}' - print(allele) + # print(allele) cmd = [str(netmhciipan_path), '-f', str(peptide_fasta), '-BA', '-u', '-s', '-length', '9,10,11,12,13,14,15,16,17,18', '-inptype', '0', '-a', str(final_allele)] @@ -1057,4 +1060,342 @@ def fetch_polypeptide_sequences(pdb_path): sequences[chain_id] = sequence return sequences +# leave out k ids, then create the first fold (train and validation) then add one ID and create fold 2 from fold 1 and added data and so on. +# the added data is sampled from the rest equal to total / k +def create_progressive_k_fold_cross_validation(df, k=5, target_col="Label", id_col="id", + random_state=42, save_to_disk=False, + output_dir=None): + """ + Creates k folds for cross-validation where each fold progressively adds IDs to training. + + Strategy: + 1. First fold leaves out k IDs for validation + 2. Each subsequent fold adds one ID to training and samples a new validation set + 3. Each fold maintains approximately equal size (total/k samples) + + Args: + df (pd.DataFrame): The input dataframe. + k (int): The number of folds. + target_col (str): The name of the column containing target labels. + id_col (str): The name of the column containing the IDs. + random_state (int): Random state for reproducibility. + save_to_disk (bool): Whether to save the folds to disk. + output_dir (str): Directory to save fold files if save_to_disk is True. + + Returns: + list: A list of tuples, where each tuple contains (train_df, val_df) for a fold. + """ + # Get unique IDs + unique_ids = df[id_col].unique() + + if len(unique_ids) < k: + raise ValueError(f"Not enough unique IDs ({len(unique_ids)}) for {k} folds") + + # Shuffle IDs + np.random.seed(random_state) + np.random.shuffle(unique_ids) + + # Calculate target size for each fold + target_fold_size = len(df) // k + + # Initialize list to store folds + folds = [] + + # Start with k IDs left out + validation_ids = unique_ids[:k] + training_ids = np.array([]) + + for fold_idx in range(k): + # For each fold, add one ID to training and get a new validation ID + if fold_idx > 0: + # Add one ID from validation to training + training_ids = np.append(training_ids, validation_ids[0]) + # Remove the first ID from validation + validation_ids = validation_ids[1:] + + # Get all data for current training and validation IDs + train_mask = df[id_col].isin(training_ids) + val_mask = df[id_col].isin([validation_ids[0]]) # Use only the first validation ID + + train_df = df[train_mask].copy() + val_df = df[val_mask].copy() + + # Add samples from remaining IDs to reach target fold size + remaining_ids = [id for id in unique_ids if id not in training_ids and id not in [validation_ids[0]]] + + if len(train_df) < target_fold_size and remaining_ids: + samples_needed = target_fold_size - len(train_df) + additional_mask = df[id_col].isin(remaining_ids) + additional_df = df[additional_mask].sample( + n=min(samples_needed, sum(additional_mask)), + random_state=random_state+fold_idx + ) + train_df = pd.concat([train_df, additional_df]) + + # Reset indices + train_df = train_df.reset_index(drop=True) + val_df = val_df.reset_index(drop=True) + + folds.append((train_df, val_df, validation_ids)) + + # Save folds to disk if requested + if save_to_disk and output_dir: + os.makedirs(output_dir, exist_ok=True) + train_df.to_parquet(f"{output_dir}/train_fold_{fold_idx}.parquet") + val_df.to_parquet(f"{output_dir}/val_fold_{fold_idx}.parquet") + + return folds + + +# write a function that produces 5 fold cross validation from the df, each fold must have one allele to be left out completely and produce a stratified train and validation based on the label. +# take into account the subset proportion. +# def create_k_fold_leave_one_out_stratified_cross_validation(df, k=5, target_col="label", id_col="allele", +# subset_prop=1.0, train_size=0.8, random_state=42): +# """ +# Creates k folds for cross-validation where each fold: +# 1. Leaves one unique ID out completely +# 2. Has a validation set with at least one unique allele not present in the training set +# 3. Splits the remaining data into stratified train/validation sets +# 4. Balances both train and validation sets based on target labels +# +# Args: +# df (pd.DataFrame): The input dataframe. +# k (int): The number of folds. +# target_col (str): The name of the column containing the target labels for stratification. +# id_col (str): The name of the column containing the IDs (alleles) to leave out. +# subset_prop (float): The proportion of the remaining data to use. +# train_size (float): The proportion of the non-held-out data to use for training. +# random_state (int): Random state for reproducibility. +# +# Returns: +# list: A list of tuples, where each tuple contains (train_df, val_df, left_out_id) for a fold. +# """ +# # Take a subset of the dataframe +# df = df.sample(frac=subset_prop, random_state=random_state) +# +# # Get unique IDs (alleles) +# unique_ids = df[id_col].unique() +# +# # Check if k is larger than the number of unique IDs +# if k > len(unique_ids): +# raise ValueError(f"k must be less than or equal to the number of unique IDs ({len(unique_ids)}).") +# +# # Set random seed for reproducibility +# np.random.seed(random_state) +# +# # Randomly select k IDs to leave out (one per fold) +# selected_ids = np.random.choice(unique_ids, k, replace=False) +# +# folds = [] +# for leave_out_id in selected_ids: +# print("leave out ID:", leave_out_id) +# # Create a mask for the ID to leave out +# leave_out_mask = df[id_col] == leave_out_id +# +# # Get remaining data (exclude the left out ID) +# remaining_df = df[~leave_out_mask].copy() +# +# # Get remaining unique IDs +# remaining_unique_ids = remaining_df[id_col].unique() +# +# # Select a validation-only ID (will only appear in validation set) +# val_only_id = np.random.choice(remaining_unique_ids, 1)[0] +# +# # Split data for unique validation ID and the rest +# val_only_mask = remaining_df[id_col] == val_only_id +# val_only_df = remaining_df[val_only_mask].copy() +# train_eligible_df = remaining_df[~val_only_mask].copy() +# +# # Create stratified split on the training-eligible data +# train_df, extra_val_df = train_test_split( +# train_eligible_df, +# train_size=train_size, +# stratify=train_eligible_df[target_col], +# random_state=random_state +# ) +# +# # Combine validation-only data with extra validation data +# val_df = pd.concat([val_only_df, extra_val_df], ignore_index=True) +# +# # Balance training set +# train_class_counts = train_df[target_col].value_counts() +# min_train_count = train_class_counts.min() +# balanced_train_dfs = [] +# +# for label in train_class_counts.index: +# label_df = train_df[train_df[target_col] == label] +# # Downsample to the minority class count +# balanced_label_df = label_df.sample(min_train_count, random_state=random_state) +# balanced_train_dfs.append(balanced_label_df) +# +# balanced_train_df = pd.concat(balanced_train_dfs, ignore_index=True) +# +# # Balance validation set +# val_class_counts = val_df[target_col].value_counts() +# min_val_count = val_class_counts.min() +# balanced_val_dfs = [] +# +# for label in val_class_counts.index: +# label_df = val_df[val_df[target_col] == label] +# # Downsample to the minority class count +# balanced_label_df = label_df.sample(min_val_count, random_state=random_state) +# balanced_val_dfs.append(balanced_label_df) +# +# balanced_val_df = pd.concat(balanced_val_dfs, ignore_index=True) +# +# # Include the left_out_id in the tuple +# folds.append((balanced_train_df, balanced_val_df, leave_out_id)) +# +# return folds + + +def create_k_fold_leave_one_out_stratified_cv( + df: pd.DataFrame, + k: int = 5, + target_col: str = "label", + id_col: str = "allele", + subset_prop: float = 1.0, + train_size: float = 0.8, + random_state: int = 42, + augmentation: str = None # "down_sampling" or "GNUSS" +): + """ + Build *k* folds such that + + 1. **One whole ID (group) is left out of both train & val** (`left_out_id`). + 2. **Validation contains exactly one additional ID** (`val_only_id`) + that never appears in train. + 3. Remaining rows are split *stratified* on `target_col` + (`train_size` fraction for training). + 4. Train & val are **down-sampled** to perfectly balanced label counts. + + Returns + ------- + list[tuple[pd.DataFrame, pd.DataFrame, Hashable]] + Each tuple = (train_df, val_df, left_out_id). + """ + rng = np.random.RandomState(random_state) + if subset_prop < 1.0: + if subset_prop <= 0.0 or subset_prop > 1.0: + raise ValueError(f"subset_prop must be in (0, 1], got {subset_prop}") + # Take a random subset of the DataFrame + print(f"Taking {subset_prop * 100:.2f}% of the data for k-fold CV") + df = df.sample(frac=subset_prop, random_state=random_state).reset_index(drop=True) + + # --- pick the k IDs that will be held out completely ------------------- + unique_ids = df[id_col].unique() + if k > len(unique_ids): + raise ValueError(f"k={k} > unique {id_col} count ({len(unique_ids)})") + left_out_ids = rng.choice(unique_ids, size=k, replace=False) + + folds = [] + for fold_idx, left_out_id in enumerate(left_out_ids, 1): + fold_seed = random_state + fold_idx + mask_left_out = df[id_col] == left_out_id + working_df = df.loc[~mask_left_out].copy() + + # --------------------------------------------------------------- + # 1) choose ONE id that will appear *only* in validation + # (GroupShuffleSplit with test_size=1 group) + # --------------------------------------------------------------- + gss = GroupShuffleSplit( + n_splits=1, test_size=1, random_state=fold_seed + ) + (train_groups_idx, val_only_groups_idx), = gss.split( + X=np.zeros(len(working_df)), y=None, groups=working_df[id_col] + ) + val_only_group_id = working_df.iloc[val_only_groups_idx][id_col].unique()[0] + + mask_val_only = working_df[id_col] == val_only_group_id + df_val_only = working_df[mask_val_only] + df_eligible = working_df[~mask_val_only] + + # --------------------------------------------------------------- + # 2) stratified split of *eligible* rows + # --------------------------------------------------------------- + sss = StratifiedShuffleSplit( + n_splits=1, train_size=train_size, random_state=fold_seed + ) + train_idx, extra_val_idx = next( + sss.split(df_eligible, df_eligible[target_col]) + ) + df_train = df_eligible.iloc[train_idx] + df_val = pd.concat( + [df_val_only, df_eligible.iloc[extra_val_idx]], ignore_index=True + ) + + print(f"Fold size: train={len(df_train)}, val={len(df_val)} | ") + + # --------------------------------------------------------------- + # 3) balance train and val via down-sampling + # --------------------------------------------------------------- + def _balance_down_sampling(frame: pd.DataFrame) -> pd.DataFrame: + min_count = frame[target_col].value_counts().min() + print(f"Balancing {len(frame)} rows to {min_count} per class") + balanced_parts = [ + resample( + frame[frame[target_col] == lbl], + replace=False, + n_samples=min_count, + random_state=fold_seed, + ) + for lbl in frame[target_col].unique() + ] + return pd.concat(balanced_parts, ignore_index=True) + + def _balance_GNUSS(frame: pd.DataFrame) -> pd.DataFrame: + """ + Balance the DataFrame by upsampling the minority class with Gaussian noise. + """ + # Determine label counts and the maximum class size + counts = frame[target_col].value_counts() + max_count = counts.max() + + # Identify numeric columns for noise injection + numeric_cols = frame.select_dtypes(include="number").columns + + balanced_parts = [] + for label, count in counts.items(): + df_label = frame[frame[target_col] == label] + balanced_parts.append(df_label) + if count < max_count: + # Upsample with replacement + n_needed = max_count - count + sampled = df_label.sample(n=n_needed, replace=True, random_state=fold_seed) + # Add Gaussian noise to numeric features + noise = pd.DataFrame( + rng.normal(loc=0, scale=1e-6, size=(n_needed, len(numeric_cols))), + columns=numeric_cols, + index=sampled.index + ) + sampled[numeric_cols] = sampled[numeric_cols] + noise + balanced_parts.append(sampled) + + # Combine and return + return pd.concat(balanced_parts).reset_index(drop=True) + + if augmentation == "GNUSS": + df_train_bal = _balance_GNUSS(df_train) + df_val_bal = _balance_GNUSS(df_val) + elif augmentation == "down_sampling": # default to down-sampling + df_train_bal = _balance_down_sampling(df_train) + df_val_bal = _balance_down_sampling(df_val) + elif not augmentation: + df_train_bal = df_train.copy() + df_val_bal = df_val.copy() + print("No augmentation applied, using original train and val sets.") + else: + raise ValueError(f"Unknown augmentation method: {augmentation}") + + # Shuffle both datasets to avoid any ordering bias + df_train_bal = df_train_bal.sample(frac=1.0, random_state=fold_seed).reset_index(drop=True) + df_val_bal = df_val_bal.sample(frac=1.0, random_state=fold_seed).reset_index(drop=True) + folds.append((df_train_bal, df_val_bal, left_out_id)) + + print( + f"[fold {fold_idx}/{k}] left-out={left_out_id} | " + f"val-only={val_only_group_id} | " + f"train={len(df_train_bal)}, val={len(df_val_bal)}" + ) + return folds diff --git a/utils/run_pMHC_DL_ESM2.py b/utils/run_pMHC_DL_ESM2.py new file mode 100644 index 00000000..c75353ca --- /dev/null +++ b/utils/run_pMHC_DL_ESM2.py @@ -0,0 +1,737 @@ +#!/usr/bin/env python +""" +========================= + +MEMORY-OPTIMIZED End‑to‑end trainer for a **peptide×MHC cross‑attention classifier**. +Loads NetMHCpan‑style parquet files in true streaming fashion without loading entire datasets into memory. + +Key improvements: +1. Streaming parquet reading with configurable batch sizes +2. Lazy evaluation of dataset properties (seq length, class balance) +3. Memory-efficient TensorFlow data pipelines +4. Proper cleanup and memory monitoring + +Author: Amirreza (memory-optimized version, 2025) +""" +from __future__ import annotations +import os +import sys + +print(sys.executable) +# ============================================================================= +# CRITICAL: GPU Memory Configuration - MUST BE FIRST +# ============================================================================= +import tensorflow as tf + + +def configure_gpu_memory(): + """Configure TensorFlow to use GPU memory efficiently""" + try: + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + print(f"Found {len(gpus)} GPU(s)") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + print("✓ GPU memory growth enabled") + else: + print("No GPUs found - running on CPU") + except RuntimeError as e: + print(f"GPU configuration error: {e}") + + +# Configure GPU immediately +configure_gpu_memory() + +# --------------------------------------------------------------------- +# ► Use all logical CPU cores for TF ops that still run on CPU +# --------------------------------------------------------------------- +NUM_CPUS = os.cpu_count() or 1 +tf.config.threading.set_intra_op_parallelism_threads(NUM_CPUS) +tf.config.threading.set_inter_op_parallelism_threads(NUM_CPUS) +print(f'✓ TF intra/inter-op threads set to {NUM_CPUS}') + +# Set memory-friendly environment variables +os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async' +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' +os.environ["PYTHONHASHSEED"] = "42" +os.environ["TF_DETERMINISTIC_OPS"] = "1" + +import math +import argparse, datetime, pathlib, json +import psutil +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from model_archive import build_custom_classifier +from sklearn.metrics import ( + confusion_matrix, roc_curve, auc, precision_score, + recall_score, f1_score, accuracy_score, roc_auc_score +) +import seaborn as sns +import pyarrow.parquet as pq +import gc +import weakref +import pyarrow as pa, pyarrow.compute as pc +pa.set_cpu_count(os.cpu_count()) + + +# ============================================================================= +# Memory monitoring functions +# ============================================================================= +def monitor_memory(): + """Monitor system memory usage""" + memory = psutil.virtual_memory() + print(f"System RAM: {memory.used / 1e9:.1f}GB / {memory.total / 1e9:.1f}GB ({memory.percent:.1f}% used)") + + try: + from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo + nvmlInit() + deviceCount = nvmlDeviceGetCount() + for i in range(deviceCount): + handle = nvmlDeviceGetHandleByIndex(i) + info = nvmlDeviceGetMemoryInfo(handle) + print( + f"GPU {i}: {info.used / 1e9:.1f}GB / {info.total / 1e9:.1f}GB ({100 * info.used / info.total:.1f}% used)") + except: + print("GPU memory monitoring not available") + + +def cleanup_memory(): + """Aggressive memory cleanup""" + gc.collect() + try: + tf.keras.backend.clear_session() + except: + pass + + + +# ---------------------------------------------------------------------------- +# 1) Peptide → k‑mer one‑hot *inside* TF graph (GPU/TPU friendly) +# ---------------------------------------------------------------------------- +pad_token = 20 + +def peptides_to_onehot_kmer_windows(seq, max_seq_len, k=9): + """ + Converts a peptide sequence into a sliding window of k-mers, one-hot encoded. + Output shape: (RF, k, 21), where RF = max_seq_len - k + 1 + """ + AA = "ACDEFGHIKLMNPQRSTVWY" + aa_to_idx = {aa: i for i, aa in enumerate(AA)} + RF = max_seq_len - k + 1 + RFs = np.zeros((RF, k, 21), dtype=np.float32) + for window in range(RF): + if window + k <= len(seq): + kmer = seq[window:window + k] + for i, aa in enumerate(kmer): + idx = aa_to_idx.get(aa, pad_token) + RFs[window, i, idx] = 1.0 + # Pad remaining positions in k-mer if sequence is too short + for i in range(len(kmer), k): + RFs[window, i, pad_token] = 1.0 + else: + # Entire k-mer is padding if out of sequence + RFs[window, :, pad_token] = 1.0 + return np.array(RFs) + + +def _parquet_rowcount(parquet_path: str | os.PathLike) -> int: + return pq.ParquetFile(parquet_path).metadata.num_rows + + +def _read_embedding_file(path: str | os.PathLike) -> np.ndarray: + # Try fast numeric path first + try: + arr = np.load(path) + if isinstance(arr, np.ndarray) and arr.dtype == np.float32: + return arr + raise ValueError + except ValueError: + obj = np.load(path, allow_pickle=True) + if isinstance(obj, np.ndarray) and obj.dtype == object: + obj = obj.item() + if isinstance(obj, dict) and "embedding" in obj: + return obj["embedding"].astype("float32") + raise ValueError(f"Unrecognised embedding file {path}") + +# ---------------------------------------------------------------------------- +# Streaming dataset utilities +# ---------------------------------------------------------------------------- +class StreamingParquetReader: + """Memory-efficient streaming parquet reader""" + + def __init__(self, parquet_path: str, batch_size: int = 1000): + self.parquet_path = parquet_path + self.batch_size = batch_size + self._file = None + self._num_rows = None + + def __enter__(self): + self._file = pq.ParquetFile(self.parquet_path) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._file: + self._file = None + + @property + def num_rows(self): + """Get total number of rows without loading data""" + if self._num_rows is None: + if self._file is None: + with pq.ParquetFile(self.parquet_path) as f: + self._num_rows = f.metadata.num_rows + else: + self._num_rows = self._file.metadata.num_rows + return self._num_rows + + def iter_batches(self): + """Iterate over parquet file in batches""" + if self._file is None: + raise RuntimeError("Reader not opened. Use within 'with' statement.") + + for batch in self._file.iter_batches(batch_size=self.batch_size): + df = batch.to_pandas() + yield df + del df, batch # Explicit cleanup + + def sample_for_metadata(self, n_samples: int = 1000): + """Sample a small portion for metadata extraction""" + with pq.ParquetFile(self.parquet_path) as f: + # Read first batch for metadata + first_batch = next(f.iter_batches(batch_size=min(n_samples, self.num_rows))) + return first_batch.to_pandas() + + +def get_dataset_metadata(parquet_path: str): + """Extract dataset metadata without loading full dataset""" + with StreamingParquetReader(parquet_path) as reader: + sample_df = reader.sample_for_metadata(reader.num_rows) + + metadata = { + 'total_rows': reader.num_rows, + 'max_peptide_length': int(sample_df['long_mer'].str.len().max()) if 'long_mer' in sample_df.columns else 0, + 'class_distribution': sample_df[ + 'assigned_label'].value_counts().to_dict() if 'assigned_label' in sample_df.columns else {}, + } + + del sample_df + return metadata + + +def calculate_class_weights(parquet_path: str): + """Calculate class weights from a sample of the dataset""" + with StreamingParquetReader(parquet_path, batch_size=1000) as reader: + label_counts = {0: 0, 1: 0} + for batch_df in reader.iter_batches(): + batch_labels = batch_df['assigned_label'].values + unique, counts = np.unique(batch_labels, return_counts=True) + for label, count in zip(unique, counts): + if label in [0, 1]: + label_counts[int(label)] += count + del batch_df + + # Calculate balanced class weights + total = sum(label_counts.values()) + if total == 0 or label_counts[0] == 0 or label_counts[1] == 0: + return {0: 1.0, 1: 1.0} + + return { + 0: total / (2 * label_counts[0]), + 1: total / (2 * label_counts[1]) + } + +# --------------------------------------------------------------------- +# Utility that is executed in worker processes +# (must be top-level so it can be pickled on Windows) +# --------------------------------------------------------------------- +def _row_to_tensor_pack(row_dict: dict, max_pep_seq_len: int, max_mhc_len: int): + """Convert a single row (already in plain-python dict form) into tensors.""" + # --- peptide one-hot ------------------------------------------------ + pep = row_dict["long_mer"].upper()[:max_pep_seq_len] + pep_arr = peptides_to_onehot_kmer_windows( + pep, max_seq_len=max_pep_seq_len, k=9 + ) # shape: (RF, k, 21) + + # --- load MHC embedding -------------------------------------------- + mhc = _read_embedding_file(row_dict["mhc_embedding_path"]) + if mhc.shape[0] != max_mhc_len: # sanity check + raise ValueError(f"MHC length mismatch: {mhc.shape[0]} vs {max_mhc_len}") + + # --- label ---------------------------------------------------------- + label = float(row_dict["assigned_label"]) + return (pep_arr, mhc.astype("float32")), label + +from concurrent.futures import ProcessPoolExecutor +import functools, itertools + +def streaming_data_generator( + parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 1000): + """ + Yields *individual* samples, but converts an entire Parquet batch + on multiple CPU cores first. + """ + with StreamingParquetReader(parquet_path, batch_size) as reader, \ + ProcessPoolExecutor(max_workers=os.cpu_count()) as pool: + + # Partial function to avoid re-sending constants + worker_fn = functools.partial( + _row_to_tensor_pack, + max_pep_seq_len=max_pep_seq_len, + max_mhc_len=max_mhc_len, + ) + + for batch_df in reader.iter_batches(): + # Convert Arrow table → list[dict] once; avoids pandas overhead + dict_rows = batch_df.to_dict(orient="list") # columns -> python lists + # Re-shape to list[dict(row)] + rows_iter = ( {k: dict_rows[k][i] for k in dict_rows} # row dict + for i in range(len(batch_df)) ) + + # Parallel map; chunksize tuned for large batches + results = pool.map(worker_fn, rows_iter, chunksize=64) + + # Stream each converted sample back to the generator consumer + yield from results # <-- keeps memory footprint tiny + + # explicit clean-up + del batch_df, dict_rows, rows_iter, results + + +def create_streaming_dataset(parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 128, + buffer_size: int = 1000, + k: int = 9): + """ + Same semantics as before, but the generator already does parallel + preprocessing. We now ask tf.data to interleave multiple generator + shards in parallel as well. + """ + RF = max_pep_seq_len - k + 1 + output_signature = ( + ( + tf.TensorSpec(shape=(RF, k, 21), dtype=tf.float32), + tf.TensorSpec(shape=(max_mhc_len, 1152), dtype=tf.float32), + ), + tf.TensorSpec(shape=(), dtype=tf.float32), + ) + + ds = tf.data.Dataset.from_generator( + lambda: streaming_data_generator( + parquet_path, + max_pep_seq_len, + max_mhc_len, + buffer_size), + output_signature=output_signature, + ) + + # ► Parallel interleave gives another speed-up if the Parquet file has + # many row-groups – adjust cycle_length as needed. + ds = ds.interleave( + lambda x, y: tf.data.Dataset.from_tensors((x, y)), + cycle_length=tf.data.AUTOTUNE, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + + return ds + +# --------------------------------------------------------------------------- +# Visualisation utility +# --------------------------------------------------------------------------- + +def plot_training_curve(history: tf.keras.callbacks.History, run_dir: str, fold_id: int = None, + model=None, val_dataset=None): + """Plot training curves and validation metrics""" + hist = history.history + plt.figure(figsize=(21, 6)) + plot_name = f"training_curve{'_fold' + str(fold_id) if fold_id is not None else ''}" + + plt.suptitle(f"Training Curves{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # Plot 1: Loss curve + plt.subplot(1, 4, 1) + plt.plot(hist["loss"], label="train", linewidth=2) + plt.plot(hist["val_loss"], label="val", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("BCE Loss") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 2: Accuracy curve + if "binary_accuracy" in hist and "val_binary_accuracy" in hist: + plt.subplot(1, 4, 2) + plt.plot(hist["binary_accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_binary_accuracy"], label="val acc", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + plt.title("Binary Accuracy") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 3: AUC curve + if "AUC" in hist and "val_AUC" in hist: + plt.subplot(1, 4, 3) + plt.plot(hist["AUC"], label="train AUC", linewidth=2) + plt.plot(hist["val_AUC"], label="val AUC", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("AUC") + plt.title("AUC") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 4: Confusion matrix placeholder + plt.subplot(1, 4, 4) + if model is not None and val_dataset is not None: + # Sample a subset for confusion matrix to avoid memory issues + sample_dataset = val_dataset.take(100) # Take only 100 batches + y_pred_proba = model.predict(sample_dataset, verbose=0) + y_pred = (y_pred_proba > 0.5).astype(int) + + y_true = [] + for _, labels in sample_dataset: + y_true.extend(labels.numpy()) + y_true = np.array(y_true) + + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix (Sample)') + else: + plt.text(0.5, 0.5, 'Confusion Matrix N/A \n(Sample from validation)', + ha='center', va='center', transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) + plt.axis('off') + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{plot_name}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Training curve saved to {out_png}") + + +def plot_test_metrics(model, test_dataset, run_dir: str, fold_id: int = None, + history=None, string: str = None): + """Plot comprehensive evaluation metrics for test dataset""" + print("Generating predictions for test metrics...") + + # Collect predictions and labels in batches to avoid memory issues + y_true_list = [] + y_pred_proba_list = [] + + for batch_x, batch_y in test_dataset: + batch_pred = model.predict(batch_x, verbose=0) + y_true_list.append(batch_y.numpy()) + y_pred_proba_list.append(batch_pred.flatten()) + + y_true = np.concatenate(y_true_list).flatten() + y_pred_proba = np.concatenate(y_pred_proba_list) + y_pred = (y_pred_proba > 0.5).astype(int) + + # Calculate ROC curve + fpr, tpr, _ = roc_curve(y_true, y_pred_proba) + roc_auc = auc(fpr, tpr) + + # Create evaluation plot + plt.figure(figsize=(15, 10)) + plt.suptitle(f"{string} Evaluation Metrics{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # ROC Curve + plt.subplot(2, 2, 1) + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.8) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('ROC Curve') + plt.legend(loc="lower right") + plt.grid(True, alpha=0.3) + + # Confusion Matrix + plt.subplot(2, 2, 2) + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix') + plt.xlabel('Predicted Label') + plt.ylabel('True Label') + + # Metrics bar chart + plt.subplot(2, 2, 3) + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, zero_division=0) + recall = recall_score(y_true, y_pred, zero_division=0) + f1 = f1_score(y_true, y_pred, zero_division=0) + + metrics = {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1, 'AUC': roc_auc} + bars = plt.bar(range(len(metrics)), list(metrics.values()), + color=['skyblue', 'lightcoral', 'lightgreen', 'gold', 'orange'], + alpha=0.8, edgecolor='black', linewidth=1) + plt.xticks(range(len(metrics)), list(metrics.keys()), rotation=45, ha='right') + plt.ylim(0, 1.0) + plt.title('Evaluation Metrics') + plt.ylabel('Score') + plt.grid(True, alpha=0.3, axis='y') + + for bar, value in zip(bars, metrics.values()): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, + f'{value:.3f}', ha='center', va='bottom', fontweight='bold') + + # Prediction distribution + plt.subplot(2, 2, 4) + plt.hist(y_pred_proba[y_true == 0], bins=30, alpha=0.7, label='Negative Class', + color='red', density=True) + plt.hist(y_pred_proba[y_true == 1], bins=30, alpha=0.7, label='Positive Class', + color='blue', density=True) + plt.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Threshold') + plt.xlabel('Predicted Probability') + plt.ylabel('Density') + plt.title('Prediction Distribution') + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{string}_metrics{'_fold' + str(fold_id) if fold_id is not None else ''}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Test metrics visualization saved to {out_png}") + + # Print summary + print("\n" + "=" * 50) + print("EVALUATION SUMMARY") + print("=" * 50) + print(f"Accuracy: {accuracy:.4f}") + print(f"Precision: {precision:.4f}") + print(f"Recall: {recall:.4f}") + print(f"F1 Score: {f1:.4f}") + print(f"ROC AUC: {roc_auc:.4f}") + print("=" * 50) + + return { + 'roc_auc': roc_auc, 'accuracy': accuracy, 'precision': precision, + 'recall': recall, 'f1': f1, 'confusion_matrix': cm.tolist() + } + +# --------------------------------------------------------------------------- +# Training script +# --------------------------------------------------------------------------- + +def main(argv=None): + p = argparse.ArgumentParser() + p.add_argument("--dataset_path", required=True, + help="Path to the dataset directory") + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch", type=int, default=128) + p.add_argument("--outdir", default=None, + help="Output dir (default: runs/run_YYYYmmdd-HHMMSS)") + p.add_argument("--buffer_size", type=int, default=1000, + help="Buffer size for streaming data loading") + p.add_argument("--test_batches", type=int, default=None, + help="Number of batches to use for test datasets (default: 30)") + + args = p.parse_args(argv) + + run_dir = args.outdir or f"runs/run_{datetime.datetime.now():%Y%m%d-%H%M%S}" + pathlib.Path(run_dir).mkdir(parents=True, exist_ok=True) + print(f"★ Outputs → {run_dir}\n") + + # Set seeds for reproducibility + tf.random.set_seed(42) + np.random.seed(42) + print("Setting random seeds for reproducibility...") + + print("Initial memory state:") + monitor_memory() + + # Extract metadata from datasets without loading them fully + print("Extracting dataset metadata...") + + # Get fold information + fold_dir = os.path.join(args.dataset_path, 'folds') + fold_files = sorted([f for f in os.listdir(fold_dir) if f.endswith('.parquet')]) + n_folds = len(fold_files) // 2 + + # Find maximum peptide length across all datasets + max_peptide_length = 0 + max_mhc_length = 36 # Fixed for now + + print("Scanning datasets for maximum peptide length...") + all_parquet_files = [ + os.path.join(args.dataset_path, "test1.parquet"), + os.path.join(args.dataset_path, "test2.parquet") + ] + + # Add fold files + for i in range(1, n_folds + 1): + all_parquet_files.extend([ + os.path.join(fold_dir, f'fold_{i}_train.parquet'), + os.path.join(fold_dir, f'fold_{i}_val.parquet') + ]) + + for pq_file in all_parquet_files: + if os.path.exists(pq_file): + metadata = get_dataset_metadata(pq_file) + max_peptide_length = max(max_peptide_length, metadata['max_peptide_length']) + print( + f" {os.path.basename(pq_file)}: max_len={metadata['max_peptide_length']}, rows={metadata['total_rows']}") + + print(f"✓ Maximum peptide length across all datasets: {max_peptide_length}") + + # Create fold datasets and class weights + folds = [] + class_weights = [] + + for i in range(1, n_folds + 1): + print(f"\nProcessing fold {i}/{n_folds}") + train_path = os.path.join(fold_dir, f'fold_{i}_train.parquet') + val_path = os.path.join(fold_dir, f'fold_{i}_val.parquet') + + # Calculate class weights from training data + print(f" Calculating class weights...") + cw = calculate_class_weights(train_path) + print(f" Class weights: {cw}") + + # Create streaming datasets + train_ds = (create_streaming_dataset(train_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + .shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True) + .batch(args.batch) + # .take(args.test_batches) # Limit to buffer size for memory efficiency + .prefetch(tf.data.AUTOTUNE)) + + val_ds = (create_streaming_dataset(val_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + .batch(args.batch) + # .take(args.test_batches) # Limit to buffer size for memory efficiency + .prefetch(tf.data.AUTOTUNE)) + + folds.append((train_ds, val_ds)) + class_weights.append(cw) + + # Force cleanup + cleanup_memory() + + # Create test datasets + print("Creating test datasets...") + test1_ds = (create_streaming_dataset(os.path.join(args.dataset_path, "test1.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + test2_ds = (create_streaming_dataset(os.path.join(args.dataset_path, "test2.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + print(f"✓ Created {n_folds} fold datasets and 2 test datasets") + print("Memory after dataset creation:") + monitor_memory() + + # Training loop + print("\n" + "=" * 60) + print("STARTING TRAINING") + print("=" * 60) + + for fold_id, ((train_loader, val_loader), class_weight) in enumerate(zip(folds, class_weights), start=1): + print(f'\n🔥 Training fold {fold_id}/{n_folds}') + + # Clean up before each fold + cleanup_memory() + + # Build fresh model for each fold + print("Building model...") + model = build_custom_classifier(max_peptide_length, max_mhc_length) + model.summary() + + # Callbacks + ckpt_cb = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(run_dir, f'best_fold_{fold_id}.weights.h5'), + monitor='val_loss', save_best_only=True, mode='min', verbose=1) + early_cb = tf.keras.callbacks.EarlyStopping( + monitor='val_loss', patience=15, restore_best_weights=True, verbose=1) + + # Verify data shapes + for (x_pep, latents), labels in train_loader.take(1): + print(f"✓ Input shapes: peptide={x_pep.shape}, mhc={latents.shape}, labels={labels.shape}") + break + + print("Memory before training:") + monitor_memory() + + # Train model + print("🚀 Starting training...") + hist = model.fit( + train_loader, + validation_data=val_loader, + epochs=args.epochs, + class_weight=class_weight, + callbacks=[ckpt_cb, early_cb], + verbose=1, + ) + + print("Memory after training:") + monitor_memory() + + # Plot training curves + plot_training_curve(hist, run_dir, fold_id, model, val_loader) + + # Save model and metadata + model.save_weights(os.path.join(run_dir, f'model_fold_{fold_id}.weights.h5')) + metadata = { + "fold_id": fold_id, + "epochs": args.epochs, + "batch_size": args.batch, + "max_peptide_length": max_peptide_length, + "max_mhc_length": max_mhc_length, + "class_weights": class_weight, + "run_dir": run_dir, + "mhc_class": MHC_CLASS, + } + with open(os.path.join(run_dir, f'metadata_fold_{fold_id}.json'), 'w') as f: + json.dump(metadata, f, indent=4) + + # Evaluate on test sets + print(f"\n📊 Evaluating fold {fold_id} on test sets...") + + # Test1 evaluation + print("Evaluating on test1 (balanced alleles)...") + plot_test_metrics(model, test1_ds, run_dir, fold_id, string="Test1_balanced_alleles") + + # Test2 evaluation + print("Evaluating on test2 (rare alleles)...") + plot_test_metrics(model, test2_ds, run_dir, fold_id, string="Test2_rare_alleles") + + print(f"✅ Fold {fold_id} completed successfully") + + # Cleanup + del model, hist + cleanup_memory() + + print("\n🎉 Training completed successfully!") + print(f"📁 All results saved to: {run_dir}") + + +if __name__ == "__main__": + BUFFER = 8192 # Reduced buffer size for memory efficiency + MHC_CLASS = 1 + dataset_path = f"../data/Custom_dataset/NetMHCpan_dataset/mhc_{MHC_CLASS}" + main([ + "--dataset_path", dataset_path, + "--epochs", "10", + "--batch", "8192", + "--buffer_size", "8192", + ]) \ No newline at end of file diff --git a/utils/run_pMHC_DL_ESM3.py b/utils/run_pMHC_DL_ESM3.py new file mode 100644 index 00000000..9abd7ced --- /dev/null +++ b/utils/run_pMHC_DL_ESM3.py @@ -0,0 +1,716 @@ +#!/usr/bin/env python +""" +========================= + +MEMORY-OPTIMIZED End‑to‑end trainer for a **peptide×MHC cross‑attention classifier**. +Loads NetMHCpan‑style parquet files in true streaming fashion without loading entire datasets into memory. + +Key improvements: +1. Streaming parquet reading with configurable batch sizes +2. Lazy evaluation of dataset properties (seq length, class balance) +3. Memory-efficient TensorFlow data pipelines +4. Proper cleanup and memory monitoring + +Author: Amirreza (memory-optimized version, 2025) +""" +from __future__ import annotations +import os +import sys + +print(sys.executable) + +# ============================================================================= +# CRITICAL: GPU Memory Configuration - MUST BE FIRST +# ============================================================================= +import tensorflow as tf + + +def configure_gpu_memory(): + """Configure TensorFlow to use GPU memory efficiently""" + try: + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + print(f"Found {len(gpus)} GPU(s)") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + print("✓ GPU memory growth enabled") + else: + print("No GPUs found - running on CPU") + except RuntimeError as e: + print(f"GPU configuration error: {e}") + + +# Configure GPU immediately +configure_gpu_memory() + +# --------------------------------------------------------------------- +# ► Use all logical CPU cores for TF ops that still run on CPU +# --------------------------------------------------------------------- +NUM_CPUS = os.cpu_count() or 1 +tf.config.threading.set_intra_op_parallelism_threads(NUM_CPUS) +tf.config.threading.set_inter_op_parallelism_threads(NUM_CPUS) +print(f'✓ TF intra/inter-op threads set to {NUM_CPUS}') + +# Set memory-friendly environment variables +os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async' +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' +os.environ["PYTHONHASHSEED"] = "42" +os.environ["TF_DETERMINISTIC_OPS"] = "1" + +import math +import argparse, datetime, pathlib, json +import psutil +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from model import build_classifier +from sklearn.metrics import ( + confusion_matrix, roc_curve, auc, precision_score, + recall_score, f1_score, accuracy_score, roc_auc_score +) +import seaborn as sns +import pyarrow.parquet as pq +import gc +import weakref +import pyarrow as pa, pyarrow.compute as pc +pa.set_cpu_count(os.cpu_count()) + + +# ============================================================================= +# Memory monitoring functions +# ============================================================================= +def monitor_memory(): + """Monitor system memory usage""" + memory = psutil.virtual_memory() + print(f"System RAM: {memory.used / 1e9:.1f}GB / {memory.total / 1e9:.1f}GB ({memory.percent:.1f}% used)") + + try: + from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo + nvmlInit() + deviceCount = nvmlDeviceGetCount() + for i in range(deviceCount): + handle = nvmlDeviceGetHandleByIndex(i) + info = nvmlDeviceGetMemoryInfo(handle) + print( + f"GPU {i}: {info.used / 1e9:.1f}GB / {info.total / 1e9:.1f}GB ({100 * info.used / info.total:.1f}% used)") + except: + print("GPU memory monitoring not available") + + +def cleanup_memory(): + """Aggressive memory cleanup""" + gc.collect() + try: + tf.keras.backend.clear_session() + except: + pass + + +# ---------------------------------------------------------------------------- +# Peptide encoding utilities +# ---------------------------------------------------------------------------- +AA = "ACDEFGHIKLMNPQRSTVWY" # 20 standard AAs, order fixed +AA_TO_IDX = {aa: i for i, aa in enumerate(AA)} +UNK_IDX = 20 # index for unknown / padding + + +def peptides_to_onehot(sequence: str, max_seq_len: int) -> np.ndarray: + """Convert peptide sequence to one-hot encoding""" + arr = np.zeros((max_seq_len, 21), dtype=np.float32) + for j, aa in enumerate(sequence.upper()[:max_seq_len]): + arr[j, AA_TO_IDX.get(aa, UNK_IDX)] = 1.0 + return arr + + +def _read_embedding_file(path: str | os.PathLike) -> np.ndarray: + """Robust loader for latent embeddings""" + try: + arr = np.load(path) + if isinstance(arr, np.ndarray) and arr.dtype == np.float32: + return arr + raise ValueError + except ValueError: + obj = np.load(path, allow_pickle=True) + if isinstance(obj, np.ndarray) and obj.dtype == object: + obj = obj.item() + if isinstance(obj, dict) and "embedding" in obj: + return obj["embedding"].astype("float32") + raise ValueError(f"Unrecognised embedding file {path}") + + +# ---------------------------------------------------------------------------- +# Streaming dataset utilities +# ---------------------------------------------------------------------------- +class StreamingParquetReader: + """Memory-efficient streaming parquet reader""" + + def __init__(self, parquet_path: str, batch_size: int = 1000): + self.parquet_path = parquet_path + self.batch_size = batch_size + self._file = None + self._num_rows = None + + def __enter__(self): + self._file = pq.ParquetFile(self.parquet_path) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._file: + self._file = None + + @property + def num_rows(self): + """Get total number of rows without loading data""" + if self._num_rows is None: + if self._file is None: + with pq.ParquetFile(self.parquet_path) as f: + self._num_rows = f.metadata.num_rows + else: + self._num_rows = self._file.metadata.num_rows + return self._num_rows + + def iter_batches(self): + """Iterate over parquet file in batches""" + if self._file is None: + raise RuntimeError("Reader not opened. Use within 'with' statement.") + + for batch in self._file.iter_batches(batch_size=self.batch_size): + df = batch.to_pandas() + yield df + del df, batch # Explicit cleanup + + def sample_for_metadata(self, n_samples: int = 1000): + """Sample a small portion for metadata extraction""" + with pq.ParquetFile(self.parquet_path) as f: + # Read first batch for metadata + first_batch = next(f.iter_batches(batch_size=min(n_samples, self.num_rows))) + return first_batch.to_pandas() + + +def get_dataset_metadata(parquet_path: str): + """Extract dataset metadata without loading full dataset""" + with StreamingParquetReader(parquet_path) as reader: + sample_df = reader.sample_for_metadata(reader.num_rows) + + metadata = { + 'total_rows': reader.num_rows, + 'max_peptide_length': int(sample_df['long_mer'].str.len().max()) if 'long_mer' in sample_df.columns else 0, + 'class_distribution': sample_df[ + 'assigned_label'].value_counts().to_dict() if 'assigned_label' in sample_df.columns else {}, + } + + del sample_df + return metadata + + +def calculate_class_weights(parquet_path: str): + """Calculate class weights from a sample of the dataset""" + with StreamingParquetReader(parquet_path, batch_size=1000) as reader: + label_counts = {0: 0, 1: 0} + for batch_df in reader.iter_batches(): + batch_labels = batch_df['assigned_label'].values + unique, counts = np.unique(batch_labels, return_counts=True) + for label, count in zip(unique, counts): + if label in [0, 1]: + label_counts[int(label)] += count + del batch_df + + # Calculate balanced class weights + total = sum(label_counts.values()) + if total == 0 or label_counts[0] == 0 or label_counts[1] == 0: + return {0: 1.0, 1: 1.0} + + return { + 0: total / (2 * label_counts[0]), + 1: total / (2 * label_counts[1]) + } + + +# --------------------------------------------------------------------- +# Utility that is executed in worker processes +# (must be top-level so it can be pickled on Windows) +# --------------------------------------------------------------------- +def _row_to_tensor_pack(row_dict: dict, max_pep_seq_len: int, max_mhc_len: int): + """Convert a single row (already in plain-python dict form) into tensors.""" + # --- peptide one-hot ------------------------------------------------ + pep = row_dict["long_mer"].upper()[:max_pep_seq_len] + pep_arr = np.zeros((max_pep_seq_len, 21), dtype=np.float32) + for j, aa in enumerate(pep): + pep_arr[j, AA_TO_IDX.get(aa, UNK_IDX)] = 1.0 + + # --- load MHC embedding -------------------------------------------- + mhc = _read_embedding_file(row_dict["mhc_embedding_path"]) + if mhc.shape[0] != max_mhc_len: # sanity check + raise ValueError(f"MHC length mismatch: {mhc.shape[0]} vs {max_mhc_len}") + + # --- label ---------------------------------------------------------- + label = float(row_dict["assigned_label"]) + return (pep_arr, mhc.astype("float32")), label + +from concurrent.futures import ProcessPoolExecutor +import functools, itertools + +def streaming_data_generator( + parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 1000): + """ + Yields *individual* samples, but converts an entire Parquet batch + on multiple CPU cores first. + """ + with StreamingParquetReader(parquet_path, batch_size) as reader, \ + ProcessPoolExecutor(max_workers=os.cpu_count()) as pool: + + # Partial function to avoid re-sending constants + worker_fn = functools.partial( + _row_to_tensor_pack, + max_pep_seq_len=max_pep_seq_len, + max_mhc_len=max_mhc_len, + ) + + for batch_df in reader.iter_batches(): + # Convert Arrow table → list[dict] once; avoids pandas overhead + dict_rows = batch_df.to_dict(orient="list") # columns -> python lists + # Re-shape to list[dict(row)] + rows_iter = ( {k: dict_rows[k][i] for k in dict_rows} # row dict + for i in range(len(batch_df)) ) + + # Parallel map; chunksize tuned for large batches + results = pool.map(worker_fn, rows_iter, chunksize=64) + + # Stream each converted sample back to the generator consumer + yield from results # <-- keeps memory footprint tiny + + # explicit clean-up + del batch_df, dict_rows, rows_iter, results + + +def create_streaming_dataset(parquet_path: str, + max_pep_seq_len: int, + max_mhc_len: int, + batch_size: int = 128, + buffer_size: int = 1000): + """ + Same semantics as before, but the generator already does parallel + preprocessing. We now ask tf.data to interleave multiple generator + shards in parallel as well. + """ + output_signature = ( + ( + tf.TensorSpec(shape=(max_pep_seq_len, 21), dtype=tf.float32), + tf.TensorSpec(shape=(max_mhc_len, 1152), dtype=tf.float32), + ), + tf.TensorSpec(shape=(), dtype=tf.float32), + ) + + ds = tf.data.Dataset.from_generator( + lambda: streaming_data_generator( + parquet_path, + max_pep_seq_len, + max_mhc_len, + buffer_size), + output_signature=output_signature, + ) + + # ► Parallel interleave gives another speed-up if the Parquet file has + # many row-groups – adjust cycle_length as needed. + ds = ds.interleave( + lambda x, y: tf.data.Dataset.from_tensors((x, y)), + cycle_length=tf.data.AUTOTUNE, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + ) + + return ds + + +# ---------------------------------------------------------------------------- +# Visualization utilities (keeping the same as original) +# ---------------------------------------------------------------------------- +def plot_training_curve(history: tf.keras.callbacks.History, run_dir: str, fold_id: int = None, + model=None, val_dataset=None): + """Plot training curves and validation metrics""" + hist = history.history + plt.figure(figsize=(21, 6)) + plot_name = f"training_curve{'_fold' + str(fold_id) if fold_id is not None else ''}" + + plt.suptitle(f"Training Curves{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # Plot 1: Loss curve + plt.subplot(1, 4, 1) + plt.plot(hist["loss"], label="train", linewidth=2) + plt.plot(hist["val_loss"], label="val", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.title("BCE Loss") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 2: Accuracy curve + if "binary_accuracy" in hist and "val_binary_accuracy" in hist: + plt.subplot(1, 4, 2) + plt.plot(hist["binary_accuracy"], label="train acc", linewidth=2) + plt.plot(hist["val_binary_accuracy"], label="val acc", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("Accuracy") + plt.title("Binary Accuracy") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 3: AUC curve + if "AUC" in hist and "val_AUC" in hist: + plt.subplot(1, 4, 3) + plt.plot(hist["AUC"], label="train AUC", linewidth=2) + plt.plot(hist["val_AUC"], label="val AUC", linewidth=2) + plt.xlabel("Epoch") + plt.ylabel("AUC") + plt.title("AUC") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot 4: Confusion matrix placeholder + plt.subplot(1, 4, 4) + if model is not None and val_dataset is not None: + # Sample a subset for confusion matrix to avoid memory issues + sample_dataset = val_dataset.take(100) # Take only 100 batches + y_pred_proba = model.predict(sample_dataset, verbose=0) + y_pred = (y_pred_proba > 0.5).astype(int) + + y_true = [] + for _, labels in sample_dataset: + y_true.extend(labels.numpy()) + y_true = np.array(y_true) + + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix (100 Batches)') + else: + plt.text(0.5, 0.5, 'Confusion Matrix N/A \n(Sample from validation)', + ha='center', va='center', transform=plt.gca().transAxes, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) + plt.axis('off') + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{plot_name}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Training curve saved to {out_png}") + + +def plot_test_metrics(model, test_dataset, run_dir: str, fold_id: int = None, + history=None, string: str = None): + """Plot comprehensive evaluation metrics for test dataset""" + print("Generating predictions for test metrics...") + + # Collect predictions and labels in batches to avoid memory issues + y_true_list = [] + y_pred_proba_list = [] + + for batch_x, batch_y in test_dataset: + batch_pred = model.predict(batch_x, verbose=0) + y_true_list.append(batch_y.numpy()) + y_pred_proba_list.append(batch_pred.flatten()) + + y_true = np.concatenate(y_true_list).flatten() + y_pred_proba = np.concatenate(y_pred_proba_list) + y_pred = (y_pred_proba > 0.5).astype(int) + + # Calculate ROC curve + fpr, tpr, _ = roc_curve(y_true, y_pred_proba) + roc_auc = auc(fpr, tpr) + + # Create evaluation plot + plt.figure(figsize=(15, 10)) + plt.suptitle(f"{string} Evaluation Metrics{' (Fold ' + str(fold_id) + ')' if fold_id is not None else ''}", + fontsize=16, fontweight='bold') + + # ROC Curve + plt.subplot(2, 2, 1) + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.8) + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('ROC Curve') + plt.legend(loc="lower right") + plt.grid(True, alpha=0.3) + + # Confusion Matrix + plt.subplot(2, 2, 2) + cm = confusion_matrix(y_true, y_pred) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['Negative', 'Positive'], + yticklabels=['Negative', 'Positive']) + plt.title('Confusion Matrix') + plt.xlabel('Predicted Label') + plt.ylabel('True Label') + + # Metrics bar chart + plt.subplot(2, 2, 3) + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, zero_division=0) + recall = recall_score(y_true, y_pred, zero_division=0) + f1 = f1_score(y_true, y_pred, zero_division=0) + + metrics = {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1, 'AUC': roc_auc} + bars = plt.bar(range(len(metrics)), list(metrics.values()), + color=['skyblue', 'lightcoral', 'lightgreen', 'gold', 'orange'], + alpha=0.8, edgecolor='black', linewidth=1) + plt.xticks(range(len(metrics)), list(metrics.keys()), rotation=45, ha='right') + plt.ylim(0, 1.0) + plt.title('Evaluation Metrics') + plt.ylabel('Score') + plt.grid(True, alpha=0.3, axis='y') + + for bar, value in zip(bars, metrics.values()): + plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, + f'{value:.3f}', ha='center', va='bottom', fontweight='bold') + + # Prediction distribution + plt.subplot(2, 2, 4) + plt.hist(y_pred_proba[y_true == 0], bins=30, alpha=0.7, label='Negative Class', + color='red', density=True) + plt.hist(y_pred_proba[y_true == 1], bins=30, alpha=0.7, label='Positive Class', + color='blue', density=True) + plt.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Threshold') + plt.xlabel('Predicted Probability') + plt.ylabel('Density') + plt.title('Prediction Distribution') + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + os.makedirs(run_dir, exist_ok=True) + out_png = os.path.join(run_dir, f"{string}_metrics{'_fold' + str(fold_id) if fold_id is not None else ''}.png") + plt.savefig(out_png, dpi=300, bbox_inches='tight') + plt.close() # Close to free memory + print(f"✓ Test metrics visualization saved to {out_png}") + + # Print summary + print("\n" + "=" * 50) + print("EVALUATION SUMMARY") + print("=" * 50) + print(f"Accuracy: {accuracy:.4f}") + print(f"Precision: {precision:.4f}") + print(f"Recall: {recall:.4f}") + print(f"F1 Score: {f1:.4f}") + print(f"ROC AUC: {roc_auc:.4f}") + print("=" * 50) + + return { + 'roc_auc': roc_auc, 'accuracy': accuracy, 'precision': precision, + 'recall': recall, 'f1': f1, 'confusion_matrix': cm.tolist() + } + + +# ---------------------------------------------------------------------------- +# Main training function +# ---------------------------------------------------------------------------- +def main(argv=None): + p = argparse.ArgumentParser() + p.add_argument("--dataset_path", required=True, + help="Path to the dataset directory") + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch", type=int, default=128) + p.add_argument("--outdir", default=None, + help="Output dir (default: runs/run_YYYYmmdd-HHMMSS)") + p.add_argument("--buffer_size", type=int, default=1000, + help="Buffer size for streaming data loading") + + args = p.parse_args(argv) + + run_dir = args.outdir or f"runs/run_{datetime.datetime.now():%Y%m%d-%H%M%S}" + pathlib.Path(run_dir).mkdir(parents=True, exist_ok=True) + print(f"★ Outputs → {run_dir}\n") + + # Set seeds for reproducibility + tf.random.set_seed(42) + np.random.seed(42) + print("Setting random seeds for reproducibility...") + + print("Initial memory state:") + monitor_memory() + + # Extract metadata from datasets without loading them fully + print("Extracting dataset metadata...") + + # Get fold information + fold_dir = os.path.join(args.dataset_path, 'folds') + fold_files = sorted([f for f in os.listdir(fold_dir) if f.endswith('.parquet')]) + n_folds = len(fold_files) // 2 + + # Find maximum peptide length across all datasets + max_peptide_length = 0 + max_mhc_length = 36 # Fixed for now + + print("Scanning datasets for maximum peptide length...") + all_parquet_files = [ + os.path.join(args.dataset_path, "test1.parquet"), + os.path.join(args.dataset_path, "test2.parquet") + ] + + # Add fold files + for i in range(1, n_folds + 1): + all_parquet_files.extend([ + os.path.join(fold_dir, f'fold_{i}_train.parquet'), + os.path.join(fold_dir, f'fold_{i}_val.parquet') + ]) + + for pq_file in all_parquet_files: + if os.path.exists(pq_file): + metadata = get_dataset_metadata(pq_file) + max_peptide_length = max(max_peptide_length, metadata['max_peptide_length']) + print( + f" {os.path.basename(pq_file)}: max_len={metadata['max_peptide_length']}, rows={metadata['total_rows']}") + + print(f"✓ Maximum peptide length across all datasets: {max_peptide_length}") + + # Create fold datasets and class weights + folds = [] + class_weights = [] + + for i in range(1, n_folds + 1): + print(f"\nProcessing fold {i}/{n_folds}") + train_path = os.path.join(fold_dir, f'fold_{i}_train.parquet') + val_path = os.path.join(fold_dir, f'fold_{i}_val.parquet') + + # Calculate class weights from training data + print(f" Calculating class weights...") + cw = calculate_class_weights(train_path) + print(f" Class weights: {cw}") + + # Create streaming datasets + train_ds = (create_streaming_dataset(train_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + .shuffle(buffer_size=args.buffer_size, reshuffle_each_iteration=True) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + val_ds = (create_streaming_dataset(val_path, max_peptide_length, max_mhc_length, + buffer_size=args.buffer_size) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + folds.append((train_ds, val_ds)) + class_weights.append(cw) + + # Force cleanup + cleanup_memory() + + # Create test datasets + print("Creating test datasets...") + test1_ds = (create_streaming_dataset(os.path.join(args.dataset_path, "test1.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + test2_ds = (create_streaming_dataset(os.path.join(args.dataset_path, "test2.parquet"), + max_peptide_length, max_mhc_length, buffer_size=args.buffer_size) + .batch(args.batch) + .prefetch(tf.data.AUTOTUNE)) + + print(f"✓ Created {n_folds} fold datasets and 2 test datasets") + print("Memory after dataset creation:") + monitor_memory() + + # Training loop + print("\n" + "=" * 60) + print("STARTING TRAINING") + print("=" * 60) + + for fold_id, ((train_loader, val_loader), class_weight) in enumerate(zip(folds, class_weights), start=1): + print(f'\n🔥 Training fold {fold_id}/{n_folds}') + + # Clean up before each fold + cleanup_memory() + + # Build fresh model for each fold + print("Building model...") + model = build_classifier(max_peptide_length, max_mhc_length) + model.summary() + + # Callbacks + ckpt_cb = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(run_dir, f'best_fold_{fold_id}.weights.h5'), + monitor='val_loss', save_best_only=True, mode='min', verbose=1) + early_cb = tf.keras.callbacks.EarlyStopping( + monitor='val_loss', patience=15, restore_best_weights=True, verbose=1) + + # Verify data shapes + for (x_pep, latents), labels in train_loader.take(1): + print(f"✓ Input shapes: peptide={x_pep.shape}, mhc={latents.shape}, labels={labels.shape}") + break + + print("Memory before training:") + monitor_memory() + + # Train model + print("🚀 Starting training...") + hist = model.fit( + train_loader, + validation_data=val_loader, + epochs=args.epochs, + class_weight=class_weight, + callbacks=[ckpt_cb, early_cb], + verbose=1, + ) + + print("Memory after training:") + monitor_memory() + + # Plot training curves + plot_training_curve(hist, run_dir, fold_id, model, val_loader) + + # Save model and metadata + model.save_weights(os.path.join(run_dir, f'model_fold_{fold_id}.weights.h5')) + metadata = { + "fold_id": fold_id, + "epochs": args.epochs, + "batch_size": args.batch, + "max_peptide_length": max_peptide_length, + "max_mhc_length": max_mhc_length, + "class_weights": class_weight, + "run_dir": run_dir, + "mhc_class": MHC_CLASS + } + with open(os.path.join(run_dir, f'metadata_fold_{fold_id}.json'), 'w') as f: + json.dump(metadata, f, indent=4) + + # Evaluate on test sets + print(f"\n📊 Evaluating fold {fold_id} on test sets...") + + # Test1 evaluation + print("Evaluating on test1 (balanced alleles)...") + plot_test_metrics(model, test1_ds, run_dir, fold_id, string="Test1_balanced_alleles") + + # Test2 evaluation + print("Evaluating on test2 (rare alleles)...") + plot_test_metrics(model, test2_ds, run_dir, fold_id, string="Test2_rare_alleles") + + print(f"✅ Fold {fold_id} completed successfully") + + # Cleanup + del model, hist + cleanup_memory() + + print("\n🎉 Training completed successfully!") + print(f"📁 All results saved to: {run_dir}") + + +if __name__ == "__main__": + BUFFER = 8192 # Reduced buffer size for memory efficiency + MHC_CLASS = 1 + dataset_path = f"../data/Custom_dataset/NetMHCpan_dataset/mhc_{MHC_CLASS}" + main([ + "--dataset_path", dataset_path, + "--epochs", "10", + "--batch", "8192", + "--buffer_size", "8192", + ]) \ No newline at end of file diff --git a/utils/visualize_tensorflow_model.py b/utils/visualize_tensorflow_model.py new file mode 100644 index 00000000..20514733 --- /dev/null +++ b/utils/visualize_tensorflow_model.py @@ -0,0 +1,52 @@ +import tensorflow as tf +from tensorflow.keras.models import load_model +import os + +# Load the model +# Import the custom layer - adjust the import path as needed +# from utils.model import LatentProj, SelfAttentionBlock, CrossAttentionBlock, PeptideProj # Update this import based on where your LatentProj is defined +# with tf.keras.utils.custom_object_scope({'LatentProj': LatentProj,'PeptideProj': PeptideProj, 'SelfAttentionBlock': SelfAttentionBlock, 'CrossAttentionBlock': CrossAttentionBlock}): +# model = load_model('runs/run_20250603-111633/best_weights.h5') + +from utils.model_archive import AttentionLayer, PositionalEncoding, AnchorPositionExtractor +# Import any additional classes/functions used by Lambda layers +from utils.model_archive import * # Import all potential dependencies + + +import uuid + +def wrap_layer(layer_class): + def fn(**config): + config.pop('trainable', None) + config.pop('dtype', None) + # assign a unique name to avoid duplicates + config['name'] = f"{layer_class.__name__.lower()}_{uuid.uuid4().hex[:8]}" + return layer_class.from_config(config) + return fn + + +# Load the model with wrapped custom objects +model = load_model( + 'model_output/peptide_mhc_cross_attention_model.h5', + custom_objects={ + 'AttentionLayer': wrap_layer(AttentionLayer), + 'PositionalEncoding': wrap_layer(PositionalEncoding), + 'AnchorPositionExtractor': wrap_layer(AnchorPositionExtractor), + } +) +# Display and save the model's architecture + +# Display model summary +model.summary() + +# Create better visualization as SVG with cleaner layout +tf.keras.utils.plot_model( + model, + to_file='model_output/model_architecture.png', + show_shapes=True, + show_layer_names=True, + rankdir='TB', # Top to bottom layout + dpi=200, # Higher resolution + expand_nested=True, # Expand nested models to show all layers + show_layer_activations=True # Show activation functions +) \ No newline at end of file