Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
40a3b28
convert checkpoint entrypoint
brian-dellabetta Feb 26, 2026
e334817
Update convert_checkpoint README to reflect implementation
brian-dellabetta Feb 26, 2026
dfbcf30
quality fixeS
brian-dellabetta Feb 26, 2026
4c39a6f
quality fixes
brian-dellabetta Feb 26, 2026
10039f3
include kv_cache_scheme
brian-dellabetta Feb 27, 2026
3075189
move converters to folder
brian-dellabetta Feb 27, 2026
5a1ca69
cleanup
brian-dellabetta Feb 27, 2026
58d0230
stylefixes
brian-dellabetta Feb 27, 2026
f5de237
unit tests
brian-dellabetta Feb 27, 2026
5a64346
missing init file
brian-dellabetta Feb 27, 2026
712f9ce
config_group key
brian-dellabetta Mar 2, 2026
c65eef1
drop autoawq todo stub
brian-dellabetta Mar 3, 2026
66cdd0f
Converter docstrings
brian-dellabetta Mar 3, 2026
0ba80e3
break out into convert module
brian-dellabetta Mar 3, 2026
96125d6
test fixes
brian-dellabetta Mar 3, 2026
cb9c123
rename
brian-dellabetta Mar 3, 2026
88b2691
rename
brian-dellabetta Mar 3, 2026
6a2c622
rename
brian-dellabetta Mar 3, 2026
d7590a6
working example
brian-dellabetta Mar 3, 2026
278f177
docstrings/cosmetics
brian-dellabetta Mar 3, 2026
b852c3c
stricter validate, test for entrypoint
brian-dellabetta Mar 3, 2026
fea715e
test fixes
brian-dellabetta Mar 3, 2026
d3ddadf
style fixes
brian-dellabetta Mar 3, 2026
60ac422
exec_jobs helper, remove test
brian-dellabetta Mar 4, 2026
33deb5d
stylefix
brian-dellabetta Mar 4, 2026
934abd2
p1
brian-dellabetta Mar 4, 2026
8c673b8
example
brian-dellabetta Mar 4, 2026
40e2430
impl
brian-dellabetta Mar 4, 2026
4d72ba7
example
brian-dellabetta Mar 4, 2026
cbf7470
check1
brian-dellabetta Mar 4, 2026
fd53387
claude unit tests
brian-dellabetta Mar 4, 2026
783a281
docstring updates
brian-dellabetta Mar 4, 2026
acabb1f
cosmetic
brian-dellabetta Mar 4, 2026
65a0e10
comment fix
brian-dellabetta Mar 5, 2026
da6d332
claude consolidate tensors
brian-dellabetta Mar 5, 2026
0b85709
claude consolidate tensors
brian-dellabetta Mar 5, 2026
a1ab6c8
claude consolidate tensors
brian-dellabetta Mar 5, 2026
2107988
updates
brian-dellabetta Mar 6, 2026
cc8cf45
claude
brian-dellabetta Mar 6, 2026
98681fe
merge
brian-dellabetta Mar 6, 2026
035a911
remove consolidate
brian-dellabetta Mar 6, 2026
d9f99ce
remove consolidate
brian-dellabetta Mar 6, 2026
58ecae1
cleanup
brian-dellabetta Mar 6, 2026
b03f20c
model config
brian-dellabetta Mar 9, 2026
55d0560
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Mar 9, 2026
98764b3
mem overflow on update_global_scales, max_memory issue
brian-dellabetta Mar 9, 2026
878fa9b
use torch.nn.Linear
brian-dellabetta Mar 9, 2026
6961d51
ready to run
brian-dellabetta Mar 10, 2026
27ce7aa
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Mar 10, 2026
63ab5da
bf16 kernel p1
brian-dellabetta Mar 11, 2026
4c86202
contiguous
brian-dellabetta Mar 11, 2026
489521e
torch native kernel
brian-dellabetta Mar 11, 2026
15c2c7f
claude less memory kernel
brian-dellabetta Mar 11, 2026
4cd6962
working
brian-dellabetta Mar 11, 2026
991e330
readme updates
brian-dellabetta Mar 11, 2026
71d4b2b
update example based on changes
brian-dellabetta Mar 13, 2026
2ad62f6
drop act_quant from mla
brian-dellabetta Mar 13, 2026
bbc4427
clean job for loop
brian-dellabetta Mar 16, 2026
69ebd26
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Mar 16, 2026
f062fde
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Mar 17, 2026
76b2163
cleanup after e2e proof-of-concept
brian-dellabetta Mar 19, 2026
8283978
more cleanup
brian-dellabetta Mar 19, 2026
29c9888
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Mar 19, 2026
87398cc
stylefixes
brian-dellabetta Mar 19, 2026
86588d3
single-threaded comment
brian-dellabetta Mar 19, 2026
9693b45
more example comments
brian-dellabetta Mar 19, 2026
8069130
claude unit tests
brian-dellabetta Mar 20, 2026
8a21aa3
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Mar 20, 2026
182b8dc
cleanup
brian-dellabetta Mar 20, 2026
0ae156b
stylefixes
brian-dellabetta Mar 20, 2026
d184f48
update_config cleanup
brian-dellabetta Mar 20, 2026
886d08d
merge conflicts resolved
brian-dellabetta Mar 24, 2026
4c74306
post merge bugfix
brian-dellabetta Mar 24, 2026
ba33789
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Mar 29, 2026
a3247fe
Merge branch 'main' into bdellabe/ds32-to-bfloat16
brian-dellabetta Apr 1, 2026
320ce06
requires/is_required_by interface with build inverse weights map
brian-dellabetta Apr 1, 2026
9fbce9f
example update, dequantizer rename
brian-dellabetta Apr 1, 2026
b6f516f
typo
brian-dellabetta Apr 1, 2026
2de9cd7
typo
brian-dellabetta Apr 1, 2026
acc01d0
typo
brian-dellabetta Apr 1, 2026
e57a2c1
typo
brian-dellabetta Apr 1, 2026
70343c5
working poc with requires
brian-dellabetta Apr 2, 2026
44e89f2
update reindex test
brian-dellabetta Apr 2, 2026
47c9e36
single job fix
brian-dellabetta Apr 2, 2026
129a9fe
requires -> get_dependencies, is_required bool
brian-dellabetta Apr 2, 2026
23af737
unit test bugfix
brian-dellabetta Apr 2, 2026
4ed0bc1
docstrings
brian-dellabetta Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/convert_checkpoint/deepseek32_fpblock_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from compressed_tensors.entrypoints.convert import (
convert_checkpoint,
FP8BlockDequantizer,
)

# deepseek-ai/DeepSeek-V3.2 checkpoint has layers that are quantized in the FP8
# quant method's FP8_BLOCK scheme. This script will upconvert to bfloat16 so that
# the model can be compressed in another configuration.
MODEL_ID = "deepseek-ai/DeepSeek-V3.2"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-bf16"

# Convert DeepSeek-V3.2 back to dense bfloat16 format
convert_checkpoint(
model_stub=MODEL_ID,
save_directory=SAVE_DIR,
converter=FP8BlockDequantizer(
# `deepseek-ai/DeepSeek-V3.2` fp8-block-quantized layers, found by inspection
targets=[
r"re:.*mlp.*\.(gate_up|gate|up|down)_proj$",
r"re:.*self_attn.*\.(kv_b|o|q_a|q_b)_proj$",
r"re:.*self_attn.kv_a_proj_with_mqa$",
r"re:.*self_attn.indexer.(wk|wq_b)$",
],
),
max_workers=4,
)
25 changes: 25 additions & 0 deletions examples/convert_checkpoint/qwen3_fpblock_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from compressed_tensors.entrypoints.convert import (
convert_checkpoint,
FP8BlockDequantizer,
)

MODEL_ID = "qwen-community/Qwen3-4B-FP8"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1].rstrip("-FP8")

# Convert Qwen3-4B-FP8 back to dense bfloat16 format
convert_checkpoint(
model_stub=MODEL_ID,
save_directory=SAVE_DIR,
converter=FP8BlockDequantizer(
# qwen-community/Qwen3-4B-FP8's fp8-block-quantized layers, found by inspection
targets=[
r"re:.*mlp.*\.(gate_up|gate|up|down)_proj$",
r"re:.*self_attn.*\.(q|k|v|o)_proj$",
],
weight_block_size=[128, 128],
),
max_workers=8,
)
68 changes: 52 additions & 16 deletions src/compressed_tensors/entrypoints/convert/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import os
import shutil
from collections.abc import Callable
Expand All @@ -13,8 +14,12 @@
validate_file,
write_checkpoint_quantization_config,
)
from compressed_tensors.entrypoints.convert.converters import Converter
from compressed_tensors.entrypoints.convert.converters import (
Converter,
build_inverse_weight_maps,
)
from compressed_tensors.utils.safetensors_load import (
find_safetensors_index_file,
get_checkpoint_files,
is_weights_file,
update_safetensors_index,
Expand All @@ -32,11 +37,14 @@ def convert_checkpoint(
max_workers: int = 1,
):
"""
Convert a model checkpoint to compressed-tensors format without loading it up
in memory, instead operating directly on the model safetensors files. This
entrypoint operates on a model stub or folder containing weights saved in
safetensors files, and updates the corresponding quantization_config field in
the config.json. All additional files will be copied to new checkpoint.
Convert a model checkpoint to either:
- its equivalent quantized format in compressed-tensors
- the unquantized format
without loading it up in memory, instead operating directly on the model
safetensors files. This entrypoint operates on a model stub or folder containing
weights saved in safetensors files, and updates the corresponding
quantization_config field in the config.json. All additional files will be
copied to new checkpoint.

:param model_stub: huggingface model hub or path to local weights files
:param save_directory: new checkpoint will be saved in this directory.
Expand All @@ -45,38 +53,59 @@ def convert_checkpoint(
:param converters: converter we wish to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
# validate arguments
# get all model_files for checkpoint
model_files = get_checkpoint_files(model_stub)

# 0. collect safetensors files, copy files
# Read weight map from safetensors.index.json
index_file = find_safetensors_index_file(model_files)
with open(index_file, "r") as f:
weight_map: dict[str, str] = json.load(f)["weight_map"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will error if there is no index file. Use something like this instead

def get_weight_map(model_files):
  index_file = find_safetensors_index_file(model_files)
  if index_file is not None:
    with open(index_file, "r") as f:
        return json.load(f)["weight_map"]
  else:
    with safe_open(SAFE_WEIGHTS_NAME, "r") as file:
      return {tensor: SAFE_WEIGHTS_NAME for tensor in file.keys()}


# Build inverse_weight_maps, so that each job knows how to load up every necessary
# weight and its dependencies
inverse_weight_maps = build_inverse_weight_maps(
weight_map=weight_map,
model_files=model_files,
converters=[converter],
)

# Build validation/conversion jobs, copy over any other file
validate_jobs = []
convert_jobs = []
for file_path, resolved_path in model_files.items():
save_path = Path(save_directory) / file_path

if file_path.endswith("safetensors"):
validate_jobs.append((validate_file, resolved_path, converter))
convert_jobs.append((convert_file, resolved_path, save_path, converter))
assert (
file_path in inverse_weight_maps
), f"Could not find inverse_weight_map for file {file_path}"
validate_jobs.append(
(validate_file, inverse_weight_maps[file_path], converter)
)
convert_jobs.append(
(convert_file, inverse_weight_maps[file_path], save_path, converter)
)

else:
if is_weights_file(file_path):
logger.warning(f"Skip processing for weights file {file_path}")
save_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Copying {file_path} {save_path}")
shutil.copyfile(resolved_path, save_path)
if str(resolved_path) != str(save_path):
save_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Copying {file_path} {save_path}")
shutil.copyfile(resolved_path, save_path)

# 1. validate quantizable tensors fail fast before long-running quantization
# Validate before long-running procssing job
exec_jobs(validate_jobs, max_workers, desc="Validating")

# 2-5. quantize and compress weights
# Process weights, accumulating total bytes used and the new weight_map
total_size = 0
weight_map = dict()
convert_results = exec_jobs(convert_jobs, max_workers, desc="Converting")
for _total_size, _weight_map in convert_results:
total_size += _total_size
weight_map.update(_weight_map)

# 5. update config and safetensors index
# Update config and safetensors index
write_checkpoint_quantization_config(save_directory, converter)
update_safetensors_index(save_directory, total_size, weight_map)

Expand All @@ -93,6 +122,13 @@ def exec_jobs(
:param desc: tqdm description
"""
results = []

# For easier debugging, don't run single-threaded jobs via ThreadPoolExecutor
if max_workers == 1:
for job in tqdm.tqdm(jobs, desc=desc):
results.append(job[0](*job[1:]))
return results

with ThreadPoolExecutor(max_workers) as executor:
futures = [executor.submit(*job) for job in jobs]
for future in tqdm.tqdm(as_completed(futures), total=len(futures), desc=desc):
Expand Down
43 changes: 30 additions & 13 deletions src/compressed_tensors/entrypoints/convert/convert_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from compressed_tensors import __version__ as ct_version
from compressed_tensors.base import COMPRESSION_VERSION_NAME, QUANTIZATION_CONFIG_NAME
from compressed_tensors.entrypoints.convert import Converter
from compressed_tensors.utils.safetensors_load import find_config_path
from compressed_tensors.utils.safetensors_load import (
InverseWeightMap,
find_config_path,
load_tensors_from_inverse_weight_map,
)
from loguru import logger
from safetensors.torch import load_file, save_file
from safetensors.torch import save_file


__all__ = [
Expand All @@ -34,17 +38,20 @@ def write_checkpoint_quantization_config(
:param converter: Converter instance whose create_config() produces the
updated quantization config
"""
quant_config = converter.create_config()

quant_config_data = quant_config.model_dump()
quant_config_data[COMPRESSION_VERSION_NAME] = ct_version
quant_config_data = None
if (quant_config := converter.create_config()) is not None:
quant_config_data = quant_config.model_dump()
quant_config_data[COMPRESSION_VERSION_NAME] = ct_version

config_file_path = find_config_path(save_directory)
if config_file_path is not None:
with open(config_file_path, "r") as file:
config_data = json.load(file)

config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
if quant_config_data is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qconfig field is not guaranteed to exist

Suggested change
if quant_config_data is None:
if quant_config_data is None and QUANTIZATION_CONFIG_NAME in config_data:

del config_data[QUANTIZATION_CONFIG_NAME]
else:
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data

with open(config_file_path, "w") as file:
json.dump(config_data, file, indent=2, sort_keys=True)
Expand All @@ -57,35 +64,45 @@ def write_checkpoint_quantization_config(


def validate_file(
file_path: str | os.PathLike,
inverse_weight_map: InverseWeightMap,
converter: Converter,
):
"""
Validate that each quantizable tensor in a safetensors file can be quantized.

:param file_path: safetensors file to validate
:param inverse_weight_map: mapping of resolved source file path ->
list of tensor names to load from that file. Precomputed by
build_inverse_weight_map() in the job-building phase.
Example: {"/path/shard0.safetensors": ["q_proj.weight"],
"/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]}
:param converter: converter we wish to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
"""
tensors = load_file(file_path)
tensors = load_tensors_from_inverse_weight_map(inverse_weight_map)

converter.validate(tensors)


def convert_file(
file_path: str | os.PathLike,
inverse_weight_map: InverseWeightMap,
save_path: str | os.PathLike,
converter: Converter,
) -> tuple[int, dict[str, str]]:
"""
Convert tensors in a given safetensors file

:param file_path: safetensors file to process
:param inverse_weight_map: mapping of resolved source file path ->
list of tensor names to load from that file. Precomputed by
build_inverse_weight_map() in the job-building phase.
Example: {"/path/shard0.safetensors": ["q_proj.weight"],
"/path/shard1.safetensors": ["k_proj.weight", "v_proj.weight"]}
:param save_path: save path of file with quantized weights
:param converter: converter we wish to apply to the checkpoint,
e.g. conversion of some layers from some format to compressed-tensors
:returns: tuple of (total_size, weight_map), respectively the total size in bytes
of the saved file and dictionary of weight name -> save path
"""
tensors = load_file(file_path)
tensors = load_tensors_from_inverse_weight_map(inverse_weight_map)

converter.process(tensors)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@

from .base import *
from .modelopt_nvfp4 import *
from .fp8block_dequantizer import *
97 changes: 96 additions & 1 deletion src/compressed_tensors/entrypoints/convert/converters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING, Protocol

import torch
from compressed_tensors.utils.safetensors_load import InverseWeightMap


__all__ = ["Converter", "build_inverse_weight_maps"]

if TYPE_CHECKING:
from compressed_tensors.quantization import QuantizationConfig

Expand Down Expand Up @@ -42,9 +46,100 @@ def validate(self, tensors: dict[str, torch.Tensor]):
"""
pass

def create_config(self) -> QuantizationConfig:
def create_config(self) -> QuantizationConfig | None:
"""
Create compressed-tensors QuantizationConfig so that it can be set in the
new model checkpoint's config.json.
If the converter is moving checkpoint to full-precision, have this function
return None, and quantization_config will be removed from config.json
Comment on lines +53 to +54
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If the converter is moving checkpoint to full-precision, have this function
return None, and quantization_config will be removed from config.json
A return value of `None` means that quantization_config will be removed from config.json

"""
pass
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this an abstract method to force implementers to make a decision here

Suggested change
pass
raise NotImplementedError()


def get_dependencies(self, weight_name: str) -> dict[str, bool]:
"""
Given a weight name, return a dictionary of all dependency weight names, so that
weights can be processed correctly and in a parallelized fashion.
If a dependency is optional, the value associated with the key should be False.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give an example of an "optional dependency" to make the concept clear

If the value is True, it is assumed the weight is required and will error out
during the job build phase if not found.
If there are no dependencies, an empty dict should be returned.

:returns: dict[str, bool] {dependency weight name -> whether it is required}
"""
pass
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either abstract method or return empty dict



def build_inverse_weight_maps(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably goes in a utils file, not the base class file

weight_map: dict[str, str],
model_files: dict[str, str],
converters: list[Converter],
) -> dict[str, InverseWeightMap]:
"""
For a given output shard, precompute exactly which tensors to load from
which source files — including required partner tensors from other shards.

This is necessary because some converters require that a set of tensors are
accessible in order for them to be processed correctly.

:param shard_name: the shard filename this job will process and save
:param weight_map: tensor name -> shard filename (from safetensors.index.json)
:param model_files: shard filename -> resolved absolute path
:return: {resolved_file_path: [tensor_names_to_load]}
"""

def get_dependencies_recursive(
weight_name: str, converters: list[Converter], current_deps: dict[str, bool]
) -> dict[str, bool]:
for converter in converters:
for dep, is_required in converter.get_dependencies(weight_name).items():
if dep not in current_deps:
current_deps[dep] = is_required
get_dependencies_recursive(dep, converters, current_deps)

return current_deps

# map of weight name -> ( map of dependency name -> is_required )
weight_deps_dict: dict[str, set[str]] = defaultdict(set)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
weight_deps_dict: dict[str, set[str]] = defaultdict(set)
weight_deps_dict: dict[str, dict[str, bool]] = dict()

for weight_name, weight_shard_name in weight_map.items():
weight_deps_dict[weight_name] = get_dependencies_recursive(
weight_name, converters, {}
)
assert (
weight_name not in weight_deps_dict[weight_name]
), f"{weight_name} found in dependencies {weight_deps_dict[weight_name]}"

# set of all dependencies (i.e. all weight names required by another)
all_dependencies: set[str] = set()
for values in weight_deps_dict.values():
for value in values:
all_dependencies.add(value)
Comment on lines +112 to +115
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_dependencies: set[str] = set()
for values in weight_deps_dict.values():
for value in values:
all_dependencies.add(value)
all_dependencies: set[str] = set().union(*weight_deps_dict.values())


inverse_weight_maps: dict[str, InverseWeightMap] = defaultdict(
lambda: defaultdict(list)
)
for weight_name, weight_shard_name in weight_map.items():
if weight_name in all_dependencies:
# weight is a partner to some other primary tensor, skip it
continue

# weight is purely a primary weight, is not a dependency of anything
# add it and all its dependencies
inverse_weight_map: InverseWeightMap = inverse_weight_maps[weight_shard_name]
dependency_weights = weight_deps_dict[weight_name]
for weight_to_add_name, is_required in [
(weight_name, True),
*dependency_weights.items(),
]:
if weight_to_add_name not in weight_map:
if is_required:
raise ValueError(
f"Required weight {weight_to_add_name} not found in weight map"
)
else:
continue
weight_to_add_shard_name = weight_map[weight_to_add_name]
resolved_path = model_files.get(weight_to_add_shard_name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
resolved_path = model_files.get(weight_to_add_shard_name)
resolved_path = model_files[weight_to_add_shard_name]

inverse_weight_map[resolved_path].append(weight_to_add_name)

# return dicts, not defaultdicts, to avoid silent errors
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty

return {k: dict(v) for k, v in inverse_weight_maps.items()}
Loading
Loading