Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ jobs:
python -m pip install --upgrade pip setuptools wheel
pip install -r requirements.txt

- name: Run the formatter
shell: bash -el {0}
run: |
make format

- name: Run the spelling detector
shell: bash -el {0}
run: |
make codespell

- name: Check the documentation coverage
shell: bash -el {0}
run: |
make interrogate

- name: Run all pytest tests
shell: bash -el {0}
run: |
Expand Down
55 changes: 55 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
CEBRA_LENS_VERSION := 0.1.0.dev0

dist:
python3 -m pip install virtualenv
python3 -m pip install --upgrade build twine
python3 -m build --wheel --sdist

build: dist

test:
python -m pytest --ff tests

interrogate:
interrogate \
--ignore-property-decorators \
--ignore-init-method \
--verbose \
--ignore-semiprivate \
--ignore-private \
--ignore-magic \
--omit-covered-files \
-f 80 \
cebra_lens

docs:
export PYTHONPATH=$(pwd)
jupyter-book build docs

docs-touch:
find docs/docs -iname '*.md' -exec touch {} \;
jupyter-book build docs/docs

docs-strict:
jupyter-book build docs --keep-going --strict

# Serve the docs
serve_docs:
python -m http.server 8080 --bind 127.0.0.1 -d docs/_build/html

# Serve the entire page
serve_page:
python -m http.server 8080 --bind 127.0.0.1 -d docs/_build/html

# Format code in the main package and docs
format:
yapf -i -p -r cebra_lens
yapf -i -p -r tests
isort cebra_lens/
isort tests/

codespell:
codespell cebra_lens/ tests/ docs/docs/*.md -L "nce, nd"


.PHONY: docs docs-touch docs-strict serve_docs serve_page
9 changes: 4 additions & 5 deletions cebra_lens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# example of structure so that you can directly use the functions get_layer_activations instead of having to do CEBRA_Lens.activations.get_layer_activations
from .activations import *
from .quantification import *
from .quantification.decoding import *
from .quantification.distance import *
from .quantification.cka_metric import *
from .quantification.decoder import *
from .quantification.distance import *
from .quantification.rdm_metric import *
from .quantification.tsne import *
from .matplotlib import *
from .utils import *
from .utils_allen import *
from .utils_hpc import *
from .utils import *
from .utils_plot import *

# selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean
# __all__ = ['get_layer_activations']
64 changes: 32 additions & 32 deletions cebra_lens/activations.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Functions to retrieve and handle layer activations"""

from typing import Dict, List, Optional, Tuple, Type

import cebra
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from typing import Tuple, Dict, List, Type, Optional
from .matplotlib import plot_activations
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from .utils_plot import plot_activations


def _cut_array(
array: npt.NDArray, cut_indices: Tuple[np.int64, np.int64]
) -> npt.NDArray:
def _cut_array(array: npt.NDArray,
cut_indices: Tuple[np.int64, np.int64]) -> npt.NDArray:
"""
Slices the input array based on the provided cut indices.
This is used to remove the padding from activations in `get_activations_model`.
Expand All @@ -36,7 +37,7 @@ def _cut_array(
sliced_array = array
else:
# Otherwise, slice the array
sliced_array = array[:, start : end if end != 0 else start :]
sliced_array = array[:, start:end if end != 0 else start:]
return sliced_array


Expand Down Expand Up @@ -80,10 +81,13 @@ def get_cut_indices(
# add for output layer
cut_indices.append((0, 0))
elif layer_type == None:
raise NotImplementedError("Padding handling not implemented for 'all'.")
raise NotImplementedError(
"Padding handling not implemented to handle activations for all layer types.",
"Set layer_type to nn.Conv1d to use the default padding handling.")
else:
# need to analyze the padding from the last output of Conv1 and apply the same cut
raise NotImplementedError(f"Padding handling not implemented for {layer_type}.")
raise NotImplementedError(
f"Padding handling not implemented for {layer_type}.")
return cut_indices


Expand All @@ -93,7 +97,7 @@ def get_activations_model(
session_id: int = -1,
name: str = "single",
instance: int = 0,
layer_type: Type[nn.Module] = None,
layer_type: Type[nn.Module] = nn.Conv1d,
) -> Dict[str, npt.NDArray]:
"""
Extracts activations from a single model layer.
Expand All @@ -111,7 +115,8 @@ def get_activations_model(
instance : int
The instance number for the model, used to differentiate between models from the same model category.
layer_type : Type[nn.Module]
The type of layer to extract activations from. Defaults to None, meaning extracts activations from all layers.
The type of layer to extract activations from. None means it extracts activations from all layers.
Default is nn.Conv1d, which is the most common layer type used in CEBRA models.

Returns:
--------
Expand All @@ -125,26 +130,25 @@ def get_activations_model(
activations = {}
transform_kwargs = {}
if model.solver_name_ in [
"multi-session",
"multi-session-aux",
"multiobjective-solver",
"multi-session",
"multi-session-aux",
"multiobjective-solver",
]:

model_ = model.model_[session_id]
transform_kwargs.update({"session_id": session_id})

elif model.solver_name_ in [
"single-session",
"single-session-aux",
"single-session-hybrid",
"single-session-full",
"single-session",
"single-session-aux",
"single-session-hybrid",
"single-session-full",
]:
model_ = model.model_

else:
raise NotImplementedError(
f"Solver {model.solver_name_} is not yet implemented."
)
f"Solver {model.solver_name_} is not yet implemented.")

activations, handles, conv_layer_info = _attach_hooks(
activations=activations,
Expand Down Expand Up @@ -209,14 +213,14 @@ def process_activations(
name=model_name,
instance=i,
layer_type=layer_type,
)
)
))

return activations


# Function to create a hook that stores the activations in the dictionary
def _get_activation(name: str, activations: Dict):

def hook(model, input, output):
activations[name] = output.detach().squeeze().numpy()

Expand Down Expand Up @@ -262,8 +266,7 @@ def _attach_hooks(
# attach hook to the layer_type and to the output layer
if isinstance(model.net[i], layer_type) or i == len(model.net) - 1:
hook, activations = _get_activation(
f"{name}_{instance}_layer_{num_layer}", activations
)
f"{name}_{instance}_layer_{num_layer}", activations)
if isinstance(model.net[i], layer_type):
conv_layer_info.append(model.net[i].kernel_size[0])
handle = model.net[i].register_forward_hook(hook)
Expand Down Expand Up @@ -298,8 +301,7 @@ def _attach_hooks(

else:
hook, activations = _get_activation(
f"{name}_{instance}_layer_{num_layer}", activations
)
f"{name}_{instance}_layer_{num_layer}", activations)

handle = model.net[i].register_forward_hook(hook)
handles.append(handle)
Expand All @@ -309,8 +311,7 @@ def _attach_hooks(


def aggregate_activations(
activations: Dict[str, npt.NDArray],
) -> Dict[str, npt.NDArray]:
activations: Dict[str, npt.NDArray], ) -> Dict[str, npt.NDArray]:
"""
Aggregates activations by model identifier aka. instance.
This function takes a dictionary of activations where the keys are strings containing model identifiers and layer information,
Expand Down Expand Up @@ -387,8 +388,7 @@ def get_activations(
activations = activations or {}

aggregated_activations = aggregate_activations(
process_activations(models, data, session_id, activations, layer_type)
)
process_activations(models, data, session_id, activations, layer_type))

activations_dict = {}
for key, value in aggregated_activations.items():
Expand Down
8 changes: 4 additions & 4 deletions cebra_lens/quantification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .base import *
from .cka_metric import *
from .rdm_metric import *
from .misc import *
from .decoder import *
from .distance import *
from .decoding import *
from .base import *
from .misc import *
from .rdm_metric import *
from .tsne import *
15 changes: 8 additions & 7 deletions cebra_lens/quantification/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from tqdm import tqdm
import numpy as np
import pickle
import types
from typing import List, Union, Dict
from abc import *
from pathlib import Path
from typing import Dict, List, Union

import numpy as np
import numpy.typing as npt
from tqdm import tqdm


class _BaseMetric:
Expand All @@ -14,7 +15,8 @@ class _BaseMetric:
"""

@abstractmethod
def compute(self, activations: Dict[str, npt.NDArray]) -> Dict[str, npt.NDArray]:
def compute(self,
activations: Dict[str, npt.NDArray]) -> Dict[str, npt.NDArray]:
"""
Every metric which inherits ``_BaseMetric`` needs to implement a compute function.
The compute function is specific to a metric, e.g. intra-bin distance, RDM, CKA,...
Expand Down Expand Up @@ -66,9 +68,8 @@ def save(self, filepath: str, data: Dict[str, npt.NDArray]) -> None:
and the value is a npt.NDArray containing for all the models under that label the calculated data.
"""
filepath = Path(filepath)
custom_filepath = filepath.with_stem(
filepath.stem + f"_{self.__class__.__name__}"
)
custom_filepath = filepath.with_stem(filepath.stem +
f"_{self.__class__.__name__}")
with open(custom_filepath, "wb") as f:
pickle.dump(data, f)

Expand Down
Loading