From a17b954a08dfe0866083663bfe87ae25c80ac82a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:01:28 +0200 Subject: [PATCH 01/12] Add makefile --- Makefile | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9f3705f --- /dev/null +++ b/Makefile @@ -0,0 +1,56 @@ +CEBRA_LENS_VERSION := 0.1.0.dev0 + +dist: + python3 -m pip install virtualenv + python3 -m pip install --upgrade build twine + python3 -m build --wheel --sdist + +build: dist + +test: + python -m pytest --ff tests + +interrogate: + interrogate \ + --ignore-property-decorators \ + --ignore-init-method \ + --verbose \ + --ignore-semiprivate \ + --ignore-private \ + --ignore-magic \ + --omit-covered-files \ + -f 80 \ + cebra_lens + +docs: + jupyter-book build docs + +docs-touch: + find docs/docs -iname '*.md' -exec touch {} \; + jupyter-book build docs/docs + +docs-strict: + jupyter-book build docs --keep-going --strict + +# Serve the docs +serve_docs: + python -m http.server 8080 --bind 127.0.0.1 -d docs/_build/html + +# Serve the entire page +serve_page: + python -m http.server 8080 --bind 127.0.0.1 -d docs/_build/html + +# Format code in the main package and docs +format: + yapf -i -p -r cebra_lens + yapf -i -p -r tests + yapf -i -p -r docs/docs/examples + yapf -i -p -r docs/docs/conf.py + isort cebra_lens/ + isort tests/ + +codespell: + codespell cebra_lens/ tests/ docs/docs/*.md -L "nce, nd" + + +.PHONY: docs docs-touch docs-strict serve_docs serve_page \ No newline at end of file From 19f1dbfc8020041483417e25e903a91b8f64e407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:01:41 +0200 Subject: [PATCH 02/12] Update requirements.txt --- requirements.txt | 186 +++++++++-------------------------------------- 1 file changed, 33 insertions(+), 153 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4a514c6..fa8f2b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,154 +1,34 @@ -anaconda-anon-usage -anyio -archspec -argon2-cffi -argon2-cffi-bindings -arrow -asttokens -async-lru -attrs -Babel -beautifulsoup4 -bleach -boltons -Brotli -cebra -certifi -cffi -charset-normalizer -colorama -comm -conda -conda-content-trust -conda-libmamba-solver -conda-package-handling -conda_package_streaming -contourpy -cryptography -cycler -debugpy -decorator -defusedxml -distro -et_xmlfile -executing -fastjsonschema -filelock -fonttools -fqdn -frozendict -fsspec -h11 -h5py -hdf5storage -idna -ipykernel -ipython -ipywidgets -isoduration -jedi -Jinja2 -joblib -json5 -jsonpatch -jsonpointer -jsonschema -jsonschema-specifications -jupyter -jupyter_client -jupyter-console -jupyter_core -jupyter-events -jupyter-lsp -jupyter_server -jupyter_server_terminals -jupyterlab -jupyterlab-pygments -jupyterlab_server -jupyterlab-widgets -kiwisolver -libmambapy -literate-dataclasses -MarkupSafe -matplotlib +cebra +joblib +# Platform-specific numpy constraints +numpy<2.0; platform_system=="Windows" +numpy<2.0; platform_system!="Windows" and python_version<"3.10" +numpy; platform_system!="Windows" and python_version>="3.10" +literate-dataclasses +scikit-learn +scipy +torch>=2.4.0 +tqdm +matplotlib<3.11 matplotlib-inline -menuinst -mistune -mpmath -nbclient -nbconvert -nbformat -nest-asyncio -networkx -notebook -notebook_shim -numpy -openpyxl -overrides -packaging -pandas -pandocfilters -parso -pillow -pip -platformdirs -pluggy -ply -prometheus-client -prompt-toolkit -psutil -pure-eval -pycosat -pycparser -Pygments -pyparsing -PyQt5 -PyQt5-sip -PySocks -python-dateutil -python-json-logger -pytz -pywin32 -pywinpty -PyYAML -pyzmq -qtconsole -QtPy -referencing -requests -rfc3339-validator -rfc3986-validator -rpds-py -ruamel.yaml -scikit-learn -scipy -seaborn -Send2Trash -setuptools -sip -six -sniffio -soupsieve -stack-data -sympy -terminado -threadpoolctl -tinycss2 -torch -tornado -tqdm -traitlets -truststore -types-python-dateutil -typing_extensions -tzdata -uri-template -urllib3 -wcwidth -webcolors -webencodings -websocket-client -wheel -widgetsnbextension -win-inet-pton -zstandard +requests +pandas +plotly +seaborn +jupyter-book +ghp-import +ipykernel +jupyter +nbconvert +nbformat +pylint +toml +yapf +black +isort +coverage +pytest +licenseheaders +interrogate +codespell +cffconvert \ No newline at end of file From 0834687115c72495114f6bc0cfaa2fb29c4638e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:02:15 +0200 Subject: [PATCH 03/12] Fix tests, run formater and codespell checks --- cebra_lens/activations.py | 46 ++-- cebra_lens/matplotlib.py | 265 +++++++++++++----------- cebra_lens/quantification/base.py | 8 +- cebra_lens/quantification/cka_metric.py | 56 ++--- cebra_lens/quantification/decoding.py | 107 +++++----- cebra_lens/quantification/distance.py | 117 ++++++----- cebra_lens/quantification/misc.py | 49 ++--- cebra_lens/quantification/rdm_metric.py | 28 ++- cebra_lens/quantification/tsne.py | 6 +- cebra_lens/utils.py | 10 +- cebra_lens/utils_allen.py | 82 +++----- cebra_lens/utils_hpc.py | 6 +- docs/docs/installation.md | 2 +- tests/test_activations.py | 4 +- tests/test_decoding.py | 6 +- tests/test_misc.py | 9 +- tests/test_rdm.py | 38 ++-- tests/test_utils.py | 9 +- 18 files changed, 425 insertions(+), 423 deletions(-) diff --git a/cebra_lens/activations.py b/cebra_lens/activations.py index c9beb9d..69b38a0 100644 --- a/cebra_lens/activations.py +++ b/cebra_lens/activations.py @@ -10,9 +10,8 @@ import matplotlib.pyplot as plt -def _cut_array( - array: npt.NDArray, cut_indices: Tuple[np.int64, np.int64] -) -> npt.NDArray: +def _cut_array(array: npt.NDArray, + cut_indices: Tuple[np.int64, np.int64]) -> npt.NDArray: """ Slices the input array based on the provided cut indices. This is used to remove the padding from activations in `get_activations_model`. @@ -36,7 +35,7 @@ def _cut_array( sliced_array = array else: # Otherwise, slice the array - sliced_array = array[:, start : end if end != 0 else start :] + sliced_array = array[:, start:end if end != 0 else start:] return sliced_array @@ -80,10 +79,12 @@ def get_cut_indices( # add for output layer cut_indices.append((0, 0)) elif layer_type == None: - raise NotImplementedError("Padding handling not implemented for 'all'.") + raise NotImplementedError( + "Padding handling not implemented for 'all'.") else: # need to analyze the padding from the last output of Conv1 and apply the same cut - raise NotImplementedError(f"Padding handling not implemented for {layer_type}.") + raise NotImplementedError( + f"Padding handling not implemented for {layer_type}.") return cut_indices @@ -125,26 +126,25 @@ def get_activations_model( activations = {} transform_kwargs = {} if model.solver_name_ in [ - "multi-session", - "multi-session-aux", - "multiobjective-solver", + "multi-session", + "multi-session-aux", + "multiobjective-solver", ]: model_ = model.model_[session_id] transform_kwargs.update({"session_id": session_id}) elif model.solver_name_ in [ - "single-session", - "single-session-aux", - "single-session-hybrid", - "single-session-full", + "single-session", + "single-session-aux", + "single-session-hybrid", + "single-session-full", ]: model_ = model.model_ else: raise NotImplementedError( - f"Solver {model.solver_name_} is not yet implemented." - ) + f"Solver {model.solver_name_} is not yet implemented.") activations, handles, conv_layer_info = _attach_hooks( activations=activations, @@ -209,14 +209,14 @@ def process_activations( name=model_name, instance=i, layer_type=layer_type, - ) - ) + )) return activations # Function to create a hook that stores the activations in the dictionary def _get_activation(name: str, activations: Dict): + def hook(model, input, output): activations[name] = output.detach().squeeze().numpy() @@ -262,8 +262,7 @@ def _attach_hooks( # attach hook to the layer_type and to the output layer if isinstance(model.net[i], layer_type) or i == len(model.net) - 1: hook, activations = _get_activation( - f"{name}_{instance}_layer_{num_layer}", activations - ) + f"{name}_{instance}_layer_{num_layer}", activations) if isinstance(model.net[i], layer_type): conv_layer_info.append(model.net[i].kernel_size[0]) handle = model.net[i].register_forward_hook(hook) @@ -298,8 +297,7 @@ def _attach_hooks( else: hook, activations = _get_activation( - f"{name}_{instance}_layer_{num_layer}", activations - ) + f"{name}_{instance}_layer_{num_layer}", activations) handle = model.net[i].register_forward_hook(hook) handles.append(handle) @@ -309,8 +307,7 @@ def _attach_hooks( def aggregate_activations( - activations: Dict[str, npt.NDArray], -) -> Dict[str, npt.NDArray]: + activations: Dict[str, npt.NDArray], ) -> Dict[str, npt.NDArray]: """ Aggregates activations by model identifier aka. instance. This function takes a dictionary of activations where the keys are strings containing model identifiers and layer information, @@ -387,8 +384,7 @@ def get_activations( activations = activations or {} aggregated_activations = aggregate_activations( - process_activations(models, data, session_id, activations, layer_type) - ) + process_activations(models, data, session_id, activations, layer_type)) activations_dict = {} for key, value in aggregated_activations.items(): diff --git a/cebra_lens/matplotlib.py b/cebra_lens/matplotlib.py index 1675062..482db85 100644 --- a/cebra_lens/matplotlib.py +++ b/cebra_lens/matplotlib.py @@ -52,9 +52,8 @@ class _GenericPlot(_BasePlot): """ - def __init__( - self, axis: Optional[matplotlib.axes.Axes], figsize: Tuple, title: str - ): + def __init__(self, axis: Optional[matplotlib.axes.Axes], figsize: Tuple, + title: str): super().__init__(axis, figsize) self.title = title self.unique_keys = [] @@ -78,7 +77,8 @@ def plot(self, plot_data: Dict[str, npt.NDArray], y_axis: str) -> None: for values in data_list: layer_values.append(values) sns.lineplot( - x=np.arange(1, len(values) + 1), + x=np.arange(1, + len(values) + 1), y=values, linestyle="-", marker="D", @@ -89,14 +89,12 @@ def plot(self, plot_data: Dict[str, npt.NDArray], y_axis: str) -> None: layer_values = np.array(layer_values) - mean_values = ( - layer_values - if layer_values.ndim == 1 - else np.mean(layer_values, axis=0) - ) + mean_values = (layer_values if layer_values.ndim == 1 else np.mean( + layer_values, axis=0)) sns.lineplot( - x=np.arange(1, len(mean_values) + 1), + x=np.arange(1, + len(mean_values) + 1), y=mean_values, linestyle="-", marker="D", @@ -138,7 +136,8 @@ def __init__( self.results_dict = results_dict self.plot_data = self._transform() - self.unique_keys = list(self.results_dict.keys()) # Define unique keys here + self.unique_keys = list( + self.results_dict.keys()) # Define unique keys here self.colors = sns.color_palette("husl", len(self.unique_keys)) def _transform(self): @@ -190,7 +189,8 @@ def __init__( self.results_dict = results_dict self.plot_data = self._transform() - self.unique_keys = list(self.results_dict.keys()) # Define unique keys here + self.unique_keys = list( + self.results_dict.keys()) # Define unique keys here self.colors = sns.color_palette("husl", len(self.unique_keys)) def _transform(self) -> Dict[str, List[List[np.float64]]]: @@ -269,7 +269,8 @@ def __init__( self.dataset_label = dataset_label self.results_dict = results_dict self.plot_data = self._transform() - self.unique_keys = list(self.results_dict.keys()) # Define unique keys here + self.unique_keys = list( + self.results_dict.keys()) # Define unique keys here self.colors = sns.color_palette("husl", len(self.unique_keys)) def _transform(self) -> Dict[str, List[List[np.float64]]]: @@ -345,9 +346,10 @@ def plot_rdm_correlation( The generated figure containing the RDM comparison plot. """ - return RDMPlotOracle( - results_dict=rdm_dict, title=title, figsize=figsize, axis=ax - ).plot(**kwargs) + return RDMPlotOracle(results_dict=rdm_dict, + title=title, + figsize=figsize, + axis=ax).plot(**kwargs) def plot_distance( @@ -465,8 +467,7 @@ def __init__( ) # Call parent constructor to initialize self.fig and self.ax self.results_dict = results_dict self.palette = sns.color_palette( - palette, len(results_dict) - ) # Define a color palette + palette, len(results_dict)) # Define a color palette self.dataset_label = dataset_label # Define dataset label self.plot_error = plot_error self.label = label @@ -478,9 +479,9 @@ def __init__( def plot(self, **kwargs) -> None: """Plotting logic to plot the decoding scores across models where the x-axis are the model labels, and the y-axis are the decoding scores values.""" - x_positions = list( - range(1, len(self.results_dict) + 1) - ) # X positions for scatter points + x_positions = list(range(1, + len(self.results_dict) + + 1)) # X positions for scatter points for i, (key, results) in enumerate(self.results_dict.items()): if self.dataset_label == "visual": @@ -495,7 +496,7 @@ def plot(self, **kwargs) -> None: measure = "(cm)" else: if self.plot_error: - # betwen error and R^2 score, you want to plot the error + # between error and R^2 score, you want to plot the error score = [dict_el[0][1][self.label] for dict_el in results] self.plot_label = "Error score" # choice of label to plot, self.metric @@ -506,9 +507,10 @@ def plot(self, **kwargs) -> None: mean_error = np.mean(score) color = self.palette[i] - self.ax.scatter( - np.ones_like(score) * x_positions[i], score, color=color, alpha=0.3 - ) + self.ax.scatter(np.ones_like(score) * x_positions[i], + score, + color=color, + alpha=0.3) self.ax.scatter( x_positions[i], mean_error, @@ -522,10 +524,10 @@ def plot(self, **kwargs) -> None: self.ax.set_title(f"Comparison of {self.plot_label} Across Models") self.ax.set_xticks(x_positions) self.ax.set_xticklabels( - self.results_dict.keys() - ) # Set model names as x-tick labels + self.results_dict.keys()) # Set model names as x-tick labels self.ax.legend() # Show legend for model labels - sns.despine(ax=self.ax) # Remove top and right spines for aesthetic reasons + sns.despine( + ax=self.ax) # Remove top and right spines for aesthetic reasons def plot_decoding( @@ -640,15 +642,15 @@ def _multi_padding_check(self, embeddings_1, embeddings_2): # Padding the shorter embedding to match the number of layers in the longer embedding if self.num_layers_1 > self.num_layers_2: - embeddings_2 += [np.empty_like(embeddings_2[0])] * ( - self.num_layers_1 - self.num_layers_2 - ) + embeddings_2 += [np.empty_like(embeddings_2[0]) + ] * (self.num_layers_1 - self.num_layers_2) elif self.num_layers_2 > self.num_layers_1: - embeddings_1 += [np.empty_like(embeddings_1[0])] * ( - self.num_layers_2 - self.num_layers_1 - ) + embeddings_1 += [np.empty_like(embeddings_1[0]) + ] * (self.num_layers_2 - self.num_layers_1) - def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: + def _define_ax( + self, + axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: """Define the ax on which to generate the plot. Parameters: @@ -662,9 +664,8 @@ def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Ax """ if axis is None: if len(self.embeddings_list) == 2: - self._multi_padding_check( - self.embeddings_list[0], self.embeddings_list[1] - ) + self._multi_padding_check(self.embeddings_list[0], + self.embeddings_list[1]) self.fig, self.ax = plt.subplots( 2, max(self.num_layers_1, self.num_layers_2), @@ -683,13 +684,13 @@ def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Ax return self.ax def _plot_dataset( - self, - ax: matplotlib.axes.Axes, - embedding: npt.NDArray, - label: str, - label_ind: int = None, - gray: bool = False, - idx_order: Tuple[int, int, int] = (0, 1, 2), + self, + ax: matplotlib.axes.Axes, + embedding: npt.NDArray, + label: str, + label_ind: int = None, + gray: bool = False, + idx_order: Tuple[int, int, int] = (0, 1, 2), ) -> matplotlib.axes.Axes: """ Plot the dataset embedding, for generic dataset. @@ -720,11 +721,9 @@ def _plot_dataset( if label.shape[0] == 1 and label.shape[1] != 1: label = label.T - if ( - 0 in np.unique(label[:, label_ind]) - and 1 in np.unique(label[:, label_ind]) - and len(np.unique(label[:, label_ind])) == 2 - ): + if (0 in np.unique(label[:, label_ind]) + and 1 in np.unique(label[:, label_ind]) + and len(np.unique(label[:, label_ind])) == 2): l_ind = label[:, label_ind] == 1 l_c = label[l_ind, label_ind] l = ax.scatter( @@ -762,12 +761,12 @@ def _plot_dataset( return ax def _plot_hippocampus( - self, - ax: matplotlib.axes.Axes, - embedding: npt.NDArray, - label: str, - gray: bool = False, - idx_order: Tuple[int, int, int] = (0, 1, 2), + self, + ax: matplotlib.axes.Axes, + embedding: npt.NDArray, + label: str, + gray: bool = False, + idx_order: Tuple[int, int, int] = (0, 1, 2), ) -> matplotlib.axes.Axes: """Plot the hippocampus embedding. @@ -834,12 +833,12 @@ def _plot_hippocampus( return ax def _plot_allen( - self, - ax: matplotlib.axes.Axes, - embedding: npt.NDArray, - label: str, - gray: bool = False, - idx_order: Tuple[int, int, int] = (0, 1, 2), + self, + ax: matplotlib.axes.Axes, + embedding: npt.NDArray, + label: str, + gray: bool = False, + idx_order: Tuple[int, int, int] = (0, 1, 2), ) -> matplotlib.axes.Axes: """Plot the Allen embedding. @@ -910,25 +909,27 @@ def plot_embedding_layers( f"{group_name}", fontsize=20, ) - labels_list = [self.labels[: self.sample_plot]] * num_layers + labels_list = [self.labels[:self.sample_plot]] * num_layers titles = [f"Layer {layer}" for layer in range(1, num_layers)] titles.append("Output layer") for i, (label, ax) in enumerate(zip(labels_list, axs)): - if ( - embeddings[i].shape[0] < embeddings[i].shape[1] - ): # should be num Samples X num Neurons + if (embeddings[i].shape[0] < embeddings[i].shape[1] + ): # should be num Samples X num Neurons embedding = embeddings[i].T else: embedding = embeddings[i] - embedding = embedding[: self.sample_plot, :] + embedding = embedding[:self.sample_plot, :] if self.dataset_label == "HPC": ax = self._plot_hippocampus(ax, embedding, label) elif self.dataset_label == "visual": ax = self._plot_allen(ax, embedding, label) else: - ax = self._plot_dataset(ax, embedding, label, label_ind=label_ind) + ax = self._plot_dataset(ax, + embedding, + label, + label_ind=label_ind) ax.set_title(titles[i], y=1) ax.axis("off") @@ -946,9 +947,10 @@ def plot_embedding(self, group_name: str, label_ind: int = None): The index of the label to be used for coloring the points in the embedding plot. """ - return self.plot_embedding_layers( - self.axs, self.embeddings, group_name, label_ind=label_ind - ) + return self.plot_embedding_layers(self.axs, + self.embeddings, + group_name, + label_ind=label_ind) def plot_compare(self, label_ind: int = None): """Plots embedding layers for models being compared @@ -1066,8 +1068,7 @@ def plot_embeddings( if not isinstance(data, Dict): if group_name is None: raise ValueError( - "If data is not a dictionary, group_name must be provided." - ) + "If data is not a dictionary, group_name must be provided.") data_dict = {group_name: [data]} for group_name, models in data_dict.items(): @@ -1078,9 +1079,9 @@ def plot_embeddings( dataset_label=dataset_label, sample_plot=sample_plot, axis=ax, - ).plot_embedding( - group_name=f"{group_name} instance {i}", label_ind=label_ind, **kwargs - ) + ).plot_embedding(group_name=f"{group_name} instance {i}", + label_ind=label_ind, + **kwargs) class _ActivationPlot: @@ -1124,7 +1125,9 @@ def __init__( self._define_ax(axis) self.fig.suptitle(title, fontsize=20) - def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: + def _define_ax( + self, + axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: """Define the ax on which to generate the plot. Args: @@ -1135,9 +1138,9 @@ def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Ax A ``matplotlib.axes.Axes`` on which to generate the plot. """ if axis is None: - self.fig, self.axes = plt.subplots( - self.num_layers + 1, 1, figsize=self.figsize - ) + self.fig, self.axes = plt.subplots(self.num_layers + 1, + 1, + figsize=self.figsize) else: self.axes = [axis] + [ axis.figure.add_subplot(self.num_layers + 1, 1, i + 2) @@ -1146,7 +1149,8 @@ def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Ax def plot(self): """Handles plotting logic.""" - self.axes[0].imshow(self.input_data.T[:, 0 : self.sample_plot], aspect="auto") + self.axes[0].imshow(self.input_data.T[:, 0:self.sample_plot], + aspect="auto") self.axes[0].set_title("Input Data") self.axes[0].set_ylabel("Channel #") self.axes[0].set_xlabel("Time") @@ -1155,7 +1159,7 @@ def plot(self): # Plot the embeddings for each layer for i in range(self.num_layers): self.axes[i + 1].imshow( - self.embeddings[i][:, 0 : self.sample_plot], + self.embeddings[i][:, 0:self.sample_plot], cmap=self.cmap, aspect="auto", ) @@ -1247,14 +1251,14 @@ class _HeatMapsPlot: """ def __init__( - self, - cka_matrices: Dict[str, npt.NDArray], - annot: bool, - axis: Optional[matplotlib.axes.Axes], - show_cbar: bool = True, - cbar_label: str = "CKA score", - color_map: str = "magma", - figsize: Tuple = (15, 5), + self, + cka_matrices: Dict[str, npt.NDArray], + annot: bool, + axis: Optional[matplotlib.axes.Axes], + show_cbar: bool = True, + cbar_label: str = "CKA score", + color_map: str = "magma", + figsize: Tuple = (15, 5), ): self.cka_matrices = cka_matrices self.annot = annot @@ -1273,12 +1277,17 @@ def __init__( "vmin": 0, "vmax": 1, "cmap": color_map, - "cbar_kws": {"label": cbar_label, "orientation": "horizontal"}, + "cbar_kws": { + "label": cbar_label, + "orientation": "horizontal" + }, } if self.num_comparisons == 1: self.axs = [self.axs] # handle the 1 comparison case - def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: + def _define_ax( + self, + axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: """Define the ax on which to generate the plot. Parameters: @@ -1291,9 +1300,9 @@ def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Ax A ``matplotlib.axes.Axes`` on which to generate the plot. """ if axis is None: - self.fig, self.axs = plt.subplots( - 1, self.num_comparisons, figsize=self.figsize - ) + self.fig, self.axs = plt.subplots(1, + self.num_comparisons, + figsize=self.figsize) else: self.axs = axis @@ -1303,7 +1312,10 @@ def plot(self): """Handles plotting logic.""" for i, (key, value) in enumerate(self.cka_matrices.items()): - sns.heatmap(value, ax=self.axs[i], annot=self.annot, **self.heatmap_kwargs) + sns.heatmap(value, + ax=self.axs[i], + annot=self.annot, + **self.heatmap_kwargs) num_layers = value.shape[1] num_models = value.shape[0] @@ -1313,17 +1325,20 @@ def plot(self): if i == 0: self.axs[i].set_ylabel("Model Instantiation", fontsize=12) self.axs[i].set_yticks(np.arange(num_models) + 0.5) - self.axs[i].set_yticklabels([m for m in range(1, num_models + 1)]) + self.axs[i].set_yticklabels( + [m for m in range(1, num_models + 1)]) else: self.axs[i].set_ylabel("") self.axs[i].set_yticks([]) self.axs[i].set_xticks(np.arange(num_layers) + 0.5) - self.axs[i].set_xticklabels([f"L{l}" for l in range(1, num_layers + 1)]) + self.axs[i].set_xticklabels( + [f"L{l}" for l in range(1, num_layers + 1)]) # Adjust layout plt.subplots_adjust(wspace=0.1, right=0.9) - self.fig.suptitle("Similarity between model representations (CKA)", fontsize=16) + self.fig.suptitle("Similarity between model representations (CKA)", + fontsize=16) def plot_cka_heatmaps( @@ -1437,17 +1452,15 @@ def __init__( if len(self.rdms) != len(self.titles): raise ValueError( - "The two lists (rdms and titles) must have the same length." - ) + "The two lists (rdms and titles) must have the same length.") # Generate tick labels specific to the dataset if dataset_label == "visual": self.tick_labels = [str(i) for i in range(0, 930, 30)] elif self.dataset_label == "HPC": - self.tick_positions = ( - np.arange(0, 34, 2) / 10 - ) # Ticks at 0, 0.2, 0.4,..., 1.6 + self.tick_positions = (np.arange(0, 34, 2) / 10 + ) # Ticks at 0, 0.2, 0.4,..., 1.6 self.tick_labels = [ "0.0", "0.2", @@ -1485,7 +1498,9 @@ def __init__( else: self.tick_labels, _ = np.unique(labels, return_inverse=True) - def _define_ax(self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: + def _define_ax( + self, + axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes: """Define the ax on which to generate the plot. Parameters: @@ -1516,12 +1531,10 @@ def plot(self): if self.dataset_label == "HPC": # Set the x and y ticks - self.ax[i].set_xticks( - self.tick_positions * len(rdm) // 1.6 / 2 - ) # Scale ticks to the range of data - self.ax[i].set_yticks( - self.tick_positions * len(rdm) // 1.6 / 2 - ) # Same for y-axis + self.ax[i].set_xticks(self.tick_positions * len(rdm) // 1.6 / + 2) # Scale ticks to the range of data + self.ax[i].set_yticks(self.tick_positions * len(rdm) // 1.6 / + 2) # Same for y-axis # Set the tick labels to show 0, 0.2, ..., 1.6 self.ax[i].set_xticklabels(self.tick_labels) @@ -1567,29 +1580,35 @@ def plot(self): size = rdm.shape[0] num_categories = len(self.tick_labels) block_size = size / num_categories - tick_positions = np.arange(block_size / 2, size, block_size) + tick_positions = np.arange(block_size / 2, size, + block_size) self.ax[i].set_xticks(tick_positions) self.ax[i].set_yticks(tick_positions) - self.ax[i].set_xticklabels( - self.tick_labels, rotation=90, ha="right" - ) + self.ax[i].set_xticklabels(self.tick_labels, + rotation=90, + ha="right") self.ax[i].set_yticklabels(self.tick_labels) else: - self.ax[i].set_xticks(np.linspace(0, rdm.shape[1] - 1, num_ticks)) - self.ax[i].set_yticks(np.linspace(0, rdm.shape[0] - 1, num_ticks)) - self.ax[i].set_xticklabels( - self.tick_labels, rotation=90, ha="right", fontsize=6 - ) + self.ax[i].set_xticks( + np.linspace(0, rdm.shape[1] - 1, num_ticks)) + self.ax[i].set_yticks( + np.linspace(0, rdm.shape[0] - 1, num_ticks)) + self.ax[i].set_xticklabels(self.tick_labels, + rotation=90, + ha="right", + fontsize=6) self.ax[i].set_yticklabels(self.tick_labels, fontsize=6) plt.suptitle("Representational Dissimilarity Matrix (RDM)") plt.tight_layout() plt.subplots_adjust(bottom=0.2) - self.fig.colorbar( - cax, ax=self.ax, orientation="horizontal", fraction=0.05, label=self.metric - ) + self.fig.colorbar(cax, + ax=self.ax, + orientation="horizontal", + fraction=0.05, + label=self.metric) def plot_rdm( diff --git a/cebra_lens/quantification/base.py b/cebra_lens/quantification/base.py index 887e6c9..b2a3615 100644 --- a/cebra_lens/quantification/base.py +++ b/cebra_lens/quantification/base.py @@ -14,7 +14,8 @@ class _BaseMetric: """ @abstractmethod - def compute(self, activations: Dict[str, npt.NDArray]) -> Dict[str, npt.NDArray]: + def compute(self, + activations: Dict[str, npt.NDArray]) -> Dict[str, npt.NDArray]: """ Every metric which inherits ``_BaseMetric`` needs to implement a compute function. The compute function is specific to a metric, e.g. intra-bin distance, RDM, CKA,... @@ -66,9 +67,8 @@ def save(self, filepath: str, data: Dict[str, npt.NDArray]) -> None: and the value is a npt.NDArray containing for all the models under that label the calculated data. """ filepath = Path(filepath) - custom_filepath = filepath.with_stem( - filepath.stem + f"_{self.__class__.__name__}" - ) + custom_filepath = filepath.with_stem(filepath.stem + + f"_{self.__class__.__name__}") with open(custom_filepath, "wb") as f: pickle.dump(data, f) diff --git a/cebra_lens/quantification/cka_metric.py b/cebra_lens/quantification/cka_metric.py index e892317..741da37 100644 --- a/cebra_lens/quantification/cka_metric.py +++ b/cebra_lens/quantification/cka_metric.py @@ -33,7 +33,9 @@ def __init__(self, comparisons: List[Tuple[str, str]]): ) self.comparisons = comparisons - def center_gram(self, gram: npt.NDArray, unbiased: bool = False) -> npt.NDArray: + def center_gram(self, + gram: npt.NDArray, + unbiased: bool = False) -> npt.NDArray: """ Center a symmetric Gram matrix. @@ -76,9 +78,10 @@ def center_gram(self, gram: npt.NDArray, unbiased: bool = False) -> npt.NDArray: return gram - def cka( - self, gram_x: npt.NDArray, gram_y: npt.NDArray, debiased: bool = False - ) -> np.float64: + def cka(self, + gram_x: npt.NDArray, + gram_y: npt.NDArray, + debiased: bool = False) -> np.float64: """Compute CKA. Parameters: @@ -120,9 +123,8 @@ def gram_linear(self, x: npt.NDArray) -> npt.NDArray: return x.dot(x.T) - def _compute_cka( - self, embeddings_1: List[npt.NDArray], embeddings_2: List[npt.NDArray] - ) -> npt.NDArray: + def _compute_cka(self, embeddings_1: List[npt.NDArray], + embeddings_2: List[npt.NDArray]) -> npt.NDArray: """ Compute the Centered Kernel Alignment (CKA) between two sets of embeddings for each layer. @@ -144,12 +146,12 @@ def _compute_cka( if len(embeddings_1) != len(embeddings_2): raise ValueError( - "CKA similarity comparison is done between smae or similar model architectures. The number of layers in embeddings_1 and embeddings_2 must be the same." + "CKA similarity comparison is done between same or similar model architectures. The number of layers in embeddings_1 and embeddings_2 must be the same." ) for i in range(len(embeddings_1)): if embeddings_1[i].shape != embeddings_2[i].shape: raise ValueError( - f"CKA similarity comparison is done between smae or similar model architectures. The shape of layer {i} in embeddings_1 and embeddings_2 must be the same." + f"CKA similarity comparison is done between same or similar model architectures. The shape of layer {i} in embeddings_1 and embeddings_2 must be the same." ) cka_matrix = np.zeros((1, len(embeddings_1))) @@ -187,14 +189,15 @@ def _compute_per_layer( for j in tqdm(range(len(embeddings_1))): if flag: # the situation when there multiple models inside model labels and the same number of models inside each label - cka_matrix[j, :] = self._compute_cka(embeddings_1[j], embeddings_2[j]) + cka_matrix[j, :] = self._compute_cka(embeddings_1[j], + embeddings_2[j]) else: - cka_matrix[j, :] = self._compute_cka(embeddings_1[j], embeddings_2) + cka_matrix[j, :] = self._compute_cka(embeddings_1[j], + embeddings_2) return cka_matrix - def compute( - self, activations: Dict[str, npt.NDArray], comparison: Tuple[str, str] - ) -> npt.NDArray: + def compute(self, activations: Dict[str, npt.NDArray], + comparison: Tuple[str, str]) -> npt.NDArray: """ Compute multi-layer Centered Kernel Alignment (CKA) between different sets of activations. @@ -233,19 +236,22 @@ def compute( embeddings_1 = activations_2 embeddings_2 = activations_1[0] - self.cka_matrix = self._compute_per_layer(embeddings_1, embeddings_2) + self.cka_matrix = self._compute_per_layer(embeddings_1, + embeddings_2) # example when compare intra model single_TR v single_TR, only compare to the first instantiation elif self.comparisonX == self.comparisonY: embeddings_1 = activations_1 embeddings_2 = activations_2[0] - self.cka_matrix = self._compute_per_layer(embeddings_1, embeddings_2) + self.cka_matrix = self._compute_per_layer(embeddings_1, + embeddings_2) else: # when the model labels have the same number of models, but are different labels embeddings_1 = activations_1 embeddings_2 = activations_2 - self.cka_matrix = self._compute_per_layer(embeddings_1, embeddings_2, True) + self.cka_matrix = self._compute_per_layer(embeddings_1, + embeddings_2, True) return self.cka_matrix @@ -254,14 +260,14 @@ def __name__(self) -> str: return "cka" def plot( - self, - cka_matrices: Dict[str, npt.NDArray], - annot: bool, - show_cbar: bool = True, - cbar_label: str = "CKA score", - color_map: str = "magma", - figsize: tuple = (15, 5), - ax: Optional[matplotlib.axes.Axes] = None, + self, + cka_matrices: Dict[str, npt.NDArray], + annot: bool, + show_cbar: bool = True, + cbar_label: str = "CKA score", + color_map: str = "magma", + figsize: tuple = (15, 5), + ax: Optional[matplotlib.axes.Axes] = None, ): """ Plot the CKA matrices as heatmaps. diff --git a/cebra_lens/quantification/decoding.py b/cebra_lens/quantification/decoding.py index 0c8c2fe..cb58fbe 100644 --- a/cebra_lens/quantification/decoding.py +++ b/cebra_lens/quantification/decoding.py @@ -56,16 +56,14 @@ def decoding( for n in params: train_decoder = cebra.KNNDecoder(n_neighbors=n, metric="cosine") train_valid_idx = int(len(embedding_train) / 9 * 8) - train_decoder.fit( - embedding_train[:train_valid_idx], label_train[:train_valid_idx, i] - ) + train_decoder.fit(embedding_train[:train_valid_idx], + label_train[:train_valid_idx, i]) pred = train_decoder.predict(embedding_train[train_valid_idx:]) err = label_train[train_valid_idx:, i] - pred errs.append(abs(err).sum()) - test_decoder = cebra.KNNDecoder( - n_neighbors=params[np.argmin(errs)], metric="cosine" - ) + test_decoder = cebra.KNNDecoder(n_neighbors=params[np.argmin(errs)], + metric="cosine") test_decoder.fit(embedding_train, label_train[:, i]) label_pred = test_decoder.predict(embedding_test) @@ -73,7 +71,8 @@ def decoding( predictions.append(label_pred) label_test_err = np.median(abs(label_pred - label_test[:, i])) labels_test_err.append(label_test_err) - label_test_score = sklearn.metrics.r2_score(label_test[:, i], label_pred) + label_test_score = sklearn.metrics.r2_score(label_test[:, i], + label_pred) labels_test_score.append(label_test_score) # transform it into an appropriate shape @@ -133,7 +132,8 @@ def __init__( self.output_only = output_only def output_information(self): - print("The decoding analysis initialized with the following parameters:") + print( + "The decoding analysis initialized with the following parameters:") print(f"Session ID: {self.session_id}") print(f"Dataset label: {self.dataset_label}") print(f"Layer type: {self.layer_type}") @@ -181,13 +181,11 @@ def _decode( Array containing the decoding results based on the given embeddings and labels. Has different structure depending on the dataset used: e.g. 1D array of structure test_score, pos_test_err, pos_test_score for HPC dataset, or test_score, test_err, test_acc for Allen visual dataset. """ - if ( - embedding_train.shape[0] < embedding_train.shape[1] - ): # should be samples X neurons + if (embedding_train.shape[0] + < embedding_train.shape[1]): # should be samples X neurons embedding_train = embedding_train.T - if ( - embedding_test.shape[0] < embedding_test.shape[1] - ): # should be samples X neurons + if (embedding_test.shape[0] + < embedding_test.shape[1]): # should be samples X neurons embedding_test = embedding_test.T if dataset_label == "visual": @@ -237,26 +235,27 @@ def compute( num_layers = 0 if model.solver_name_ not in [ - "single-session", - "single-session-aux", - "single-session-hybrid", - "single-session-full", - "multi-session", - "multi-session-aux", - "multiobjective-solver", + "single-session", + "single-session-aux", + "single-session-hybrid", + "single-session-full", + "multi-session", + "multi-session-aux", + "multiobjective-solver", ]: raise NotImplementedError( - f"Solver {model.solver_name_} is not yet implemented." - ) + f"Solver {model.solver_name_} is not yet implemented.") elif model.solver_name_ in [ - "multi-session", - "multi-session-aux", - "multiobjective-solver", + "multi-session", + "multi-session-aux", + "multiobjective-solver", ]: transform_kwargs.update({"session_id": self.session_id}) - train_embedding = model.transform(self.train_data, **transform_kwargs) - test_embedding = model.transform(self.test_data, **transform_kwargs) + train_embedding = model.transform(self.train_data, + **transform_kwargs) + test_embedding = model.transform(self.test_data, + **transform_kwargs) else: activations_train = get_activations_model( @@ -286,31 +285,29 @@ def compute( train_embedding = self.train_data test_embedding = self.test_data - results.update( - { - i: self._decode( - train_embedding, - self.train_label, - test_embedding, - self.test_label, - self.dataset_label, - ) - } - ) + results.update({ + i: + self._decode( + train_embedding, + self.train_label, + test_embedding, + self.test_label, + self.dataset_label, + ) + }) else: - results.update( - { - i: self._decode( - activations_train[keys[i - 1]], - self.train_label, - activations_test[keys[i - 1]], - self.test_label, - self.dataset_label, - ) - } - ) + results.update({ + i: + self._decode( + activations_train[keys[i - 1]], + self.train_label, + activations_test[keys[i - 1]], + self.test_label, + self.dataset_label, + ) + }) return results @@ -372,10 +369,8 @@ def plot( ) if self.output_only: - return plot_decoding( - results_dict, palette, self.dataset_label, label, plot_error, ax - ) + return plot_decoding(results_dict, palette, self.dataset_label, + label, plot_error, ax) else: - return plot_layer_decoding( - results_dict, title, self.dataset_label, label, plot_error, figsize - ) + return plot_layer_decoding(results_dict, title, self.dataset_label, + label, plot_error, figsize) diff --git a/cebra_lens/quantification/distance.py b/cebra_lens/quantification/distance.py index 3a34f47..3692e8e 100644 --- a/cebra_lens/quantification/distance.py +++ b/cebra_lens/quantification/distance.py @@ -18,9 +18,8 @@ class DistanceMetric: This class provides methods to compute distances between embeddings and centroids. """ - def compute_centroid( - self, embedding: npt.NDArray, indices: List[np.int64] - ) -> np.float64: + def compute_centroid(self, embedding: npt.NDArray, + indices: List[np.int64]) -> np.float64: """ Computes the centroid of a single embedding (e.g. single layer) for specified bin indices. @@ -36,12 +35,13 @@ def compute_centroid( np.float64 The computed centroid value. """ - bin_data = embedding[:, indices.flatten()] # Get data for the current bin + bin_data = embedding[:, + indices.flatten()] # Get data for the current bin return np.mean(bin_data, axis=1) # Compute centroid - def scale_embedding( - self, embedding: npt.NDArray, metric: str = "cosine" - ) -> npt.NDArray: + def scale_embedding(self, + embedding: npt.NDArray, + metric: str = "cosine") -> npt.NDArray: """ Scales the embedding data based on the specified metric. @@ -61,8 +61,7 @@ def scale_embedding( if metric == "euclidean": scaler = StandardScaler() return scaler.fit_transform( - embedding.T - ).T # Standardize across each dimension + embedding.T).T # Standardize across each dimension elif metric == "cosine": return embedding else: @@ -70,9 +69,10 @@ def scale_embedding( f"The scaling for metric {metric} is not yet implemented. Please use 'cosine' or 'euclidean'." ) - def compute_centroids( - self, embedding: npt.NDArray, indices: List[np.float64], metric: str = "cosine" - ) -> List[np.float64]: + def compute_centroids(self, + embedding: npt.NDArray, + indices: List[np.float64], + metric: str = "cosine") -> List[np.float64]: """ Computes the centroid of a single embedding (e.g. single layer) for all the bins. @@ -95,7 +95,8 @@ def compute_centroids( for bin_idx in range(indices.shape[0]): embedding_scaled = self.scale_embedding(embedding, metric) bin_indices = indices[bin_idx, :] - centroids.append(self.compute_centroid(embedding_scaled, bin_indices)) + centroids.append( + self.compute_centroid(embedding_scaled, bin_indices)) return centroids @@ -111,7 +112,9 @@ class Intrabin(DistanceMetric): The distance metric to use for computing distances (default is "cosine"). """ - def __init__(self, indices: List[np.int64], metric: Optional[str] = "cosine"): + def __init__(self, + indices: List[np.int64], + metric: Optional[str] = "cosine"): self.indices = indices self.metric = metric @@ -140,17 +143,16 @@ def _compute_distance(self, embedding: npt.NDArray) -> np.float64: bin_data, metric=self.metric ) # Pairwise distances within the bin -> distances is list of x1x2,x1x3,x1x4... mean_intra_distance = np.mean( - intra_distances - ) # Mean of the pairwise distances + intra_distances) # Mean of the pairwise distances distances.append(mean_intra_distance) return np.mean(distances) def plot( - self, - distance_dict: Dict[str, npt.NDArray], - title: str = "Intra-bin distance", - figsize: tuple = (15, 5), + self, + distance_dict: Dict[str, npt.NDArray], + title: str = "Intra-bin distance", + figsize: tuple = (15, 5), ) -> matplotlib.figure.Figure: """ Plots the intra-bin distances. @@ -219,32 +221,29 @@ def _compute_distance(self, embedding: npt.NDArray) -> np.float64: for i in range(len(self.repetition_indices[0])): rep_indices = self.repetition_indices[bin_idx][ - i - ] # Get indices for the current repetition + i] # Get indices for the current repetition embedding_scaled = self.scale_embedding(embedding, self.metric) repetition_centroids.append( - self.compute_centroid(embedding_scaled, rep_indices) - ) + self.compute_centroid(embedding_scaled, rep_indices)) # Compute pairwise distances between centroids using cosine distance - bin_distances = cdist( - repetition_centroids, repetition_centroids, metric=self.metric - ) + bin_distances = cdist(repetition_centroids, + repetition_centroids, + metric=self.metric) # Extract non-diagonal elements to get distances between different repetitions non_diagonal_distances = bin_distances[ - ~np.eye(bin_distances.shape[0], dtype=bool) - ] + ~np.eye(bin_distances.shape[0], dtype=bool)] mean_distance = np.mean(non_diagonal_distances) distances.append(mean_distance) return np.mean(distances) def plot( - self, - distance_dict: Dict[str, npt.NDArray], - title: str = "Inter-repetition distance", - figsize: tuple = (15, 5), + self, + distance_dict: Dict[str, npt.NDArray], + title: str = "Inter-repetition distance", + figsize: tuple = (15, 5), ) -> matplotlib.figure.Figure: """ Plots the inter-repetition distances. @@ -277,7 +276,9 @@ class Interbin(DistanceMetric): The distance metric to use for computing distances (default is "cosine"). """ - def __init__(self, indices: List[np.int64], metric: Optional[str] = "cosine"): + def __init__(self, + indices: List[np.int64], + metric: Optional[str] = "cosine"): self.indices = indices self.metric = metric @@ -297,24 +298,25 @@ def _compute_distance(self, embedding: npt.NDArray) -> np.float64: The mean inter-bin distance across the embedding (e.g. across one layer). """ - centroids = self.compute_centroids( - embedding=embedding, indices=self.indices, metric=self.metric - ) + centroids = self.compute_centroids(embedding=embedding, + indices=self.indices, + metric=self.metric) # Compute pairwise distances between centroids using metric distances = cdist(centroids, centroids, metric=self.metric) # Compute the mean inter-bin distance for each layer, excluding self-distances - non_diagonal_distances = distances[~np.eye(distances.shape[0], dtype=bool)] + non_diagonal_distances = distances[~np. + eye(distances.shape[0], dtype=bool)] mean_distance = np.mean(non_diagonal_distances) return mean_distance def plot( - self, - distance_dict: Dict[str, npt.NDArray], - title: str = "Inter-bin distance", - figsize: tuple = (15, 5), + self, + distance_dict: Dict[str, npt.NDArray], + title: str = "Inter-bin distance", + figsize: tuple = (15, 5), ) -> matplotlib.figure.Figure: """ Plots the inter-bin distances. @@ -386,10 +388,12 @@ def __init__( self.metric = metric self.distance_label = distance_label - self.indices, self.repetition_indices = self._define_indices(is_discrete_labels) + self.indices, self.repetition_indices = self._define_indices( + is_discrete_labels) def _define_indices( - self, is_discrete_labels: bool = None + self, + is_discrete_labels: bool = None ) -> Tuple[npt.NDArray, Optional[npt.NDArray]]: """ Defines the indices for each bin. @@ -429,9 +433,7 @@ def _define_indices( if is_discrete_labels: # just detect the unique values and find the indices of the bins (each bin is a unique value) # dataset_label is None and is_discrete_labels is True - idxs = discrete_binning( - label=self.label, - ) + idxs = discrete_binning(label=self.label, ) else: # dataset_label is HPC or visual/ is_discrete_labels is False (dataset_label is None) idxs = continuous_binning( @@ -444,16 +446,15 @@ def _define_indices( if self.distance_label == "interrep": # only relevant for visual dataset repetition_indices = repetition_binning( - indices=idxs, data=self.data, dataset_label=self.dataset_label - ) + indices=idxs, data=self.data, dataset_label=self.dataset_label) else: repetition_indices = None return idxs, repetition_indices def compute( - self, activations: List[Union[np.float64, npt.NDArray]] - ) -> List[np.float64]: + self, activations: List[Union[np.float64, + npt.NDArray]]) -> List[np.float64]: """ Computes specified type of distance for multiple layers of embedding data. @@ -472,23 +473,25 @@ def compute( elif self.distance_label == "intrabin": distance = Intrabin(self.indices, self.metric) elif self.distance_label == "interrep": - distance = Interrep(self.indices, self.repetition_indices, self.metric) + distance = Interrep(self.indices, self.repetition_indices, + self.metric) else: raise NotImplementedError( f"Distance {self.distance_label} not yet implemented. Please use 'interbin','interrep' or 'intrabin'." ) - return super().iterate_over_layers(activations, distance._compute_distance) + return super().iterate_over_layers(activations, + distance._compute_distance) @property def __name__(self): return self.distance_label def plot( - self, - distance_dict: Dict[str, npt.NDArray], - title: str = None, - figsize: tuple = (15, 5), + self, + distance_dict: Dict[str, npt.NDArray], + title: str = None, + figsize: tuple = (15, 5), ) -> matplotlib.figure.Figure: """ Plots the computed distances. diff --git a/cebra_lens/quantification/misc.py b/cebra_lens/quantification/misc.py index 4d06925..0b8593a 100644 --- a/cebra_lens/quantification/misc.py +++ b/cebra_lens/quantification/misc.py @@ -7,6 +7,7 @@ from typing import List import warnings + def normalize_minmax(rdm: npt.NDArray) -> npt.NDArray: """ Normalizes a given array using Min-Max normalization. @@ -95,11 +96,8 @@ def continuous_binning( num_bins = 30 if sample_mode == "sub_sample": - num_samples = ( - max_num_samples - if len(data) / num_bins >= max_num_samples - else int(len(data) / num_bins) - ) + num_samples = (max_num_samples if len(data) / num_bins + >= max_num_samples else int(len(data) / num_bins)) elif sample_mode == "all": num_samples = int(len(data) / num_bins) else: @@ -113,9 +111,8 @@ def continuous_binning( j = 0 for i in range(num_bins): - full_idxs = np.where( - (label[:] >= j * step_distance) & (label[:] < (j + 1) * step_distance) - )[0] + full_idxs = np.where((label[:] >= j * step_distance) + & (label[:] < (j + 1) * step_distance))[0] if sample_mode == "sub_sample": idxs[i, :] = sample(list(full_idxs), num_samples) @@ -144,11 +141,9 @@ def continuous_binning( if i == num_bins / 2: j = 0 - full_idxs = np.where( - (label[:, 0] >= j * step_distance) - & (label[:, 0] < (j + 1) * step_distance) - & (label[:, direction] == 1) - )[0] + full_idxs = np.where((label[:, 0] >= j * step_distance) + & (label[:, 0] < (j + 1) * step_distance) + & (label[:, direction] == 1))[0] if sample_mode == "sub_sample": idxs[i, :] = sample(list(full_idxs), num_samples) @@ -165,18 +160,13 @@ def continuous_binning( if len(data) < 1000: warnings.warn( "Continuous binning is not recommended for datasets with less than 1000 samples. " - "Consider using discrete binning instead.", - UserWarning - ) + "Consider using discrete binning instead.", UserWarning) num_bins = int( 0.005 * len(data) ) # 0.005 is a heuristic to get a reasonable number of bins for continuous data if sample_mode == "sub_sample": - num_samples = ( - max_num_samples - if len(data) / num_bins >= max_num_samples - else int(len(data) / num_bins) - ) + num_samples = (max_num_samples if len(data) / num_bins + >= max_num_samples else int(len(data) / num_bins)) elif sample_mode == "all": num_samples = int(len(data) / num_bins) else: @@ -196,9 +186,8 @@ def continuous_binning( for i in range(num_bins): lower_bin_border = round(min_value + i * step_distance, 2) higher_bin_border = round(min_value + (i + 1) * step_distance, 2) - full_idxs = np.where( - (label[:] >= lower_bin_border) & (label[:] < higher_bin_border) - )[0] + full_idxs = np.where((label[:] >= lower_bin_border) + & (label[:] < higher_bin_border))[0] indices.append(full_idxs) # Due do uneven number of samples in each bin, we will take the minimum number of samples from each bin, maybe need to discuss this further @@ -207,14 +196,16 @@ def continuous_binning( print("Number of samples per bin:", num_samples) idxs = np.zeros((num_bins, num_samples)) for i in range(num_bins): - idxs[i, :] = np.random.choice(indices[i], num_samples, replace=False) + idxs[i, :] = np.random.choice(indices[i], + num_samples, + replace=False) return idxs.astype(int), num_bins -def repetition_binning( - indices: npt.NDArray, data, dataset_label: str = "visual" -) -> List[npt.NDArray]: +def repetition_binning(indices: npt.NDArray, + data, + dataset_label: str = "visual") -> List[npt.NDArray]: """ Creates a list of indices for each repetition based on the provided indices and dataset label. @@ -252,7 +243,7 @@ def repetition_binning( for j in range(num_repetitions): - repetition_bin_idxs.append(indices[i][j * step : (j + 1) * step]) + repetition_bin_idxs.append(indices[i][j * step:(j + 1) * step]) repetition_idxs.append(repetition_bin_idxs) return repetition_idxs diff --git a/cebra_lens/quantification/rdm_metric.py b/cebra_lens/quantification/rdm_metric.py index 7151878..e734e37 100644 --- a/cebra_lens/quantification/rdm_metric.py +++ b/cebra_lens/quantification/rdm_metric.py @@ -49,11 +49,8 @@ def __init__( self.label_ind = label_ind self.dataset_label = dataset_label # check that label is 1D if dataset_label is not HPC/visual, and the label_ind is not provided - if ( - isinstance(self.label, np.ndarray) - and self.label.ndim != 1 - and self.dataset_label not in ["HPC", "visual"] - ): + if (isinstance(self.label, np.ndarray) and self.label.ndim != 1 + and self.dataset_label not in ["HPC", "visual"]): # if the dataset contains multiple labels check that if it is not HPC dataset the label_ind was given if self.label_ind != None: self.label = label[:, label_ind] @@ -75,10 +72,12 @@ def output_information(self): print("RDM class initialized with the following parameters:") if self.bool_oracle: print( - "The chosen analyis will plot the correlation of the RDMs with the Oracle RDM." + "The chosen analysis will plot the correlation of the RDMs with the Oracle RDM." ) else: - print("The chosen analysis will plot the RDMs, no Oracle RDM comparison.") + print( + "The chosen analysis will plot the RDMs, no Oracle RDM comparison." + ) if self.dataset_label is None: print( f"The dataset label is not specified, the RDMs will be computed based on the label index {self.label_ind},\n this label has been noted DISCRETE = {self.discrete}." @@ -123,9 +122,7 @@ def _define_indices(self) -> Tuple[npt.NDArray, Optional[int]]: if self.discrete: # just detect the unique values and find the indices of the bins (each bin is a unique value) # dataset_label is None and discrete is True - idxs = discrete_binning( - labels=self.label, - ) + idxs = discrete_binning(labels=self.label, ) else: # dataset_label is HPC or visual/ discrete is False (dataset_label is None) idxs, num_bins = continuous_binning( @@ -171,8 +168,7 @@ def _create_oracle_rdm(self): return oracle_rdm def _compute_per_layer( - self, layer_activation: npt.NDArray - ) -> Tuple[npt.NDArray, float]: + self, layer_activation: npt.NDArray) -> Tuple[npt.NDArray, float]: """ Computes the RDM for a given layer's activation. @@ -190,7 +186,8 @@ def _compute_per_layer( if layer_activation.shape[0] < layer_activation.shape[1]: layer_activation = layer_activation.T - rdm = pdist(layer_activation[self.idxs.flatten(), :], metric=self.metric) + rdm = pdist(layer_activation[self.idxs.flatten(), :], + metric=self.metric) if self.bool_oracle: oracle_rdm = self._create_oracle_rdm() comparison = 1 - correlation(oracle_rdm, rdm) @@ -217,11 +214,12 @@ def compute( A list of tuples, where each tuple contains the computed RDM and the correlation score with the Oracle RDM (if applicable) for each layer of a model. """ if isinstance( - activations, (np.ndarray, torch.Tensor) + activations, (np.ndarray, torch.Tensor) ): # if only one activation is passed instead of a list of arrays activations = [activations] - return super().iterate_over_layers(activations, self._compute_per_layer) + return super().iterate_over_layers(activations, + self._compute_per_layer) @property def __name__(self): diff --git a/cebra_lens/quantification/tsne.py b/cebra_lens/quantification/tsne.py index 9376d42..fd67b5d 100644 --- a/cebra_lens/quantification/tsne.py +++ b/cebra_lens/quantification/tsne.py @@ -44,7 +44,8 @@ def _compute_per_layer(self, layer_activation: npt.NDArray) -> npt.NDArray: layer_activation = layer_activation.T tsne = TSNE(n_components=3) - tsne_embedding = tsne.fit_transform(layer_activation[:, : self.num_samples].T) + tsne_embedding = tsne.fit_transform( + layer_activation[:, :self.num_samples].T) return tsne_embedding def compute( @@ -64,7 +65,8 @@ def compute( List[Union[float, npt.NDArray]] The 2D embedding produced by t-SNE for each layer of a model. """ - return super().iterate_over_layers(activations, self._compute_per_layer) + return super().iterate_over_layers(activations, + self._compute_per_layer) def _check_num_samples(self): """Checks if the number of samples is less than 200. If so, it sets the number of samples to 200 and prints a warning message.""" diff --git a/cebra_lens/utils.py b/cebra_lens/utils.py index 1c04f28..6afb58f 100644 --- a/cebra_lens/utils.py +++ b/cebra_lens/utils.py @@ -15,7 +15,8 @@ def get_data( - dataset_label: str = None, session_id: int = None + dataset_label: str = None, + session_id: int = None ) -> list[npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray]: """ Returns datasets based on the specified dataset label. If you are using a non standard dataset, you can add a new data loading function add it here. @@ -169,7 +170,8 @@ def plot_metric( def model_loader( - model_dir: str, groups: Dict[str, str] = {} + model_dir: str, + groups: Dict[str, str] = {} ) -> Dict[str, List[cebra.integrations.sklearn.cebra.CEBRA]]: """ Loads and categorizes CEBRA models from a given directory. @@ -213,8 +215,8 @@ def model_loader( for file in models_folder_path.iterdir(): if str(file).endswith((".pt", ".pth")): loaded_model = cebra.CEBRA.load( - file, backend="torch", map_location=torch.device("cpu") - ).to("cpu") + file, backend="torch", + map_location=torch.device("cpu")).to("cpu") key = groups.get(file.stem, file.stem) models.setdefault(key, []).append(loaded_model) print(f"Model {file.stem} loaded successfully.") diff --git a/cebra_lens/utils_allen.py b/cebra_lens/utils_allen.py index b37e325..5068aba 100644 --- a/cebra_lens/utils_allen.py +++ b/cebra_lens/utils_allen.py @@ -46,49 +46,39 @@ def get_datasets( train_datas.append( cebra.datasets.init( f"allen-movie1-ca-single-session-decoding-corrupt-{i}-repeat-{test_session}-train" - ) - ) + )) valid_datas.append( cebra.datasets.init( f"allen-movie1-ca-single-session-decoding-corrupt-{i}-repeat-{test_session}-test" - ) - ) + )) else: for i in range(4): train_datas.append( cebra.datasets.init( f"allen-movie1-ca-single-session-decoding-{i}-repeat-{test_session}-train" - ) - ) + )) valid_datas.append( cebra.datasets.init( f"allen-movie1-ca-single-session-decoding-{i}-repeat-{test_session}-test" - ) - ) + )) if pseudomice: for i in range(len(train_datas)): train_datas[i].neural = torch.from_numpy( obtain_pseudomice( - [train_datas[i].neural for i in range(len(train_datas))] - ) - ) + [train_datas[i].neural for i in range(len(train_datas))])) valid_datas[i].neural = torch.from_numpy( obtain_pseudomice( - [valid_datas[i].neural for i in range(len(train_datas))] - ) - ) + [valid_datas[i].neural for i in range(len(train_datas))])) # Add noise to the 4th mouse only if shot_noise is not None: # train_datas[0].neural = _add_shot_noise(train_datas[0].neural, scale_factor=shot_noise) - valid_datas[3].neural = _add_shot_noise( - valid_datas[3].neural, scale_factor=shot_noise - ) + valid_datas[3].neural = _add_shot_noise(valid_datas[3].neural, + scale_factor=shot_noise) elif gaussian_noise is not None: # train_datas[0].neural = _add_gaussian_noise(train_datas[0].neural, sigma=gaussian_noise) - valid_datas[3].neural = _add_gaussian_noise( - valid_datas[3].neural, sigma=gaussian_noise - ) + valid_datas[3].neural = _add_gaussian_noise(valid_datas[3].neural, + sigma=gaussian_noise) # discrete_labels = [np.tile(np.arange(900), 10) for i in range(len(mice))] discrete_labels_train = [np.tile(np.arange(900), 9) for i in range(mice)] @@ -118,16 +108,15 @@ def obtain_pseudomice(mice, num_neurons_per_mouse=80): pseudomice_matrix = None for i, session in enumerate(mice): session_length = session.shape[1] - selected_neurons = np.random.choice( - session_length, replace=False, size=num_neurons_per_mouse - ) + selected_neurons = np.random.choice(session_length, + replace=False, + size=num_neurons_per_mouse) neuron_ids.append(selected_neurons) if pseudomice_matrix is None: pseudomice_matrix = session[:, selected_neurons] else: pseudomice_matrix = np.concatenate( - (pseudomice_matrix, session[:, selected_neurons]), axis=1 - ) + (pseudomice_matrix, session[:, selected_neurons]), axis=1) pseudomouse = copy.deepcopy(mice[0]) pseudomouse = pseudomice_matrix @@ -154,7 +143,7 @@ def create_sequences(embedding, labels, seq_len=10): sequences = [] sequence_labels = [] for i in range(len(embedding) - seq_len): - seq = embedding[i : i + seq_len] + seq = embedding[i:i + seq_len] # Label is the frame number following the sequence label = labels[i + seq_len] sequences.append(seq) @@ -162,20 +151,21 @@ def create_sequences(embedding, labels, seq_len=10): return np.array(sequences), np.array(sequence_labels) -def decoding_frames( - embedding_train, embedding_test, label_train, label_test, time_window=1, seq_len=1 -): +def decoding_frames(embedding_train, + embedding_test, + label_train, + label_test, + time_window=1, + seq_len=1): """1-frame decoding. TODO(celia): Implement n-frames decoding. Started but not functional yet. """ if seq_len > 1: embedding_train, label_train = create_sequences( - embedding_train, label_train, seq_len - ) - embedding_test, label_test = create_sequences( - embedding_test, label_test, seq_len - ) + embedding_train, label_train, seq_len) + embedding_test, label_test = create_sequences(embedding_test, + label_test, seq_len) params = np.power(np.linspace(1, 10, 10, dtype=int), 2) errs = [] @@ -185,33 +175,27 @@ def decoding_frames( if seq_len > 1: train_decoder.fit( embedding_train[:train_valid_idx].reshape( - -1, seq_len * embedding_train.shape[2] - ), + -1, seq_len * embedding_train.shape[2]), label_train[:train_valid_idx], ) pred = train_decoder.predict( embedding_train[train_valid_idx:].reshape( - -1, seq_len * embedding_train.shape[2] - ) - ) + -1, seq_len * embedding_train.shape[2])) else: - train_decoder.fit( - embedding_train[:train_valid_idx], label_train[:train_valid_idx] - ) + train_decoder.fit(embedding_train[:train_valid_idx], + label_train[:train_valid_idx]) pred = train_decoder.predict(embedding_train[train_valid_idx:]) err = label_train[train_valid_idx:] - pred errs.append(abs(err).sum()) - test_decoder = cebra.KNNDecoder( - n_neighbors=params[np.argmin(errs)], metric="cosine" - ) + test_decoder = cebra.KNNDecoder(n_neighbors=params[np.argmin(errs)], + metric="cosine") if seq_len > 1: test_decoder.fit( - embedding_train.reshape(-1, seq_len * embedding_train.shape[2]), label_train - ) + embedding_train.reshape(-1, seq_len * embedding_train.shape[2]), + label_train) frame_pred = test_decoder.predict( - embedding_test.reshape(-1, seq_len * embedding_test.shape[2]) - ) + embedding_test.reshape(-1, seq_len * embedding_test.shape[2])) else: test_decoder.fit(embedding_train, label_train) frame_pred = test_decoder.predict(embedding_test) diff --git a/cebra_lens/utils_hpc.py b/cebra_lens/utils_hpc.py index dbb39db..47ea25e 100644 --- a/cebra_lens/utils_hpc.py +++ b/cebra_lens/utils_hpc.py @@ -26,7 +26,8 @@ def get_datasets( for i in rats: data = cebra.datasets.init(f"rat-hippocampus-single-{i}") - neural_train, neural_valid, label_train, label_valid = split_data_HPC(data) + neural_train, neural_valid, label_train, label_valid = split_data_HPC( + data) train_datas.append(neural_train) valid_datas.append(neural_valid) continuous_labels_train.append(label_train) @@ -80,7 +81,8 @@ def decoding_pos_dir( test_score = sklearn.metrics.r2_score(label_test[:, :2], prediction) pos_test_err = np.median(abs(prediction[:, 0] - label_test[:, 0])) - pos_test_score = sklearn.metrics.r2_score(label_test[:, 0], prediction[:, 0]) + pos_test_score = sklearn.metrics.r2_score(label_test[:, 0], prediction[:, + 0]) return test_score, pos_test_err, pos_test_score diff --git a/docs/docs/installation.md b/docs/docs/installation.md index 82a50b2..60872a1 100644 --- a/docs/docs/installation.md +++ b/docs/docs/installation.md @@ -3,5 +3,5 @@ Maybe here describe the requirement.txt installation - environment Also describe how to load a model and data --- so that in the metrics noteboook there is only the metric usage on the model and the data +-- so that in the metrics notebook there is only the metric usage on the model and the data diff --git a/tests/test_activations.py b/tests/test_activations.py index 7935dd3..8e8bd27 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -55,7 +55,9 @@ def test_get_activations_model_basic(monkeypatch): monkeypatch.setattr( activations, "_attach_hooks", - lambda *a, **kw: ({"test_layer": np.ones((5, 3))}, [], [3]), + lambda *a, **kw: ({ + "test_layer": np.ones((5, 3)) + }, [], [3]), ) result = get_activations_model(model, data, layer_type=torch.nn.Conv1d) assert isinstance(result, dict) diff --git a/tests/test_decoding.py b/tests/test_decoding.py index d53c556..a248156 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -17,13 +17,15 @@ def embeddings_labels(): def test_decoding_function(embeddings_labels): emb_train, emb_test, label_train, label_test = embeddings_labels - with patch("cebra_lens.quantification.decoding.cebra.KNNDecoder") as mock_knn: + with patch( + "cebra_lens.quantification.decoding.cebra.KNNDecoder") as mock_knn: mock_model = MagicMock() # Return prediction with correct shape each time mock_model.predict.side_effect = lambda x: np.random.rand(len(x)) mock_knn.return_value = mock_model - score, medians, r2s = decoding(emb_train, emb_test, label_train, label_test) + score, medians, r2s = decoding(emb_train, emb_test, label_train, + label_test) assert isinstance(score, float) assert len(medians) == label_train.shape[1] diff --git a/tests/test_misc.py b/tests/test_misc.py index ecb3768..17007c0 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -13,7 +13,8 @@ def test_discrete_binning_shape_and_values(): idxs = discrete_binning(labels) assert isinstance(idxs, np.ndarray) assert idxs.shape[0] == 3 # three unique labels - assert all(len(set(row)) == len(row) for row in idxs) # no duplicates in each bin + assert all(len(set(row)) == len(row) + for row in idxs) # no duplicates in each bin def test_continuous_binning_general_continuous(): @@ -27,6 +28,6 @@ def test_continuous_binning_general_continuous(): def test_repetition_binning_invalid_dataset(): with pytest.raises(NotImplementedError): - repetition_binning( - np.zeros((3, 90), dtype=int), np.random.rand(900, 10), dataset_label="HPC" - ) + repetition_binning(np.zeros((3, 90), dtype=int), + np.random.rand(900, 10), + dataset_label="HPC") diff --git a/tests/test_rdm.py b/tests/test_rdm.py index 5439aab..d6dd6b7 100644 --- a/tests/test_rdm.py +++ b/tests/test_rdm.py @@ -28,7 +28,7 @@ def test_define_indices_continuous(mock_binning, dummy_data, dummy_labels): def test_define_indices_discrete(mock_binning, dummy_data): labels = torch.tensor([0, 1, 0, 1, 2]) mock_binning.return_value = np.array([[0, 2], [1, 3]]) - rdm = RDM(data=dummy_data, label=labels, discrete=True) + rdm = RDM(data=dummy_data, label=labels, is_discrete_labels=True) idxs, bins = rdm._define_indices() assert isinstance(idxs, np.ndarray) assert bins is None @@ -37,16 +37,20 @@ def test_define_indices_discrete(mock_binning, dummy_data): def test_init_with_label_ind(): labels = np.array([[1, 2], [3, 4]]) data = torch.tensor(np.random.rand(2, 5), dtype=torch.float32) - rdm = RDM(data=data, label=labels, label_ind=0, dataset_label=None, discrete=True) + rdm = RDM(data=data, + label=labels, + label_ind=0, + dataset_label=None, + is_discrete_labels=True) assert rdm.label.tolist() == [1, 3] def test_create_oracle_rdm_custom(): rdm = RDM( data=torch.rand((10000, 5)), - label=torch.randint(0, 5, (10000,)), + label=torch.randint(0, 5, (10000, )), dataset_label=None, - discrete=True, + is_discrete_labels=True, ) rdm.idxs = np.array([[0, 1], [2, 3]]) oracle = rdm._create_oracle_rdm() @@ -56,8 +60,8 @@ def test_create_oracle_rdm_custom(): def test_compute_per_layer_and_bool_oracle(): rdm = RDM( data=torch.rand((10000, 5)), - label=torch.randint(0, 5, (10000,)), - discrete=True, + label=torch.randint(0, 5, (10000, )), + is_discrete_labels=True, ) rdm.idxs = np.array([[i] for i in range(10000)]) dummy_layer = np.random.rand(10000, 5) @@ -68,8 +72,8 @@ def test_compute_per_layer_and_bool_oracle(): def test_compute_single_activation_tensor(): data = torch.rand((10000, 5)) - label = torch.randint(0, 3, (10000,)) - rdm = RDM(data=data, label=label, discrete=False) + label = torch.randint(0, 3, (10000, )) + rdm = RDM(data=data, label=label, is_discrete_labels=False) rdm.idxs = np.array([[i] for i in range(10000)]) act = torch.rand((10000, 5)).numpy() result = rdm.compute(act) @@ -78,28 +82,18 @@ def test_compute_single_activation_tensor(): assert result[0][0].ndim == 2 # RDM squareform -def test_setters_work(): - rdm = RDM( - data=torch.rand((10000, 5)), label=torch.randint(0, 5, (10000,)), discrete=False - ) - rdm.set_num_bins(3) - rdm.set_bool_oracle(False) - assert rdm.num_bins == 3 - assert rdm.bool_oracle is False - - @patch("cebra_lens.quantification.rdm_metric.plot_rdm_correlation") @patch("cebra_lens.quantification.rdm_metric.plot_rdm_all") def test_plot(mock_all, mock_corr): - dummy_rdm = RDM( - data=torch.rand((10000, 5)), label=torch.randint(0, 5, (10000,)), discrete=False - ) + dummy_rdm = RDM(data=torch.rand((10000, 5)), + label=torch.randint(0, 5, (10000, )), + is_discrete_labels=False) dummy_rdm.bool_oracle = True dummy_rdm.plot({"group": [np.random.rand(5, 5)]}) assert mock_corr.called dummy_rdm.bool_oracle = False dummy_rdm.num_bins = 2 - dummy_rdm.discrete = True + dummy_rdm.is_discrete_labels = True dummy_rdm.plot({"group": [np.random.rand(5, 5)]}) assert mock_all.called diff --git a/tests/test_utils.py b/tests/test_utils.py index 2c01197..6c5eaa6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -29,6 +29,7 @@ def test_compute_metric_with_mock_metric_class(): dummy_data = {"group1": [np.array([1, 2]), np.array([3, 4])]} class DummyMetric: + def compute(self, sample): return sample.sum() @@ -37,7 +38,7 @@ def compute(self, sample): def test_compute_metric_with_decoding(): - dummy_data = {"group": [np.ones((5,)), np.ones((5,))]} + dummy_data = {"group": [np.ones((5, )), np.ones((5, ))]} mock_metric = MagicMock(spec=Decoding) mock_metric.compute.side_effect = lambda x: x.sum() @@ -46,10 +47,14 @@ def test_compute_metric_with_decoding(): def test_plot_metric_single_model(monkeypatch): + class DummyRDM: + def plot(self, data_dict, **kwargs): assert isinstance(data_dict, dict) return "plotted" - result = plot_metric({"data": np.array([1, 2, 3])}, DummyRDM(), group_name="test") + result = plot_metric({"data": np.array([1, 2, 3])}, + DummyRDM(), + group_name="test") assert result == "plotted" From 9669054e16fb512a8f74d0cfc1d0f7938a369a8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:02:35 +0200 Subject: [PATCH 04/12] Add gitactions for formatting and codespell --- .github/workflows/python-package.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 6492f2d..0e94fc6 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -45,6 +45,18 @@ jobs: python -m pip install --upgrade pip setuptools wheel pip install -r requirements.txt + - name: Run the formatter + run: | + make format + + - name: Run the spelling detector + run: | + make codespell + + - name: Check the documentation coverage + run: | + make interrogate + - name: Run all pytest tests shell: bash -el {0} run: | From cb28ccc88055e6452333bc27a43bad2bf1b50e16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:02:51 +0200 Subject: [PATCH 05/12] Update title of visual notebook --- demos/UsageDemoVISUAL.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/UsageDemoVISUAL.ipynb b/demos/UsageDemoVISUAL.ipynb index 10098f5..29402dc 100644 --- a/demos/UsageDemoVISUAL.ipynb +++ b/demos/UsageDemoVISUAL.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# CEBRA-Lens Demo\n", + "# CEBRA-Lens Demo: 2P visual coding dataset (Allen Institute)\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AdaptiveMotorControlLab/CEBRA-lens/blob/main/demos/UsageDemoVISUAL.ipynb)\n", "\n", From be26f03ca2ef6ef88bb53745d557202904c75102 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:11:36 +0200 Subject: [PATCH 06/12] Add the shell in yml git action file --- .github/workflows/python-package.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 0e94fc6..09708d2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -46,14 +46,17 @@ jobs: pip install -r requirements.txt - name: Run the formatter + shell: bash -el {0} run: | make format - name: Run the spelling detector + shell: bash -el {0} run: | make codespell - name: Check the documentation coverage + shell: bash -el {0} run: | make interrogate From c3d11e2a6ff83b45ae82a64f6ba6523c6cde0832 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:17:05 +0200 Subject: [PATCH 07/12] Remove yapf on notebooks --- Makefile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Makefile b/Makefile index 9f3705f..58a0ee8 100644 --- a/Makefile +++ b/Makefile @@ -44,8 +44,6 @@ serve_page: format: yapf -i -p -r cebra_lens yapf -i -p -r tests - yapf -i -p -r docs/docs/examples - yapf -i -p -r docs/docs/conf.py isort cebra_lens/ isort tests/ From 89b9ac7544801070cde72e9febdf7798b792f822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 16:57:22 +0200 Subject: [PATCH 08/12] Fix tests and add more --- cebra_lens/__init__.py | 3 +- cebra_lens/activations.py | 8 +- cebra_lens/quantification/__init__.py | 2 +- cebra_lens/quantification/cka_metric.py | 14 +-- .../{decoding.py => decoder.py} | 13 +-- cebra_lens/quantification/rdm_metric.py | 14 +-- cebra_lens/utils.py | 2 +- tests/test_activations.py | 7 +- tests/test_cka.py | 86 +++++++++++++++++++ tests/test_decoding.py | 32 ++++--- tests/test_rdm.py | 61 ++++++++++--- tests/test_utils.py | 2 +- 12 files changed, 194 insertions(+), 50 deletions(-) rename cebra_lens/quantification/{decoding.py => decoder.py} (97%) create mode 100644 tests/test_cka.py diff --git a/cebra_lens/__init__.py b/cebra_lens/__init__.py index eb46797..a212c8b 100644 --- a/cebra_lens/__init__.py +++ b/cebra_lens/__init__.py @@ -1,7 +1,6 @@ # example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations from .activations import * -from .quantification import * -from .quantification.decoding import * +from .quantification.decoder import * from .quantification.distance import * from .quantification.cka_metric import * from .quantification.rdm_metric import * diff --git a/cebra_lens/activations.py b/cebra_lens/activations.py index 69b38a0..c3154de 100644 --- a/cebra_lens/activations.py +++ b/cebra_lens/activations.py @@ -80,7 +80,8 @@ def get_cut_indices( cut_indices.append((0, 0)) elif layer_type == None: raise NotImplementedError( - "Padding handling not implemented for 'all'.") + "Padding handling not implemented to handle activations for all layer types.", + "Set layer_type to nn.Conv1d to use the default padding handling.") else: # need to analyze the padding from the last output of Conv1 and apply the same cut raise NotImplementedError( @@ -94,7 +95,7 @@ def get_activations_model( session_id: int = -1, name: str = "single", instance: int = 0, - layer_type: Type[nn.Module] = None, + layer_type: Type[nn.Module] = nn.Conv1d, ) -> Dict[str, npt.NDArray]: """ Extracts activations from a single model layer. @@ -112,7 +113,8 @@ def get_activations_model( instance : int The instance number for the model, used to differentiate between models from the same model category. layer_type : Type[nn.Module] - The type of layer to extract activations from. Defaults to None, meaning extracts activations from all layers. + The type of layer to extract activations from. None means it extracts activations from all layers. + Default is nn.Conv1d, which is the most common layer type used in CEBRA models. Returns: -------- diff --git a/cebra_lens/quantification/__init__.py b/cebra_lens/quantification/__init__.py index 2085eb1..7df8e76 100644 --- a/cebra_lens/quantification/__init__.py +++ b/cebra_lens/quantification/__init__.py @@ -2,6 +2,6 @@ from .rdm_metric import * from .misc import * from .distance import * -from .decoding import * +from .decoder import * from .base import * from .tsne import * diff --git a/cebra_lens/quantification/cka_metric.py b/cebra_lens/quantification/cka_metric.py index 741da37..06f9e30 100644 --- a/cebra_lens/quantification/cka_metric.py +++ b/cebra_lens/quantification/cka_metric.py @@ -8,9 +8,10 @@ from tqdm import tqdm import numpy as np from .base import _BaseMetric -from ..matplotlib import * +import cebra_lens.matplotlib as cebra_lens_matplotlib from typing import Optional, List, Dict, Tuple import numpy.typing as npt +import matplotlib class CKA(_BaseMetric): @@ -188,7 +189,8 @@ def _compute_per_layer( cka_matrix = np.zeros((len(embeddings_1), len(embeddings_1[0]))) for j in tqdm(range(len(embeddings_1))): if flag: - # the situation when there multiple models inside model labels and the same number of models inside each label + # the situation when there multiple models inside model labels and the same number of + # models inside each label cka_matrix[j, :] = self._compute_cka(embeddings_1[j], embeddings_2[j]) else: @@ -207,7 +209,8 @@ def compute(self, activations: Dict[str, npt.NDArray], Parameters: ----------- activations : Dict[str, npt.NDArray] - A dictionary where keys are strings which represent the model label and values are 2d lists with the corresponding activations per layer. + A dictionary where keys are strings which represent the model label and values are 2d lists + with the corresponding activations per layer. comparison : Tuple[str, str] A tuple containing the model labels to compare. @@ -227,7 +230,8 @@ def compute(self, activations: Dict[str, npt.NDArray], if len(activations_1) != len(activations_2): # if the number of models in a label is different from the other model label - # choose embeddings_1 for the one with more models, and then embeddings_2 just compare with the first model + # choose embeddings_1 for the one with more models, and then embeddings_2 just compare with + # the first model if len(activations_1) > len(activations_2): embeddings_1 = activations_1 embeddings_2 = activations_2[0] @@ -293,7 +297,7 @@ def plot( matplotlib.axes.Axes The axes on which the heatmap is plotted. """ - return plot_cka_heatmaps( + return cebra_lens_matplotlib.plot_cka_heatmaps( cka_matrices, annot, show_cbar, diff --git a/cebra_lens/quantification/decoding.py b/cebra_lens/quantification/decoder.py similarity index 97% rename from cebra_lens/quantification/decoding.py rename to cebra_lens/quantification/decoder.py index cb58fbe..3c5d8e4 100644 --- a/cebra_lens/quantification/decoding.py +++ b/cebra_lens/quantification/decoder.py @@ -5,12 +5,13 @@ from ..utils_hpc import decoding_pos_dir from ..activations import get_activations_model from .base import _BaseMetric -from ..matplotlib import * +import cebra_lens.matplotlib as cebra_lens_matplotlib import numpy.typing as npt -from typing import Dict, Type, Tuple +from typing import Dict, Type, Tuple, Optional import torch.nn as nn import sklearn.metrics import torch as pt +import matplotlib def decoding( @@ -118,7 +119,7 @@ def __init__( test_label: npt.NDArray, session_id: int = 0, dataset_label: str = None, - layer_type: Optional[Type[nn.Module]] = None, + layer_type: Optional[Type[nn.Module]] = nn.Conv1d, output_only: bool = True, ): @@ -315,7 +316,7 @@ def compute( def __name__(self): return "decode_by_layer" - def set_output_only(self, output_only): + def set_output_only(self, output_only: bool) -> None: """ Set the output_only parameter to True or False. If True, it will compute the decoding scores for the output embeddings of the model, otherwise it will compute the decoding scores for the activations of the model. @@ -369,8 +370,8 @@ def plot( ) if self.output_only: - return plot_decoding(results_dict, palette, self.dataset_label, + return cebra_lens_matplotlib.plot_decoding(results_dict, palette, self.dataset_label, label, plot_error, ax) else: - return plot_layer_decoding(results_dict, title, self.dataset_label, + return cebra_lens_matplotlib.plot_layer_decoding(results_dict, title, self.dataset_label, label, plot_error, figsize) diff --git a/cebra_lens/quantification/rdm_metric.py b/cebra_lens/quantification/rdm_metric.py index e734e37..1eeed4b 100644 --- a/cebra_lens/quantification/rdm_metric.py +++ b/cebra_lens/quantification/rdm_metric.py @@ -1,13 +1,15 @@ """All the functions relative to the Representation Dissimilarity Matrix (RDM) calculation""" +from typing import Dict, List, Optional import numpy as np from scipy.linalg import block_diag from typing import List, Optional, Tuple, Union from scipy.spatial.distance import correlation, pdist, squareform from .misc import discrete_binning, continuous_binning import torch +import matplotlib from .base import _BaseMetric -from ..matplotlib import * +import cebra_lens.matplotlib as cebra_lens_matplotlib import numpy.typing as npt @@ -21,8 +23,8 @@ class RDM(_BaseMetric): The data array of shape (num_samples, num_features). label : torch.Tensor The array of labels corresponding to the data. - discrete : bool, optional - Whether the labels are discrete or continuous. If None, it will be determined based on the dataset_label. + is_discrete_labels : bool, optional + Whether the labels are discrete or continuous. By default, it is False, meaning the labels are continuous. dataset_label : str, optional The dataset type, either 'visual' or 'HPC'. Default is 'visual'. metric : str, optional @@ -37,7 +39,7 @@ def __init__( self, data: torch.Tensor, label: torch.Tensor, - is_discrete_labels: bool = None, + is_discrete_labels: bool = False, dataset_label: str = None, metric: str = "correlation", bool_oracle: bool = True, @@ -254,9 +256,9 @@ def plot( The figure containing the plotted RDMs. """ if self.bool_oracle: - return plot_rdm_correlation(rdms) + return cebra_lens_matplotlib.plot_rdm_correlation(rdms) else: - return plot_rdm_all( + return cebra_lens_matplotlib.plot_rdm_all( rdms=rdms, labels=self.label, num_bins=self.num_bins, diff --git a/cebra_lens/utils.py b/cebra_lens/utils.py index 6afb58f..5aa15e7 100644 --- a/cebra_lens/utils.py +++ b/cebra_lens/utils.py @@ -6,7 +6,7 @@ import numpy.typing as npt from tqdm import tqdm from torch import nn -from .quantification.decoding import Decoding +from .quantification.decoder import Decoding from .quantification.rdm_metric import RDM from .quantification.cka_metric import CKA from .quantification.tsne import Tsne diff --git a/tests/test_activations.py b/tests/test_activations.py index 8e8bd27..34a8355 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -1,5 +1,5 @@ -import pytest import torch +import pytest import numpy as np from collections import namedtuple from unittest.mock import MagicMock @@ -22,7 +22,7 @@ def test_cut_array_with_cut(): np.testing.assert_array_equal(result, np.array([[2, 3, 4]])) -def test_get_cut_indices_conv1d(): +def test_get_cut_indices(): Offset = namedtuple("Offset", ["left", "right"]) # Mock the model's get_offset behavior @@ -32,6 +32,9 @@ def test_get_cut_indices_conv1d(): result = get_cut_indices(model_mock, torch.nn.Conv1d, [3, 3]) assert isinstance(result, list) assert all(isinstance(x, tuple) and len(x) == 2 for x in result) + + with pytest.raises(NotImplementedError, match="Padding handling not implemented*"): + get_cut_indices(model_mock, None, [3, 3]) def make_mock_cebra_model(): diff --git a/tests/test_cka.py b/tests/test_cka.py new file mode 100644 index 0000000..2e41021 --- /dev/null +++ b/tests/test_cka.py @@ -0,0 +1,86 @@ +import pytest +import numpy as np +import torch +from unittest.mock import patch, MagicMock +import cebra_lens + +@pytest.fixture +def dummy_comparisons(): + return [("A", "B")] + +@pytest.fixture +def dummy_cka(dummy_comparisons): + return cebra_lens.quantification.cka_metric.CKA(comparisons=dummy_comparisons) + +@pytest.fixture +def dummy_activations(): + # Simulate Conv1D and Linear layer activations as 2D arrays (samples, features) + batch_size = 10 + conv_channels = 4 + conv_length = 8 + linear_features = 5 + + # Conv1D output: (batch_size, conv_channels, conv_length) -> flatten to (batch_size, conv_channels * conv_length) + conv1d_activations_A = np.random.rand(batch_size, conv_channels, conv_length).reshape(batch_size, -1) + linear_activations_A = np.random.rand(batch_size, linear_features) + conv1d_activations_B = np.random.rand(batch_size, conv_channels, conv_length).reshape(batch_size, -1) + linear_activations_B = np.random.rand(batch_size, linear_features) + + # Each group has a list of 2D arrays (one per layer) + return { + "A": np.array([[np.random.rand(5, 10), np.random.rand(5, 10)], + [np.random.rand(5, 10), np.random.rand(5, 10)]]), + "B": np.array([[np.random.rand(5, 10), np.random.rand(5, 10)], + [np.random.rand(5, 10), np.random.rand(5, 10)]]), + } + +def test_center_gram_symmetry(dummy_cka): + mat = np.eye(5) + centered = dummy_cka.center_gram(mat) + assert np.allclose(centered, centered.T) + +def test_center_gram_unbiased(dummy_cka): + mat = np.eye(5) + centered = dummy_cka.center_gram(mat, unbiased=True) + assert np.allclose(centered, centered.T) + +def test_gram_linear(dummy_cka): + x = np.random.rand(10, 5) + gram = dummy_cka.gram_linear(x) + assert gram.shape == (10, 10) + assert np.allclose(gram, gram.T) + +def test_cka_value(dummy_cka): + x = np.random.rand(10, 5) + y = np.random.rand(10, 5) + gram_x = dummy_cka.gram_linear(x) + gram_y = dummy_cka.gram_linear(y) + val = dummy_cka.cka(gram_x, gram_y) + assert isinstance(val, float) or isinstance(val, np.floating) + +def test_compute_cka_shape(dummy_cka): + emb1 = [np.random.rand(5, 10), np.random.rand(5, 10)] + emb2 = [np.random.rand(5, 10), np.random.rand(5, 10)] + result = dummy_cka._compute_cka(emb1, emb2) + assert result.shape == (1, 2) + +def test_compute_per_layer_shape(dummy_cka): + emb1 = [ [np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3) ] + emb2 = [ [np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3) ] + result = dummy_cka._compute_per_layer(emb1, emb2, flag=True) + assert result.shape == (3, 2) + +def test_compute(dummy_cka, dummy_activations): + result = dummy_cka.compute(dummy_activations, ("A", "B")) + assert isinstance(result, np.ndarray) + + +def test_compute_intra_label(dummy_cka, dummy_activations): + result = dummy_cka.compute(dummy_activations, ("A", "A")) + assert isinstance(result, np.ndarray) + +@patch("cebra_lens.matplotlib.plot_cka_heatmaps") +def test_plot_calls_heatmap(mock_plot, dummy_cka): + cka_matrices = {"A": np.random.rand(2, 2)} + dummy_cka.plot(cka_matrices, annot=True) + assert mock_plot.called \ No newline at end of file diff --git a/tests/test_decoding.py b/tests/test_decoding.py index a248156..888b4e3 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -2,7 +2,7 @@ import numpy as np import torch from unittest.mock import patch, MagicMock -from cebra_lens.quantification.decoding import decoding, Decoding +import cebra_lens @pytest.fixture @@ -18,13 +18,13 @@ def test_decoding_function(embeddings_labels): emb_train, emb_test, label_train, label_test = embeddings_labels with patch( - "cebra_lens.quantification.decoding.cebra.KNNDecoder") as mock_knn: + "cebra.KNNDecoder") as mock_knn: mock_model = MagicMock() # Return prediction with correct shape each time mock_model.predict.side_effect = lambda x: np.random.rand(len(x)) mock_knn.return_value = mock_model - score, medians, r2s = decoding(emb_train, emb_test, label_train, + score, medians, r2s = cebra_lens.quantification.decoder.decoding(emb_train, emb_test, label_train, label_test) assert isinstance(score, float) @@ -41,7 +41,7 @@ def make_mock_cebra_model(): def test_decoding_class_output_only_true(): model = make_mock_cebra_model() - decoding_class = Decoding( + decoding_class = cebra_lens.quantification.decoder.Decoding( train_data=torch.rand((300, 100)), train_label=np.random.rand(300, 1), test_data=torch.rand((100, 100)), @@ -54,7 +54,7 @@ def test_decoding_class_output_only_true(): assert 0 in results -@patch("cebra_lens.quantification.decoding.get_activations_model") +@patch("cebra_lens.activations.get_activations_model") def test_decoding_class_output_only_false(mock_get_act): model = make_mock_cebra_model() mock_get_act.side_effect = lambda **kwargs: { @@ -62,7 +62,7 @@ def test_decoding_class_output_only_false(mock_get_act): "layer2": np.random.rand(1000, 1000), } - decoding_class = Decoding( + decoding_class = cebra_lens.quantification.decoder.Decoding( train_data=torch.rand((1000, 1000)), train_label=np.random.rand(1000, 1), test_data=torch.rand((1000, 1000)), @@ -73,11 +73,19 @@ def test_decoding_class_output_only_false(mock_get_act): results = decoding_class.compute(model) assert isinstance(results, dict) - assert len(results) == 3 # baseline + 2 layers + assert len(results) == 1 # only one Conv1d layer in the mock model + + decoding_class.layer_type = None + with pytest.raises(NotImplementedError, match="Padding handling not implemented*"): + decoding_class.compute(model) + + decoding_class.layer_type = torch.nn.Linear + with pytest.raises(NotImplementedError, match="Padding handling not implemented*"): + decoding_class.compute(model) def test_set_output_only(): - decoding_instance = Decoding( + decoding_instance = cebra_lens.quantification.decoder.Decoding( train_data=torch.rand((10, 5)), train_label=np.random.rand(10, 1), test_data=torch.rand((10, 5)), @@ -87,11 +95,11 @@ def test_set_output_only(): assert decoding_instance.output_only is False -@patch("cebra_lens.quantification.decoding.plot_decoding") -@patch("cebra_lens.quantification.decoding.plot_layer_decoding") -def test_plot_logic(mock_layer_plot, mock_decoding_plot): +@patch("cebra_lens.matplotlib.plot_decoding") +@patch("cebra_lens.matplotlib.plot_layer_decoding") +def test_decoder_plot(mock_layer_plot, mock_decoding_plot): dummy_result = {"modelA": {0: (0.9, [0.1], [0.8])}} - dec = Decoding( + dec = cebra_lens.quantification.decoder.Decoding( train_data=torch.rand((10, 10)), train_label=np.random.rand(10, 1), test_data=torch.rand((10, 10)), diff --git a/tests/test_rdm.py b/tests/test_rdm.py index d6dd6b7..2a11429 100644 --- a/tests/test_rdm.py +++ b/tests/test_rdm.py @@ -2,7 +2,7 @@ import numpy as np import torch from unittest.mock import patch, MagicMock -from cebra_lens.quantification.rdm_metric import RDM +import cebra_lens @pytest.fixture @@ -18,7 +18,7 @@ def dummy_labels(): @patch("cebra_lens.quantification.rdm_metric.continuous_binning") def test_define_indices_continuous(mock_binning, dummy_data, dummy_labels): mock_binning.return_value = (np.array([[0, 1], [2, 3]]), 2) - rdm = RDM(data=dummy_data, label=dummy_labels, dataset_label="visual") + rdm = cebra_lens.quantification.rdm_metric.RDM(data=dummy_data, label=dummy_labels, dataset_label="visual") idxs, bins = rdm._define_indices() assert isinstance(idxs, np.ndarray) assert bins == 2 @@ -28,7 +28,7 @@ def test_define_indices_continuous(mock_binning, dummy_data, dummy_labels): def test_define_indices_discrete(mock_binning, dummy_data): labels = torch.tensor([0, 1, 0, 1, 2]) mock_binning.return_value = np.array([[0, 2], [1, 3]]) - rdm = RDM(data=dummy_data, label=labels, is_discrete_labels=True) + rdm = cebra_lens.quantification.rdm_metric.RDM(data=dummy_data, label=labels, is_discrete_labels=True) idxs, bins = rdm._define_indices() assert isinstance(idxs, np.ndarray) assert bins is None @@ -37,7 +37,7 @@ def test_define_indices_discrete(mock_binning, dummy_data): def test_init_with_label_ind(): labels = np.array([[1, 2], [3, 4]]) data = torch.tensor(np.random.rand(2, 5), dtype=torch.float32) - rdm = RDM(data=data, + rdm = cebra_lens.quantification.rdm_metric.RDM(data=data, label=labels, label_ind=0, dataset_label=None, @@ -46,7 +46,7 @@ def test_init_with_label_ind(): def test_create_oracle_rdm_custom(): - rdm = RDM( + rdm = cebra_lens.quantification.rdm_metric.RDM( data=torch.rand((10000, 5)), label=torch.randint(0, 5, (10000, )), dataset_label=None, @@ -58,7 +58,7 @@ def test_create_oracle_rdm_custom(): def test_compute_per_layer_and_bool_oracle(): - rdm = RDM( + rdm = cebra_lens.quantification.rdm_metric.RDM( data=torch.rand((10000, 5)), label=torch.randint(0, 5, (10000, )), is_discrete_labels=True, @@ -73,7 +73,7 @@ def test_compute_per_layer_and_bool_oracle(): def test_compute_single_activation_tensor(): data = torch.rand((10000, 5)) label = torch.randint(0, 3, (10000, )) - rdm = RDM(data=data, label=label, is_discrete_labels=False) + rdm = cebra_lens.quantification.rdm_metric.RDM(data=data, label=label, is_discrete_labels=False) rdm.idxs = np.array([[i] for i in range(10000)]) act = torch.rand((10000, 5)).numpy() result = rdm.compute(act) @@ -82,10 +82,10 @@ def test_compute_single_activation_tensor(): assert result[0][0].ndim == 2 # RDM squareform -@patch("cebra_lens.quantification.rdm_metric.plot_rdm_correlation") -@patch("cebra_lens.quantification.rdm_metric.plot_rdm_all") -def test_plot(mock_all, mock_corr): - dummy_rdm = RDM(data=torch.rand((10000, 5)), +@patch("cebra_lens.matplotlib.plot_rdm_correlation") +@patch("cebra_lens.matplotlib.plot_rdm_all") +def test_rdm_plot(mock_all, mock_corr): + dummy_rdm = cebra_lens.quantification.rdm_metric.RDM(data=torch.rand((10000, 5)), label=torch.randint(0, 5, (10000, )), is_discrete_labels=False) dummy_rdm.bool_oracle = True @@ -97,3 +97,42 @@ def test_plot(mock_all, mock_corr): dummy_rdm.is_discrete_labels = True dummy_rdm.plot({"group": [np.random.rand(5, 5)]}) assert mock_all.called + + dummy_rdm.plot({"group": [np.random.rand(5, 5)]}, titles=None) + dummy_rdm.plot({}) + +def test_init_raises_keyerror_for_multilabel_without_label_ind(): + # label is 2D, dataset_label is not HPC/visual, label_ind not provided + labels = np.array([[1, 2], [3, 4]]) + data = torch.tensor(np.random.rand(2, 5), dtype=torch.float32) + with pytest.raises(KeyError): + cebra_lens.quantification.rdm_metric.RDM( + data=data, + label=labels, + is_discrete_labels=True, + dataset_label=None, + label_ind=None, + ) + +def test_define_indices_raises_valueerror_for_invalid_dataset_label(): + data = torch.tensor(np.random.rand(10, 5), dtype=torch.float32) + labels = torch.tensor(np.random.randint(0, 5, size=10), dtype=torch.int64) + with pytest.raises(ValueError): + cebra_lens.quantification.rdm_metric.RDM( + data=data, + label=labels, + is_discrete_labels=True, + dataset_label="invalid_label", + ) + +def test_compute_per_layer_transposes_if_needed(): + rdm = cebra_lens.quantification.rdm_metric.RDM( + data=torch.rand((10, 5)), + label=torch.randint(0, 5, (10, )), + is_discrete_labels=True, + ) + rdm.idxs = np.array([[i] for i in range(10)]) + # Provide activation with shape (features, samples) + dummy_layer = np.random.rand(5, 10) + rdm.bool_oracle = False + rdm._compute_per_layer(dummy_layer) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 6c5eaa6..d6da83f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import numpy as np from unittest.mock import patch, MagicMock from cebra_lens import extract_label, compute_metric, plot_metric, model_loader -from cebra_lens.quantification.decoding import Decoding +from cebra_lens.quantification.decoder import Decoding from cebra_lens.quantification.rdm_metric import RDM from cebra_lens.quantification.cka_metric import CKA From 8abc2661562ee94ae176434257f79ee4e5e1dd72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 16:58:36 +0200 Subject: [PATCH 09/12] Run formatter --- cebra_lens/__init__.py | 6 +-- cebra_lens/activations.py | 10 +++-- cebra_lens/matplotlib.py | 9 +++-- cebra_lens/quantification/__init__.py | 8 ++-- cebra_lens/quantification/base.py | 7 ++-- cebra_lens/quantification/cka_metric.py | 17 +++++---- cebra_lens/quantification/decoder.py | 31 ++++++++------- cebra_lens/quantification/distance.py | 10 +++-- cebra_lens/quantification/misc.py | 7 ++-- cebra_lens/quantification/rdm_metric.py | 16 ++++---- cebra_lens/quantification/tsne.py | 10 +++-- cebra_lens/utils.py | 12 +++--- cebra_lens/utils_allen.py | 7 ++-- cebra_lens/utils_hpc.py | 7 ++-- tests/test_activations.py | 20 +++++----- tests/test_cka.py | 50 ++++++++++++++++++------- tests/test_decoding.py | 23 +++++++----- tests/test_misc.py | 11 +++--- tests/test_rdm.py | 40 +++++++++++++------- tests/test_utils.py | 10 +++-- 20 files changed, 186 insertions(+), 125 deletions(-) diff --git a/cebra_lens/__init__.py b/cebra_lens/__init__.py index a212c8b..4bb3127 100644 --- a/cebra_lens/__init__.py +++ b/cebra_lens/__init__.py @@ -1,14 +1,14 @@ # example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations from .activations import * +from .matplotlib import * +from .quantification.cka_metric import * from .quantification.decoder import * from .quantification.distance import * -from .quantification.cka_metric import * from .quantification.rdm_metric import * from .quantification.tsne import * -from .matplotlib import * +from .utils import * from .utils_allen import * from .utils_hpc import * -from .utils import * # selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean # __all__ = ['get_layer_activations'] diff --git a/cebra_lens/activations.py b/cebra_lens/activations.py index c3154de..34ffec8 100644 --- a/cebra_lens/activations.py +++ b/cebra_lens/activations.py @@ -1,13 +1,15 @@ """Functions to retrieve and handle layer activations""" +from typing import Dict, List, Optional, Tuple, Type + import cebra -import torch -import torch.nn as nn +import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt -from typing import Tuple, Dict, List, Type, Optional +import torch +import torch.nn as nn + from .matplotlib import plot_activations -import matplotlib.pyplot as plt def _cut_array(array: npt.NDArray, diff --git a/cebra_lens/matplotlib.py b/cebra_lens/matplotlib.py index 482db85..ede4b2b 100644 --- a/cebra_lens/matplotlib.py +++ b/cebra_lens/matplotlib.py @@ -1,14 +1,15 @@ """Matplotlib interface to CEBRA-Lens.""" +import random from abc import * -from typing import Optional, Tuple, List, Dict, Union -import seaborn as sns +from typing import Dict, List, Optional, Tuple, Union + import matplotlib.axes import matplotlib.pyplot as plt import numpy as np -import torch import numpy.typing as npt -import random +import seaborn as sns +import torch class _BasePlot: diff --git a/cebra_lens/quantification/__init__.py b/cebra_lens/quantification/__init__.py index 7df8e76..559aa3a 100644 --- a/cebra_lens/quantification/__init__.py +++ b/cebra_lens/quantification/__init__.py @@ -1,7 +1,7 @@ +from .base import * from .cka_metric import * -from .rdm_metric import * -from .misc import * -from .distance import * from .decoder import * -from .base import * +from .distance import * +from .misc import * +from .rdm_metric import * from .tsne import * diff --git a/cebra_lens/quantification/base.py b/cebra_lens/quantification/base.py index b2a3615..19b539e 100644 --- a/cebra_lens/quantification/base.py +++ b/cebra_lens/quantification/base.py @@ -1,11 +1,12 @@ -from tqdm import tqdm -import numpy as np import pickle import types -from typing import List, Union, Dict from abc import * from pathlib import Path +from typing import Dict, List, Union + +import numpy as np import numpy.typing as npt +from tqdm import tqdm class _BaseMetric: diff --git a/cebra_lens/quantification/cka_metric.py b/cebra_lens/quantification/cka_metric.py index 06f9e30..e970230 100644 --- a/cebra_lens/quantification/cka_metric.py +++ b/cebra_lens/quantification/cka_metric.py @@ -5,13 +5,16 @@ """ -from tqdm import tqdm +from typing import Dict, List, Optional, Tuple + +import matplotlib import numpy as np -from .base import _BaseMetric -import cebra_lens.matplotlib as cebra_lens_matplotlib -from typing import Optional, List, Dict, Tuple import numpy.typing as npt -import matplotlib +from tqdm import tqdm + +import cebra_lens.matplotlib as cebra_lens_matplotlib + +from .base import _BaseMetric class CKA(_BaseMetric): @@ -189,7 +192,7 @@ def _compute_per_layer( cka_matrix = np.zeros((len(embeddings_1), len(embeddings_1[0]))) for j in tqdm(range(len(embeddings_1))): if flag: - # the situation when there multiple models inside model labels and the same number of + # the situation when there multiple models inside model labels and the same number of # models inside each label cka_matrix[j, :] = self._compute_cka(embeddings_1[j], embeddings_2[j]) @@ -230,7 +233,7 @@ def compute(self, activations: Dict[str, npt.NDArray], if len(activations_1) != len(activations_2): # if the number of models in a label is different from the other model label - # choose embeddings_1 for the one with more models, and then embeddings_2 just compare with + # choose embeddings_1 for the one with more models, and then embeddings_2 just compare with # the first model if len(activations_1) > len(activations_2): embeddings_1 = activations_1 diff --git a/cebra_lens/quantification/decoder.py b/cebra_lens/quantification/decoder.py index 3c5d8e4..95effd7 100644 --- a/cebra_lens/quantification/decoder.py +++ b/cebra_lens/quantification/decoder.py @@ -1,17 +1,20 @@ +from typing import Dict, Optional, Tuple, Type + import cebra -import torch +import matplotlib import numpy as np -from ..utils_allen import decoding_frames -from ..utils_hpc import decoding_pos_dir -from ..activations import get_activations_model -from .base import _BaseMetric -import cebra_lens.matplotlib as cebra_lens_matplotlib import numpy.typing as npt -from typing import Dict, Type, Tuple, Optional -import torch.nn as nn import sklearn.metrics +import torch import torch as pt -import matplotlib +import torch.nn as nn + +import cebra_lens.matplotlib as cebra_lens_matplotlib + +from ..activations import get_activations_model +from ..utils_allen import decoding_frames +from ..utils_hpc import decoding_pos_dir +from .base import _BaseMetric def decoding( @@ -370,8 +373,10 @@ def plot( ) if self.output_only: - return cebra_lens_matplotlib.plot_decoding(results_dict, palette, self.dataset_label, - label, plot_error, ax) + return cebra_lens_matplotlib.plot_decoding(results_dict, palette, + self.dataset_label, + label, plot_error, ax) else: - return cebra_lens_matplotlib.plot_layer_decoding(results_dict, title, self.dataset_label, - label, plot_error, figsize) + return cebra_lens_matplotlib.plot_layer_decoding( + results_dict, title, self.dataset_label, label, plot_error, + figsize) diff --git a/cebra_lens/quantification/distance.py b/cebra_lens/quantification/distance.py index 3692e8e..7a1d1b2 100644 --- a/cebra_lens/quantification/distance.py +++ b/cebra_lens/quantification/distance.py @@ -1,14 +1,16 @@ "file containing all the functions relative to distance computing" +from typing import Dict, List, Optional, Tuple, Union + import numpy as np +import numpy.typing as npt from scipy.spatial.distance import cdist, pdist from sklearn.preprocessing import StandardScaler -from typing import List, Optional, Tuple, Union, Dict -from .misc import discrete_binning, repetition_binning, continuous_binning -from .base import _BaseMetric + from ..matplotlib import * -import numpy.typing as npt from ..utils import extract_label +from .base import _BaseMetric +from .misc import continuous_binning, discrete_binning, repetition_binning class DistanceMetric: diff --git a/cebra_lens/quantification/misc.py b/cebra_lens/quantification/misc.py index 0b8593a..f078f64 100644 --- a/cebra_lens/quantification/misc.py +++ b/cebra_lens/quantification/misc.py @@ -1,11 +1,12 @@ """misc functions like normalization and possibly others""" +import warnings from random import sample +from typing import List + import numpy as np -import torch import numpy.typing as npt -from typing import List -import warnings +import torch def normalize_minmax(rdm: npt.NDArray) -> npt.NDArray: diff --git a/cebra_lens/quantification/rdm_metric.py b/cebra_lens/quantification/rdm_metric.py index 1eeed4b..4af29aa 100644 --- a/cebra_lens/quantification/rdm_metric.py +++ b/cebra_lens/quantification/rdm_metric.py @@ -1,16 +1,18 @@ """All the functions relative to the Representation Dissimilarity Matrix (RDM) calculation""" -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple, Union + +import matplotlib import numpy as np +import numpy.typing as npt +import torch from scipy.linalg import block_diag -from typing import List, Optional, Tuple, Union from scipy.spatial.distance import correlation, pdist, squareform -from .misc import discrete_binning, continuous_binning -import torch -import matplotlib -from .base import _BaseMetric + import cebra_lens.matplotlib as cebra_lens_matplotlib -import numpy.typing as npt + +from .base import _BaseMetric +from .misc import continuous_binning, discrete_binning class RDM(_BaseMetric): diff --git a/cebra_lens/quantification/tsne.py b/cebra_lens/quantification/tsne.py index fd67b5d..89356c9 100644 --- a/cebra_lens/quantification/tsne.py +++ b/cebra_lens/quantification/tsne.py @@ -1,11 +1,13 @@ """Functions to transform data e.g. tSNE, other functions can be added""" -from sklearn.manifold import TSNE -import numpy as np -from .base import _BaseMetric -from ..matplotlib import * from typing import List, Optional, Union + +import numpy as np import numpy.typing as npt +from sklearn.manifold import TSNE + +from ..matplotlib import * +from .base import _BaseMetric class Tsne(_BaseMetric): diff --git a/cebra_lens/utils.py b/cebra_lens/utils.py index 5aa15e7..f421ac0 100644 --- a/cebra_lens/utils.py +++ b/cebra_lens/utils.py @@ -1,17 +1,19 @@ import pathlib +from typing import Any, Dict, List, Union + import cebra -import torch -from typing import Dict, List, Any, Union import numpy as np import numpy.typing as npt -from tqdm import tqdm +import torch from torch import nn +from tqdm import tqdm + +from .quantification.cka_metric import CKA from .quantification.decoder import Decoding from .quantification.rdm_metric import RDM -from .quantification.cka_metric import CKA from .quantification.tsne import Tsne -from .utils_hpc import get_datasets as get_datasets_hpc from .utils_allen import get_datasets as get_datasets_visual +from .utils_hpc import get_datasets as get_datasets_hpc def get_data( diff --git a/cebra_lens/utils_allen.py b/cebra_lens/utils_allen.py index 5068aba..9c40d17 100644 --- a/cebra_lens/utils_allen.py +++ b/cebra_lens/utils_allen.py @@ -1,11 +1,12 @@ """Utils relative to the visual dataset""" -import numpy as np -import torch +import copy + import cebra import cebra.datasets -import copy +import numpy as np import sklearn.metrics +import torch ######################################################################################################################## diff --git a/cebra_lens/utils_hpc.py b/cebra_lens/utils_hpc.py index 47ea25e..431d2fe 100644 --- a/cebra_lens/utils_hpc.py +++ b/cebra_lens/utils_hpc.py @@ -1,9 +1,10 @@ +from typing import List + import cebra -import numpy as np -import sklearn import cebra.datasets +import numpy as np import numpy.typing as npt -from typing import List +import sklearn def get_datasets( diff --git a/tests/test_activations.py b/tests/test_activations.py index 34a8355..7f3ff15 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -1,13 +1,12 @@ -import torch -import pytest -import numpy as np from collections import namedtuple from unittest.mock import MagicMock -from cebra_lens.activations import ( - get_activations_model, - _cut_array, - get_cut_indices, -) + +import numpy as np +import pytest +import torch + +from cebra_lens.activations import (_cut_array, get_activations_model, + get_cut_indices) def test_cut_array_no_cut(): @@ -32,8 +31,9 @@ def test_get_cut_indices(): result = get_cut_indices(model_mock, torch.nn.Conv1d, [3, 3]) assert isinstance(result, list) assert all(isinstance(x, tuple) and len(x) == 2 for x in result) - - with pytest.raises(NotImplementedError, match="Padding handling not implemented*"): + + with pytest.raises(NotImplementedError, + match="Padding handling not implemented*"): get_cut_indices(model_mock, None, [3, 3]) diff --git a/tests/test_cka.py b/tests/test_cka.py index 2e41021..20fa7fb 100644 --- a/tests/test_cka.py +++ b/tests/test_cka.py @@ -1,16 +1,22 @@ -import pytest +from unittest.mock import MagicMock, patch + import numpy as np +import pytest import torch -from unittest.mock import patch, MagicMock + import cebra_lens + @pytest.fixture def dummy_comparisons(): return [("A", "B")] + @pytest.fixture def dummy_cka(dummy_comparisons): - return cebra_lens.quantification.cka_metric.CKA(comparisons=dummy_comparisons) + return cebra_lens.quantification.cka_metric.CKA( + comparisons=dummy_comparisons) + @pytest.fixture def dummy_activations(): @@ -21,35 +27,47 @@ def dummy_activations(): linear_features = 5 # Conv1D output: (batch_size, conv_channels, conv_length) -> flatten to (batch_size, conv_channels * conv_length) - conv1d_activations_A = np.random.rand(batch_size, conv_channels, conv_length).reshape(batch_size, -1) + conv1d_activations_A = np.random.rand(batch_size, conv_channels, + conv_length).reshape(batch_size, -1) linear_activations_A = np.random.rand(batch_size, linear_features) - conv1d_activations_B = np.random.rand(batch_size, conv_channels, conv_length).reshape(batch_size, -1) + conv1d_activations_B = np.random.rand(batch_size, conv_channels, + conv_length).reshape(batch_size, -1) linear_activations_B = np.random.rand(batch_size, linear_features) # Each group has a list of 2D arrays (one per layer) return { - "A": np.array([[np.random.rand(5, 10), np.random.rand(5, 10)], - [np.random.rand(5, 10), np.random.rand(5, 10)]]), - "B": np.array([[np.random.rand(5, 10), np.random.rand(5, 10)], - [np.random.rand(5, 10), np.random.rand(5, 10)]]), + "A": + np.array([[np.random.rand(5, 10), + np.random.rand(5, 10)], + [np.random.rand(5, 10), + np.random.rand(5, 10)]]), + "B": + np.array([[np.random.rand(5, 10), + np.random.rand(5, 10)], + [np.random.rand(5, 10), + np.random.rand(5, 10)]]), } + def test_center_gram_symmetry(dummy_cka): mat = np.eye(5) centered = dummy_cka.center_gram(mat) assert np.allclose(centered, centered.T) + def test_center_gram_unbiased(dummy_cka): mat = np.eye(5) centered = dummy_cka.center_gram(mat, unbiased=True) assert np.allclose(centered, centered.T) + def test_gram_linear(dummy_cka): x = np.random.rand(10, 5) gram = dummy_cka.gram_linear(x) assert gram.shape == (10, 10) assert np.allclose(gram, gram.T) + def test_cka_value(dummy_cka): x = np.random.rand(10, 5) y = np.random.rand(10, 5) @@ -58,29 +76,33 @@ def test_cka_value(dummy_cka): val = dummy_cka.cka(gram_x, gram_y) assert isinstance(val, float) or isinstance(val, np.floating) + def test_compute_cka_shape(dummy_cka): emb1 = [np.random.rand(5, 10), np.random.rand(5, 10)] emb2 = [np.random.rand(5, 10), np.random.rand(5, 10)] result = dummy_cka._compute_cka(emb1, emb2) assert result.shape == (1, 2) + def test_compute_per_layer_shape(dummy_cka): - emb1 = [ [np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3) ] - emb2 = [ [np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3) ] + emb1 = [[np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3)] + emb2 = [[np.random.rand(5, 10), np.random.rand(5, 10)] for _ in range(3)] result = dummy_cka._compute_per_layer(emb1, emb2, flag=True) assert result.shape == (3, 2) + def test_compute(dummy_cka, dummy_activations): result = dummy_cka.compute(dummy_activations, ("A", "B")) assert isinstance(result, np.ndarray) - - + + def test_compute_intra_label(dummy_cka, dummy_activations): result = dummy_cka.compute(dummy_activations, ("A", "A")) assert isinstance(result, np.ndarray) + @patch("cebra_lens.matplotlib.plot_cka_heatmaps") def test_plot_calls_heatmap(mock_plot, dummy_cka): cka_matrices = {"A": np.random.rand(2, 2)} dummy_cka.plot(cka_matrices, annot=True) - assert mock_plot.called \ No newline at end of file + assert mock_plot.called diff --git a/tests/test_decoding.py b/tests/test_decoding.py index 888b4e3..c1d05db 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -1,7 +1,9 @@ -import pytest +from unittest.mock import MagicMock, patch + import numpy as np +import pytest import torch -from unittest.mock import patch, MagicMock + import cebra_lens @@ -17,15 +19,14 @@ def embeddings_labels(): def test_decoding_function(embeddings_labels): emb_train, emb_test, label_train, label_test = embeddings_labels - with patch( - "cebra.KNNDecoder") as mock_knn: + with patch("cebra.KNNDecoder") as mock_knn: mock_model = MagicMock() # Return prediction with correct shape each time mock_model.predict.side_effect = lambda x: np.random.rand(len(x)) mock_knn.return_value = mock_model - score, medians, r2s = cebra_lens.quantification.decoder.decoding(emb_train, emb_test, label_train, - label_test) + score, medians, r2s = cebra_lens.quantification.decoder.decoding( + emb_train, emb_test, label_train, label_test) assert isinstance(score, float) assert len(medians) == label_train.shape[1] @@ -74,13 +75,15 @@ def test_decoding_class_output_only_false(mock_get_act): results = decoding_class.compute(model) assert isinstance(results, dict) assert len(results) == 1 # only one Conv1d layer in the mock model - + decoding_class.layer_type = None - with pytest.raises(NotImplementedError, match="Padding handling not implemented*"): + with pytest.raises(NotImplementedError, + match="Padding handling not implemented*"): decoding_class.compute(model) - + decoding_class.layer_type = torch.nn.Linear - with pytest.raises(NotImplementedError, match="Padding handling not implemented*"): + with pytest.raises(NotImplementedError, + match="Padding handling not implemented*"): decoding_class.compute(model) diff --git a/tests/test_misc.py b/tests/test_misc.py index 17007c0..5105229 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -1,11 +1,10 @@ -import pytest import numpy as np +import pytest import torch -from cebra_lens.quantification.misc import ( - discrete_binning, - continuous_binning, - repetition_binning, -) + +from cebra_lens.quantification.misc import (continuous_binning, + discrete_binning, + repetition_binning) def test_discrete_binning_shape_and_values(): diff --git a/tests/test_rdm.py b/tests/test_rdm.py index 2a11429..694d67d 100644 --- a/tests/test_rdm.py +++ b/tests/test_rdm.py @@ -1,7 +1,9 @@ -import pytest +from unittest.mock import MagicMock, patch + import numpy as np +import pytest import torch -from unittest.mock import patch, MagicMock + import cebra_lens @@ -18,7 +20,9 @@ def dummy_labels(): @patch("cebra_lens.quantification.rdm_metric.continuous_binning") def test_define_indices_continuous(mock_binning, dummy_data, dummy_labels): mock_binning.return_value = (np.array([[0, 1], [2, 3]]), 2) - rdm = cebra_lens.quantification.rdm_metric.RDM(data=dummy_data, label=dummy_labels, dataset_label="visual") + rdm = cebra_lens.quantification.rdm_metric.RDM(data=dummy_data, + label=dummy_labels, + dataset_label="visual") idxs, bins = rdm._define_indices() assert isinstance(idxs, np.ndarray) assert bins == 2 @@ -28,7 +32,9 @@ def test_define_indices_continuous(mock_binning, dummy_data, dummy_labels): def test_define_indices_discrete(mock_binning, dummy_data): labels = torch.tensor([0, 1, 0, 1, 2]) mock_binning.return_value = np.array([[0, 2], [1, 3]]) - rdm = cebra_lens.quantification.rdm_metric.RDM(data=dummy_data, label=labels, is_discrete_labels=True) + rdm = cebra_lens.quantification.rdm_metric.RDM(data=dummy_data, + label=labels, + is_discrete_labels=True) idxs, bins = rdm._define_indices() assert isinstance(idxs, np.ndarray) assert bins is None @@ -38,10 +44,10 @@ def test_init_with_label_ind(): labels = np.array([[1, 2], [3, 4]]) data = torch.tensor(np.random.rand(2, 5), dtype=torch.float32) rdm = cebra_lens.quantification.rdm_metric.RDM(data=data, - label=labels, - label_ind=0, - dataset_label=None, - is_discrete_labels=True) + label=labels, + label_ind=0, + dataset_label=None, + is_discrete_labels=True) assert rdm.label.tolist() == [1, 3] @@ -73,7 +79,9 @@ def test_compute_per_layer_and_bool_oracle(): def test_compute_single_activation_tensor(): data = torch.rand((10000, 5)) label = torch.randint(0, 3, (10000, )) - rdm = cebra_lens.quantification.rdm_metric.RDM(data=data, label=label, is_discrete_labels=False) + rdm = cebra_lens.quantification.rdm_metric.RDM(data=data, + label=label, + is_discrete_labels=False) rdm.idxs = np.array([[i] for i in range(10000)]) act = torch.rand((10000, 5)).numpy() result = rdm.compute(act) @@ -85,9 +93,10 @@ def test_compute_single_activation_tensor(): @patch("cebra_lens.matplotlib.plot_rdm_correlation") @patch("cebra_lens.matplotlib.plot_rdm_all") def test_rdm_plot(mock_all, mock_corr): - dummy_rdm = cebra_lens.quantification.rdm_metric.RDM(data=torch.rand((10000, 5)), - label=torch.randint(0, 5, (10000, )), - is_discrete_labels=False) + dummy_rdm = cebra_lens.quantification.rdm_metric.RDM( + data=torch.rand((10000, 5)), + label=torch.randint(0, 5, (10000, )), + is_discrete_labels=False) dummy_rdm.bool_oracle = True dummy_rdm.plot({"group": [np.random.rand(5, 5)]}) assert mock_corr.called @@ -97,10 +106,11 @@ def test_rdm_plot(mock_all, mock_corr): dummy_rdm.is_discrete_labels = True dummy_rdm.plot({"group": [np.random.rand(5, 5)]}) assert mock_all.called - + dummy_rdm.plot({"group": [np.random.rand(5, 5)]}, titles=None) dummy_rdm.plot({}) + def test_init_raises_keyerror_for_multilabel_without_label_ind(): # label is 2D, dataset_label is not HPC/visual, label_ind not provided labels = np.array([[1, 2], [3, 4]]) @@ -114,6 +124,7 @@ def test_init_raises_keyerror_for_multilabel_without_label_ind(): label_ind=None, ) + def test_define_indices_raises_valueerror_for_invalid_dataset_label(): data = torch.tensor(np.random.rand(10, 5), dtype=torch.float32) labels = torch.tensor(np.random.randint(0, 5, size=10), dtype=torch.int64) @@ -125,6 +136,7 @@ def test_define_indices_raises_valueerror_for_invalid_dataset_label(): dataset_label="invalid_label", ) + def test_compute_per_layer_transposes_if_needed(): rdm = cebra_lens.quantification.rdm_metric.RDM( data=torch.rand((10, 5)), @@ -135,4 +147,4 @@ def test_compute_per_layer_transposes_if_needed(): # Provide activation with shape (features, samples) dummy_layer = np.random.rand(5, 10) rdm.bool_oracle = False - rdm._compute_per_layer(dummy_layer) \ No newline at end of file + rdm._compute_per_layer(dummy_layer) diff --git a/tests/test_utils.py b/tests/test_utils.py index d6da83f..7054517 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,12 @@ -import pytest +from unittest.mock import MagicMock, patch + import numpy as np -from unittest.mock import patch, MagicMock -from cebra_lens import extract_label, compute_metric, plot_metric, model_loader +import pytest + +from cebra_lens import compute_metric, extract_label, model_loader, plot_metric +from cebra_lens.quantification.cka_metric import CKA from cebra_lens.quantification.decoder import Decoding from cebra_lens.quantification.rdm_metric import RDM -from cebra_lens.quantification.cka_metric import CKA def test_extract_label_single_dim(): From d6a02e2e41f62e3eaa8ba26179a8e7e07914b341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 20 Jun 2025 17:27:58 +0200 Subject: [PATCH 10/12] Fix imports with matplotlib --- cebra_lens/__init__.py | 2 +- cebra_lens/activations.py | 2 +- cebra_lens/quantification/cka_metric.py | 4 ++-- cebra_lens/quantification/decoder.py | 14 +++++++------- cebra_lens/quantification/distance.py | 2 +- cebra_lens/quantification/rdm_metric.py | 6 +++--- cebra_lens/quantification/tsne.py | 3 ++- cebra_lens/{matplotlib.py => utils_plot.py} | 0 demos/metric_template.py | 8 ++++---- tests/test_cka.py | 2 +- tests/test_decoding.py | 4 ++-- tests/test_rdm.py | 4 ++-- 12 files changed, 26 insertions(+), 25 deletions(-) rename cebra_lens/{matplotlib.py => utils_plot.py} (100%) diff --git a/cebra_lens/__init__.py b/cebra_lens/__init__.py index 4bb3127..a198adb 100644 --- a/cebra_lens/__init__.py +++ b/cebra_lens/__init__.py @@ -1,6 +1,5 @@ # example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations from .activations import * -from .matplotlib import * from .quantification.cka_metric import * from .quantification.decoder import * from .quantification.distance import * @@ -9,6 +8,7 @@ from .utils import * from .utils_allen import * from .utils_hpc import * +from .utils_plot import * # selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean # __all__ = ['get_layer_activations'] diff --git a/cebra_lens/activations.py b/cebra_lens/activations.py index 34ffec8..2cdcecd 100644 --- a/cebra_lens/activations.py +++ b/cebra_lens/activations.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -from .matplotlib import plot_activations +from .utils_plot import plot_activations def _cut_array(array: npt.NDArray, diff --git a/cebra_lens/quantification/cka_metric.py b/cebra_lens/quantification/cka_metric.py index e970230..6e810a7 100644 --- a/cebra_lens/quantification/cka_metric.py +++ b/cebra_lens/quantification/cka_metric.py @@ -12,7 +12,7 @@ import numpy.typing as npt from tqdm import tqdm -import cebra_lens.matplotlib as cebra_lens_matplotlib +from cebra_lens import utils_plot from .base import _BaseMetric @@ -300,7 +300,7 @@ def plot( matplotlib.axes.Axes The axes on which the heatmap is plotted. """ - return cebra_lens_matplotlib.plot_cka_heatmaps( + return utils_plot.plot_cka_heatmaps( cka_matrices, annot, show_cbar, diff --git a/cebra_lens/quantification/decoder.py b/cebra_lens/quantification/decoder.py index 95effd7..9f24ffb 100644 --- a/cebra_lens/quantification/decoder.py +++ b/cebra_lens/quantification/decoder.py @@ -9,7 +9,7 @@ import torch as pt import torch.nn as nn -import cebra_lens.matplotlib as cebra_lens_matplotlib +from cebra_lens import utils_plot from ..activations import get_activations_model from ..utils_allen import decoding_frames @@ -373,10 +373,10 @@ def plot( ) if self.output_only: - return cebra_lens_matplotlib.plot_decoding(results_dict, palette, - self.dataset_label, - label, plot_error, ax) + return utils_plot.plot_decoding(results_dict, palette, + self.dataset_label, label, + plot_error, ax) else: - return cebra_lens_matplotlib.plot_layer_decoding( - results_dict, title, self.dataset_label, label, plot_error, - figsize) + return utils_plot.plot_layer_decoding(results_dict, title, + self.dataset_label, label, + plot_error, figsize) diff --git a/cebra_lens/quantification/distance.py b/cebra_lens/quantification/distance.py index 7a1d1b2..c726802 100644 --- a/cebra_lens/quantification/distance.py +++ b/cebra_lens/quantification/distance.py @@ -7,8 +7,8 @@ from scipy.spatial.distance import cdist, pdist from sklearn.preprocessing import StandardScaler -from ..matplotlib import * from ..utils import extract_label +from ..utils_plot import * from .base import _BaseMetric from .misc import continuous_binning, discrete_binning, repetition_binning diff --git a/cebra_lens/quantification/rdm_metric.py b/cebra_lens/quantification/rdm_metric.py index 4af29aa..fdab4d1 100644 --- a/cebra_lens/quantification/rdm_metric.py +++ b/cebra_lens/quantification/rdm_metric.py @@ -9,7 +9,7 @@ from scipy.linalg import block_diag from scipy.spatial.distance import correlation, pdist, squareform -import cebra_lens.matplotlib as cebra_lens_matplotlib +from cebra_lens import utils_plot from .base import _BaseMetric from .misc import continuous_binning, discrete_binning @@ -258,9 +258,9 @@ def plot( The figure containing the plotted RDMs. """ if self.bool_oracle: - return cebra_lens_matplotlib.plot_rdm_correlation(rdms) + return utils_plot.plot_rdm_correlation(rdms) else: - return cebra_lens_matplotlib.plot_rdm_all( + return utils_plot.plot_rdm_all( rdms=rdms, labels=self.label, num_bins=self.num_bins, diff --git a/cebra_lens/quantification/tsne.py b/cebra_lens/quantification/tsne.py index 89356c9..6ea7aa0 100644 --- a/cebra_lens/quantification/tsne.py +++ b/cebra_lens/quantification/tsne.py @@ -2,11 +2,12 @@ from typing import List, Optional, Union +import matplotlib import numpy as np import numpy.typing as npt from sklearn.manifold import TSNE -from ..matplotlib import * +from ..utils_plot import * from .base import _BaseMetric diff --git a/cebra_lens/matplotlib.py b/cebra_lens/utils_plot.py similarity index 100% rename from cebra_lens/matplotlib.py rename to cebra_lens/utils_plot.py diff --git a/demos/metric_template.py b/demos/metric_template.py index 2ef3066..8af4f73 100644 --- a/demos/metric_template.py +++ b/demos/metric_template.py @@ -1,9 +1,9 @@ import numpy as np -from ..cebra_lens.quantification.base import _BaseMetric -from ..cebra_lens.matplotlib import * +from cebra_lens.quantification.base import _BaseMetric +from cebra_lens.utils_plot import * from typing import List, Optional, Union import numpy.typing as npt - +import matplotlib class NewMetric(_BaseMetric): """ @@ -86,7 +86,7 @@ def plot( The figure containing the NewMetric plot. """ - #Need to define the plot_newMetric function in the matplotlib.py + #Need to define the plot_newMetric function in the utils_plot.py return plot_newMetric( embeddings, labels, diff --git a/tests/test_cka.py b/tests/test_cka.py index 20fa7fb..801404d 100644 --- a/tests/test_cka.py +++ b/tests/test_cka.py @@ -101,7 +101,7 @@ def test_compute_intra_label(dummy_cka, dummy_activations): assert isinstance(result, np.ndarray) -@patch("cebra_lens.matplotlib.plot_cka_heatmaps") +@patch("cebra_lens.utils_plot.plot_cka_heatmaps") def test_plot_calls_heatmap(mock_plot, dummy_cka): cka_matrices = {"A": np.random.rand(2, 2)} dummy_cka.plot(cka_matrices, annot=True) diff --git a/tests/test_decoding.py b/tests/test_decoding.py index c1d05db..f47b067 100644 --- a/tests/test_decoding.py +++ b/tests/test_decoding.py @@ -98,8 +98,8 @@ def test_set_output_only(): assert decoding_instance.output_only is False -@patch("cebra_lens.matplotlib.plot_decoding") -@patch("cebra_lens.matplotlib.plot_layer_decoding") +@patch("cebra_lens.utils_plot.plot_decoding") +@patch("cebra_lens.utils_plot.plot_layer_decoding") def test_decoder_plot(mock_layer_plot, mock_decoding_plot): dummy_result = {"modelA": {0: (0.9, [0.1], [0.8])}} dec = cebra_lens.quantification.decoder.Decoding( diff --git a/tests/test_rdm.py b/tests/test_rdm.py index 694d67d..58c6be3 100644 --- a/tests/test_rdm.py +++ b/tests/test_rdm.py @@ -90,8 +90,8 @@ def test_compute_single_activation_tensor(): assert result[0][0].ndim == 2 # RDM squareform -@patch("cebra_lens.matplotlib.plot_rdm_correlation") -@patch("cebra_lens.matplotlib.plot_rdm_all") +@patch("cebra_lens.utils_plot.plot_rdm_correlation") +@patch("cebra_lens.utils_plot.plot_rdm_all") def test_rdm_plot(mock_all, mock_corr): dummy_rdm = cebra_lens.quantification.rdm_metric.RDM( data=torch.rand((10000, 5)), From 4794c82f9152d2f1436c7b1d056bd2a9fc20f34e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:36:40 +0200 Subject: [PATCH 11/12] Remove outdated scripts (#48) --- scripts/CKA_analysis.py | 109 ----------------- scripts/RDM_analysis.py | 166 -------------------------- scripts/cebra_visualization.py | 103 ---------------- scripts/distance_analysis.py | 165 ------------------------- scripts/layer_activation_retrieval.py | 148 ----------------------- scripts/model_decoding.py | 105 ---------------- scripts/tSNE_visualization.py | 134 --------------------- 7 files changed, 930 deletions(-) delete mode 100644 scripts/CKA_analysis.py delete mode 100644 scripts/RDM_analysis.py delete mode 100644 scripts/cebra_visualization.py delete mode 100644 scripts/distance_analysis.py delete mode 100644 scripts/layer_activation_retrieval.py delete mode 100644 scripts/model_decoding.py delete mode 100644 scripts/tSNE_visualization.py diff --git a/scripts/CKA_analysis.py b/scripts/CKA_analysis.py deleted file mode 100644 index c017b37..0000000 --- a/scripts/CKA_analysis.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -from tqdm import tqdm -import argparse -import pickle -from GithubFolder.src.cebra_lens import cebra_lens as lens -import matplotlib.pyplot as plt -import logging - - -def setup_logging(): - - # Get directory and filename - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_filename = os.path.splitext(os.path.basename(__file__))[0] - - logs_dir = os.path.join(script_dir, "logs") - - if not os.path.exists(logs_dir): - os.makedirs(logs_dir) - - log_file_path = os.path.join(logs_dir, f"{script_filename}.log") - - logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - -def main( - activations_filepath="data/activations/offset10.pkl", - bool_comput=0, - saving_filepath="data/CKA/offset10.pkl", -): - logging.info("Script started with arguments:") - for arg, value in locals().items(): - logging.info(f"{arg}: {value}") - - with open(activations_filepath, "rb") as f: - activations_dict = pickle.load(f) - - if bool_comput: - comparisons = [ - ("single_UT", "single_TR"), - ("multi_UT", "multi_TR"), - ("single_TR", "multi_TR"), - ("single_TR", "single_TR"), - ("multi_TR", "multi_TR"), - ] - - cka_matrices = {} - for comparison in tqdm(comparisons): - cka_matrix = lens.quantification.compute_multi_CKA_layers( - activations_dict=activations_dict, comparison=comparison - ) - cka_matrices[f"{comparison[0]}_v_{comparison[1]}"] = cka_matrix - - with open(saving_filepath, "wb") as f: - pickle.dump(cka_matrices, f) - - else: - - with open(saving_filepath, "rb") as f: - cka_matrices = pickle.load(f) - - fig = lens.plotting.plot_cka_heatmaps( - cka_matrices=cka_matrices, - annot=False, - ) - - plt.show() - - -if __name__ == "__main__": - - setup_logging() - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--activations_filepath", - type=str, - default="data/activations/offset10.pkl", - help="filepath of the activation's dictionnary", - ) - - parser.add_argument( - "--bool_comput", - type=int, - default=0, - help="If True, will recompute and overwrite the cka matrices (0 or 1)", - ) - - parser.add_argument( - "--saving_filepath", - type=str, - default=None, - help="filepath where to save the CKA dictionnary", - ) - - args = parser.parse_args() - - if args.saving_filepath is None: - filename = args.activations_filepath.split("/")[-1] - args.saving_filepath = os.path.join("data/CKA/", filename) - main( - args.activations_filepath, - args.bool_comput, - args.saving_filepath, - ) diff --git a/scripts/RDM_analysis.py b/scripts/RDM_analysis.py deleted file mode 100644 index 905d626..0000000 --- a/scripts/RDM_analysis.py +++ /dev/null @@ -1,166 +0,0 @@ -import pickle -import matplotlib.pyplot as plt -import argparse -import os -from GithubFolder.src.cebra_lens import cebra_lens as lens -import logging - - -def setup_logging(): - - # Get directory and filename - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_filename = os.path.splitext(os.path.basename(__file__))[0] - - logs_dir = os.path.join(script_dir, "logs") - - if not os.path.exists(logs_dir): - os.makedirs(logs_dir) - - log_file_path = os.path.join(logs_dir, f"{script_filename}.log") - - logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - -def main( - filepath="data/activations/offset10.pkl", - bool_comput=0, - saving_filepath="data/RDM/offset10.pkl", - session_id=3, - bool_example: bool = True, -): - logging.info("Script started with arguments:") - for arg, value in locals().items(): - logging.info(f"{arg}: {value}") - - # LOAD DATA - train_datas, _, discrete_labels_train, _ = ( - lens.utils_allen.get_single_session_datasets() - ) - - train_data = train_datas[session_id].neural - train_label = discrete_labels_train[session_id] - - with open(filepath, "rb") as f: - activations_dict = pickle.load(f) - - if bool_example: - - # example of single instance usage with plotting (neural vs multi here) - neural_rdm = lens.quantification.RDM.compute_single_RDM_layers( - train_data=train_data, - train_label=train_label, - activations=[train_data], - metric="euclidean", - bool_oracle=False, - ) - - multi_rdm = lens.quantification.RDM.compute_single_RDM_layers( - train_data=train_data, - train_label=train_label, - activations=activations_dict["multi"]["TR"][0], - metric="euclidean", - bool_oracle=False, - ) - # Normalize the RDMs using Min-Max normalization - rdm1_normalized = lens.quantification.misc.normalize_minmax(neural_rdm[0][0]) - rdm2_normalized = lens.quantification.misc.normalize_minmax(multi_rdm[-1][0]) - - fig1 = lens.plotting.plot_rdm( - [rdm1_normalized, rdm2_normalized], - ["Neural input", "Output Layer"], - metric="Normalized Euclidean distance", - ) - multi_rdm_corr = lens.quantification.RDM.compute_single_RDM_layers( - train_data=train_data, - train_label=train_label, - activations=activations_dict["multi"]["TR"][0], - metric="correlation", - bool_oracle=False, - ) - fig2 = lens.plotting.plot_rdm( - [multi_rdm_corr[0][0], multi_rdm_corr[-1][0]], - ["Layer 1", "Output Layer"], - metric="Correlation", - ) - - plt.show() - - if bool_comput: - - rdm_dict = lens.quantification.RDM.compute_multi_RDM_layers( - train_data=train_data, - train_label=train_label, - activations_dict=activations_dict, - dataset_label="visual", - metric="correlation", - bool_oracle=True, - ) - - with open(saving_filepath, "wb") as f: - pickle.dump(rdm_dict, f) - - else: - - with open(saving_filepath, "rb") as f: - rdm_dict = pickle.load(f) - - fig = lens.plotting.plot_rdm_correlation(rdm_dict=rdm_dict) - plt.show() - - -if __name__ == "__main__": - - setup_logging() - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--filepath", - type=str, - default="data/activations/offset10.pkl", - help="name of the activations (assuming they are under data/activations)", - ) - - parser.add_argument( - "--bool_comput", - type=int, - default=0, - help="If True, will recompute and overwrite the cka matrices (0 or 1)", - ) - - parser.add_argument( - "--saving_filepath", - type=str, - default=None, - help="name of the file where to save the RDM matrices (it will be under data/RDM/saving_filename)", - ) - - parser.add_argument( - "--session_id", - type=int, - default=3, - help="session id for the analysis, used to retrieve the correct data and multi-session model", - ) - parser.add_argument( - "--bool_example", - type=int, - default=1, - help="Shows an example usage.", - ) - args = parser.parse_args() - - if args.saving_filepath is None: - filename = args.filepath.split("/")[-1] - args.saving_filepath = os.path.join("data/CKA/", filename) - - main( - args.filepath, - args.bool_comput, - args.saving_filepath, - args.session_id, - args.bool_example, - ) diff --git a/scripts/cebra_visualization.py b/scripts/cebra_visualization.py deleted file mode 100644 index a314fda..0000000 --- a/scripts/cebra_visualization.py +++ /dev/null @@ -1,103 +0,0 @@ -import pickle -import argparse -from GithubFolder.src.cebra_lens import cebra_lens as lens -import matplotlib.pyplot as plt -import os -import logging - - -def setup_logging(): - - # Get directory and filename - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_filename = os.path.splitext(os.path.basename(__file__))[0] - - logs_dir = os.path.join(script_dir, "logs") - - if not os.path.exists(logs_dir): - os.makedirs(logs_dir) - - log_file_path = os.path.join(logs_dir, f"{script_filename}.log") - - logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - -def main( - activations_filepath="data/activations/offset10.pkl", - session_id=3, - dataset_label="visual", -): - - logging.info("Script started with arguments:") - for arg, value in locals().items(): - logging.info(f"{arg}: {value}") - - _, _, discrete_labels_train, _ = lens.utils_allen.get_single_session_datasets() - train_label = discrete_labels_train[session_id] - - with open(activations_filepath, "rb") as f: - activations_dict = pickle.load(f) - - fig1 = lens.plotting.compare_embeddings_layers( - activations_dict["single"]["UT"][0], - activations_dict["single"]["TR"][0], - labels=train_label, - dataset_label=dataset_label, - sample_plot=activations_dict["single"]["TR"][0][0].shape[1], - comparison_labels=("CEBRA embeddings", ["Untrained Single", "Trained Single"]), - ) - fig2 = lens.plotting.compare_embeddings_layers( - activations_dict["multi"]["UT"][0], - activations_dict["multi"]["TR"][0], - labels=train_label, - dataset_label=dataset_label, - sample_plot=activations_dict["multi"]["TR"][0][0].shape[1], - comparison_labels=("CEBRA embeddings", ["Untrained Multi", "Trained Multi"]), - ) - fig3 = lens.plotting.compare_embeddings_layers( - activations_dict["single"]["TR"][0], - activations_dict["multi"]["TR"][0], - labels=train_label, - dataset_label=dataset_label, - sample_plot=activations_dict["multi"]["TR"][0][0].shape[1], - comparison_labels=("CEBRA embeddings", ["Single", "Multi"]), - ) - - plt.show() - - -if __name__ == "__main__": - - setup_logging() - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--activations_filepath", - type=str, - default="data/activations/offset10.pkl", - help="Path to the activations file.", - ) - - parser.add_argument( - "--session_id", - type=int, - default=3, - help="Session ID to use for the analysis.", - ) - - parser.add_argument( - "--dataset_label", - type=str, - default="visual", - ) - - args = parser.parse_args() - - main( - args.activations_filepath, - args.session_id, - ) diff --git a/scripts/distance_analysis.py b/scripts/distance_analysis.py deleted file mode 100644 index bcaa68c..0000000 --- a/scripts/distance_analysis.py +++ /dev/null @@ -1,165 +0,0 @@ -import pickle -import argparse -from GithubFolder.src.cebra_lens import cebra_lens as lens -import matplotlib.pyplot as plt -import os -import logging - - -def setup_logging(): - - # Get directory and filename - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_filename = os.path.splitext(os.path.basename(__file__))[0] - - logs_dir = os.path.join(script_dir, "logs") - - if not os.path.exists(logs_dir): - os.makedirs(logs_dir) - - log_file_path = os.path.join(logs_dir, f"{script_filename}.log") - - logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - -def main( - activations_filepath, - bool_comput, - distance_filepath, - session_id, - dataset_label="visual", - metric="cosine", -): - - logging.info("Script started with arguments:") - for arg, value in locals().items(): - logging.info(f"{arg}: {value}") - - # LOAD DATA - train_datas, _, discrete_labels_train, _ = ( - lens.utils_allen.get_single_session_datasets() - ) - - train_data = train_datas[session_id].neural - train_label = discrete_labels_train[session_id] - - with open(activations_filepath, "rb") as f: - activations_dict = pickle.load(f) - - if bool_comput: - - interbin_distances_dict = ( - lens.quantification.distance.compute_multi_distance_layers( - data=train_data, - label=train_label, - activations_dict=activations_dict, - dataset_label=dataset_label, - metric=metric, - distance_label="interbin", - ) - ) - intrabin_distances_dict = ( - lens.quantification.distance.compute_multi_distance_layers( - data=train_data, - label=train_label, - activations_dict=activations_dict, - dataset_label=dataset_label, - metric=metric, - distance_label="intrabin", - ) - ) - interrep_distances_dict = ( - lens.quantification.distance.compute_multi_distance_layers( - data=train_data, - label=train_label, - activations_dict=activations_dict, - dataset_label=dataset_label, - metric=metric, - distance_label="interrep", - ) - ) - - distances = { - "inter-bin": interbin_distances_dict, - "intra-bin": intrabin_distances_dict, - "inter-rep": interrep_distances_dict, - } - with open(distance_filepath, "wb") as f: - pickle.dump(distances, f) - - else: - - with open(distance_filepath, "rb") as f: - distances = pickle.load(f) - - figs = [] - for key, value in distances.items(): - title = f"Distance: {key}" - - figs.append(lens.plotting.plot_distance(distance_dict=value, title=title)) - - plt.show() - - -if __name__ == "__main__": - - setup_logging() - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--activations_filepath", - type=str, - default="data/activations/offset10.pkl", - help="Activation's filepath", - ) - - parser.add_argument( - "--bool_comput", - type=int, - default=0, - help="If True, will recompute and overwrite the distances (0 or 1)", - ) - - parser.add_argument( - "--distance_filepath", - type=str, - default=None, - help="Saving filepath of the distances dictionnary", - ) - - parser.add_argument( - "--session_id", - type=int, - default=3, - help="session id for the analysis, used to retrieve the correct data and multi-session model", - ) - parser.add_argument( - "--dataset_label", - type=str, - default="visual", - help="session id for the analysis, used to retrieve the correct data and multi-session model", - ) - parser.add_argument( - "--metric", - type=str, - default="cosine", - help="metric to compute the distance: euclidean or cosine", - ) - args = parser.parse_args() - - if args.distance_filepath is None: - - filename = args.activations_filepath.split("/")[-1].split(".")[0] - args.distance_filepath = f"data/distances/{filename}.pkl" - - main( - args.activations_filepath, - args.bool_comput, - args.distance_filepath, - args.session_id, - args.dataset_label, - ) diff --git a/scripts/layer_activation_retrieval.py b/scripts/layer_activation_retrieval.py deleted file mode 100644 index 6b4822b..0000000 --- a/scripts/layer_activation_retrieval.py +++ /dev/null @@ -1,148 +0,0 @@ -import pickle -import argparse -import cebra -from GithubFolder.src.cebra_lens import cebra_lens as lens -import os -import logging - - -def setup_logging(): - - # Get directory and filename - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_filename = os.path.splitext(os.path.basename(__file__))[0] - - logs_dir = os.path.join(script_dir, "logs") - - if not os.path.exists(logs_dir): - os.makedirs(logs_dir) - - log_file_path = os.path.join(logs_dir, f"{script_filename}.log") - - logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - -def main( - model_name, session_id, activations_filepath, bool_plot_embeddings, layer_type -): - logging.info("Script started with arguments:") - for arg, value in locals().items(): - logging.info(f"{arg}: {value}") - - # LOAD DATA - train_datas, _, discrete_labels_train, _ = ( - lens.utils_allen.get_single_session_datasets() - ) - - train_data = train_datas[session_id].neural - train_label = discrete_labels_train[session_id] - - # LOAD MODELS - models = lens.model.model_loader(model_name=model_name) - - if bool_plot_embeddings: - - X = train_data - y = train_label - embeddings_single = [] - embeddings_multi = [] - - # Go to 5 max for plotting clarity (works even if there are less than 5 models) - for model in models["multi_TR"][:5]: - embeddings_multi.append(model.transform(X, session_id=session_id)) - for model in models["single_TR"][:5]: - embeddings_single.append(model.transform(X)) - - # Align the single session embeddings to the first rat - alignment = cebra.data.helper.OrthogonalProcrustesAlignment() - - for i in range(len(embeddings_single)): - embeddings_single[i] = alignment.fit_transform( - embeddings_single[0], embeddings_single[i], y, y - ) - - for i in range(len(embeddings_multi)): - embeddings_multi[i] = alignment.fit_transform( - embeddings_multi[0], embeddings_multi[i], y, y - ) - - embeddings_untrained_single = models["single_UT"][0].transform( - X - ) # only take the first untrained model for plotting - embeddings_untrained_multi = models["multi_UT"][0].transform( - X, session_id=session_id - ) # only take the first untrained model for plotting - - fig = lens.plotting.plot_embeddings_singlevmulti( - embeddings_single, - embeddings_multi, - embeddings_untrained_single, - embeddings_untrained_multi, - y, - ) - fig.show() - - activations = {} - activations = lens.activations.get_activations_multi_model( - models=models, - data=train_data, - session_id=session_id, - activations=activations, - layer_type=layer_type, - ) - - activations_dict = lens.activations.process_activations(activations) - - with open(activations_filepath, "wb") as f: - pickle.dump(activations_dict, f) - - -if __name__ == "__main__": - - setup_logging() - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--model_name", - type=str, - default="offset10", - help="name of the folder where the models (assuming they are under FinalModels/VISION)", - ) - - parser.add_argument( - "--session_id", - type=int, - default=3, - help="session id for the analysis, used to retrieve the correct data and multi-session model", - ) - parser.add_argument( - "--filepath", - type=str, - default="data/activations/offset10.pkl", - help="filename of the activations", - ) - parser.add_argument( - "--bool_plot_embeddings", - type=int, - default=0, - help="Plots the output embeddings of the models (0 or 1)", - ) - parser.add_argument( - "--layer_type", - type=str, - default="conv", - help="Type of layer: e.g. 'conv', 'all'", - ) - - args = parser.parse_args() - main( - args.model_name, - args.session_id, - args.activations_filepath, - args.bool_plot_embeddings, - args.layer_type, - ) diff --git a/scripts/model_decoding.py b/scripts/model_decoding.py deleted file mode 100644 index 856acd6..0000000 --- a/scripts/model_decoding.py +++ /dev/null @@ -1,105 +0,0 @@ -from GithubFolder.src.cebra_lens import cebra_lens as lens -import matplotlib.pyplot as plt -import argparse -import os -import logging - - -def setup_logging(): - - # Get directory and filename - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_filename = os.path.splitext(os.path.basename(__file__))[0] - - logs_dir = os.path.join(script_dir, "logs") - - if not os.path.exists(logs_dir): - os.makedirs(logs_dir) - - log_file_path = os.path.join(logs_dir, f"{script_filename}.log") - - logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - -def main(model_name, session_id, bool_plot_loss): - - logging.info("Script started with arguments:") - for arg, value in locals().items(): - logging.info(f"{arg}: {value}") - - # LOAD DATA - train_datas, valid_datas, discrete_labels_train, discrete_labels_val = ( - lens.utils_allen.get_single_session_datasets() - ) - - train_data = train_datas[session_id].neural - test_data = valid_datas[session_id].neural - train_label = discrete_labels_train[session_id] - test_label = discrete_labels_val[session_id] - - # LOAD MODELS - models = lens.model.model_loader(model_name=model_name) - - if bool_plot_loss: - - fig, axs = plt.subplots(1, 2, figsize=(15, 7)) - - # Plot for single models - for i in range(len(models["single_TR"])): - axs[0].plot(models["single_TR"][i].state_dict_["loss"], c="blue", alpha=0.6) - axs[0].set_xlabel("Steps", fontsize=15) - axs[0].set_ylabel("Loss", fontsize=15) - axs[0].set_title("Single-session", fontsize=20) - - # Plot for multi models - for i in range(len(models["multi_TR"])): - axs[1].plot( - models["multi_TR"][i].state_dict_["loss"], c="orange", alpha=0.6 - ) - axs[1].set_xlabel("Steps", fontsize=15) - axs[1].set_ylabel("Loss", fontsize=15) - axs[1].set_title("Multi-session", fontsize=20) - - fig.suptitle("Losses", fontsize=30) - plt.show() - - results_dict = lens.quantification.decoding.decode_models( - models=models, - train_data=train_data, - train_label=train_label, - test_data=test_data, - test_label=test_label, - session_id=3, - ) - - fig = lens.plotting.plot_decoding(results_dict=results_dict, palette_tr="cool") - plt.show() - - -if __name__ == "__main__": - - setup_logging() - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--model_name", - type=str, - default="offset10", - help="name of the folder where the models (assuming they are under FinalModels/VISION)", - ) - parser.add_argument( - "--session_id", - type=int, - default=3, - help="session id for the analysis, used to retrieve the correct data and multi-session model", - ) - parser.add_argument( - "--bool_plot_loss", type=int, default=1, help="Plots losses of the models" - ) - - args = parser.parse_args() - main(args.model_name, args.session_id, args.bool_plot_loss) diff --git a/scripts/tSNE_visualization.py b/scripts/tSNE_visualization.py deleted file mode 100644 index 4927f49..0000000 --- a/scripts/tSNE_visualization.py +++ /dev/null @@ -1,134 +0,0 @@ -import pickle -import argparse -from GithubFolder.src.cebra_lens import cebra_lens as lens -import matplotlib.pyplot as plt -import logging -import os - - -def setup_logging(): - - # Get directory and filename - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_filename = os.path.splitext(os.path.basename(__file__))[0] - - logs_dir = os.path.join(script_dir, "logs") - - if not os.path.exists(logs_dir): - os.makedirs(logs_dir) - - log_file_path = os.path.join(logs_dir, f"{script_filename}.log") - - logging.basicConfig( - filename=log_file_path, - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - - -def main( - activations_filepath="data/activations/offset10.pkl", - tsne_filepath="data/tSNE/offset10.pkl", - session_id=3, - bool_comput=False, - num_samples=200, -): - logging.info("Script started with arguments:") - for arg, value in locals().items(): - logging.info(f"{arg}: {value}") - - _, _, discrete_labels_train, _ = lens.utils_allen.get_single_session_datasets() - train_label = discrete_labels_train[session_id] - - with open(activations_filepath, "rb") as f: - activations_dict = pickle.load(f) - - if bool_comput: - - tSNE_dict = lens.transform.run_tsne_and_save( - activations_dict, tsne_filepath, num_samples - ) - - else: - - with open(tsne_filepath, "rb") as f: - tSNE_dict = pickle.load(f) - - fig1 = lens.plotting.compare_embeddings_layers( - tSNE_dict["single"]["UT"][0], - tSNE_dict["single"]["TR"][0], - labels=train_label, - dataset_label="visual", - sample_plot=200, - ) - fig2 = lens.plotting.compare_embeddings_layers( - tSNE_dict["multi"]["UT"][0], - tSNE_dict["multi"]["TR"][0], - labels=train_label, - dataset_label="visual", - sample_plot=200, - ) - fig3 = lens.plotting.compare_embeddings_layers( - tSNE_dict["single"]["TR"][0], - tSNE_dict["multi"]["TR"][0], - labels=train_label, - dataset_label="visual", - sample_plot=200, - comparison_labels=("tSNE", ["Single", "Multi"]), - ) - - plt.show() - - -if __name__ == "__main__": - - setup_logging() - - parser = argparse.ArgumentParser(description="Process some parameters.") - parser.add_argument( - "--activations_filepath", - type=str, - default="data/activations/offset10.pkl", - help="Path to the activations file.", - ) - - parser.add_argument( - "--tsne_filepath", - type=str, - default=None, - help="Path to the tSNE embeddings file.", - ) - - parser.add_argument( - "--session_id", - type=int, - default=3, - help="Session ID to use for the analysis.", - ) - - parser.add_argument( - "--bool_comput", - type=int, - default=0, - help="If True, will recompute and overwrite the tSNE embeddings (0 or 1).", - ) - - parser.add_argument( - "--num_samples", - type=int, - default=200, - help="Number of samples to use for tSNE computation.", - ) - - args = parser.parse_args() - if args.tsne_filepath == None: - filename = args.activations_filepath.split("/")[-1].split(".")[0] - args.tsne_filepath = f"data/tsne/{filename}.pkl" - - main( - args.activations_filepath, - args.tsne_filepath, - args.session_id, - args.bool_comput, - args.num_samples, - ) From 3554fa145d28350742f22579376c7d07285533fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:48:57 +0200 Subject: [PATCH 12/12] Add API to the jupyter-book (#49) * Add API to the jupyter-book * Update intro.md * Update pyproject.toml --------- Co-authored-by: Mackenzie Mathis --- Makefile | 1 + docs/_config.yml | 8 +++++ docs/_toc.yml | 5 +++- docs/docs/api/helpers.rst | 28 ++++++++++++++++++ docs/docs/api/metrics.rst | 37 +++++++++++++++++++++++ docs/docs/intro.md | 36 ++++++++++++++++++++--- pyproject.toml | 62 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 172 insertions(+), 5 deletions(-) create mode 100644 docs/docs/api/helpers.rst create mode 100644 docs/docs/api/metrics.rst create mode 100644 pyproject.toml diff --git a/Makefile b/Makefile index 58a0ee8..e5f5e71 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,7 @@ interrogate: cebra_lens docs: + export PYTHONPATH=$(pwd) jupyter-book build docs docs-touch: diff --git a/docs/_config.yml b/docs/_config.yml index 6145111..b6fe627 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -19,3 +19,11 @@ repository: launch_buttons: colab_url: "https://colab.research.google.com/github/AdaptiveMotorControlLab/demos/usage_demo.ipynb" + +sphinx: + extra_extensions: + - 'sphinx.ext.autodoc' + - 'sphinx.ext.napoleon' + - 'sphinx.ext.viewcode' + config: + add_module_names: False \ No newline at end of file diff --git a/docs/_toc.yml b/docs/_toc.yml index a9e5d12..8c35db9 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -17,4 +17,7 @@ parts: - file: docs/examples/UsageDemo.ipynb - file: docs/examples/ModelGenerator.ipynb - +- caption: API Reference + chapters: + - file: docs/api/metrics + - file: docs/api/helpers \ No newline at end of file diff --git a/docs/docs/api/helpers.rst b/docs/docs/api/helpers.rst new file mode 100644 index 0000000..b503399 --- /dev/null +++ b/docs/docs/api/helpers.rst @@ -0,0 +1,28 @@ +Helpers +======= + +Activations +----------- + +.. automodule:: cebra_lens.activations + :members: + :undoc-members: + :show-inheritance: + + +Plotting +----------- + +.. automodule:: cebra_lens.utils_plot + :members: + :undoc-members: + :show-inheritance: + + +Helpers +----------- + +.. automodule:: cebra_lens.utils + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/docs/api/metrics.rst b/docs/docs/api/metrics.rst new file mode 100644 index 0000000..5383ee1 --- /dev/null +++ b/docs/docs/api/metrics.rst @@ -0,0 +1,37 @@ +Metrics +======= + +CKA Metric +---------- + +.. automodule:: cebra_lens.quantification.cka_metric + :members: + :undoc-members: + :show-inheritance: + +RDM Metric +---------- + +.. automodule:: cebra_lens.quantification.rdm_metric + :members: + :undoc-members: + :show-inheritance: + + +Decoder Metric +---------------- + +.. automodule:: cebra_lens.quantification.decoder + :members: + :undoc-members: + :show-inheritance: + + + +Distance Metric +---------------- + +.. automodule:: cebra_lens.quantification.distance + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/docs/intro.md b/docs/docs/intro.md index b7fc3d2..100f09f 100644 --- a/docs/docs/intro.md +++ b/docs/docs/intro.md @@ -1,10 +1,38 @@ # CEBRA-Lens -# A Library for neural representational analysis +## A Library for mechanistic interpretability of CEBRA models -This repository contains the code for Eloise's semester's project "Engineering software for neural representation analysis"(SPRING 2025). Continuing on the work which Riccardo did for his semester project "Exploring nonlinear encoders for robust vision decoding" (FALL 2014). +[๐Ÿฆ“๐Ÿ”Ž CEBRA Lens](https://github.com/AdaptiveMotorControlLab/CEBRA-lens) -[Initial pitch](initial_pitch.pdf) | [Final report](final_report.pdf) | [CEBRA Lens](https://github.com/AdaptiveMotorControlLab/CEBRA-lens) +**CEBRA-Lens** is a Python library for analyzing and interpreting neural representations learned by models trained with [CEBRA](https://github.com/AdaptiveMotorControlLab/cebra). It provides tools for mechanistic interpretability, allowing users to probe, visualize, and understand the structure of learned embeddings. The library is designed to support in-depth analysis of representational geometry, feature selectivity, and latent space dynamics in neuroscience and beyond. ๐Ÿ‘‹ We welcome contributions and will continue to expand the library in the coming years. ```{tableofcontents} -``` \ No newline at end of file +``` + +# Acknowledgements + +- This repository contains the code for [Eloise's](https://github.com/eloisehabek) semester's project "Engineering software for neural representation analysis"(SPRING 2025), + building on [Riccardo's](https://github.com/riccardoprog) semester project "Exploring nonlinear encoders for robust vision decoding" (FALL 2024). +- The work was supervised by [Cรฉlia Benquet](https://github.com/CeliaBenquet) and [Mackenzie](https://github.com/MMathisLab) at the Mathis Laboratory of Adaptive Intelligence. +- We thank the [DeepDraw project](https://elifesciences.org/articles/81499) for some [source code](https://github.com/amathislab/DeepDraw) and analysis methods. + +# Contributing Guide + +### Steps to Contribute + +1. **Fork the repository** and create a new branch: + ```bash + git checkout -b your-feature-name + ``` + +2. **Make your changes** and ensure they are well-tested. + +3. **Format your code** using `isort` and `black`: + ```bash + isort . + black . + ``` +4. **Open a Pull Request** to the `main` branch with a clear description of your changes. + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1718470 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,62 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "cebra_lens" +version = "0.0.1" +description = "Tools for analyzing embeddings from CEBRA" +authors = [ + { name = "Your Name", email = "your@email.com" }, +] +readme = "README.md" +requires-python = ">=3.9" + +dependencies = [ + "cebra", + "joblib", + "numpy<2.0; platform_system=='Windows'", + "numpy<2.0; platform_system!='Windows' and python_version<'3.10'", + "numpy; platform_system!='Windows' and python_version>='3.10'", + "literate-dataclasses", + "scikit-learn", + "scipy", + "torch>=2.4.0", + "tqdm", + "matplotlib<3.11", + "matplotlib-inline", + "requests", + "pandas", + "plotly", + "seaborn", + "jupyter-book", + "ghp-import", + "ipykernel", + "jupyter", + "nbconvert", + "nbformat", + "pylint", + "toml", + "yapf", + "black", + "isort", + "coverage", + "pytest", + "licenseheaders", + "interrogate", + "codespell", + "cffconvert" +] + +[project.optional-dependencies] +dev = [ + "pytest", + "black", + "yapf", + "isort", + "pylint" +] + +[tool.setuptools.packages.find] +where = ["."] +exclude = ["tests*"]