Skip to content

Commit df09cb3

Browse files
committed
Change kv_cache as init parameter, fix KV cache dtype issues, and add MPS support
1. KV cache as init parameter - Move `kv_cache` from `fit()` to `__init__()` in both `TabICLClassifier` and `TabICLRegressor`, following scikit-learn convention that all configuration belongs in the constructor. 2. Fix KV cache dtype mismatch - When AMP is enabled, cache is computed in float16 and stored in the cache. Loading such a cache on CPU/MPS or CUDA without AMP) causes errors - Auto-upcasts float16/bfloat16 cache to float32 when loading on CPU, MPS, or CUDA without AMP, with a `UserWarning` 3. MPS (Apple Silicon) support - Skip auto-batching on MPS in `InferenceManager` - Fix `DiskTensor.__setitem__` to correctly move MPS tensors to CPU before disk write. - Auto-upcast KV cache to float32 on MPS (same as CPU). 4. Update README accordingly
1 parent e3a5ed8 commit df09cb3

File tree

7 files changed

+155
-87
lines changed

7 files changed

+155
-87
lines changed

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ reg.fit(X_train, y_train)
5959
reg.predict(X_test)
6060
```
6161

62-
To speed up repeated inference on the same training data, enable KV caching during `fit`. Note that this consumes additional memory to store the cached projections, so consider the trade-off
63-
for your use case:
62+
To speed up repeated inference on the same training data, enable KV caching. The cache is built during `fit` and reused across `predict` calls. Note that this consumes additional memory to store the cached projections, so consider the trade-off for your use case:
6463

6564
```python
66-
clf.fit(X_train, y_train, kv_cache=True) # caches key-value projections for training data
65+
clf = TabICLClassifier(kv_cache=True)
66+
clf.fit(X_train, y_train) # caches key-value projections for training data
6767
clf.predict(X_test) # fast: only processes test data by reusing the cached context
6868
```
6969

@@ -99,10 +99,11 @@ clf = TabICLClassifier(
9999
average_logits=True, # average logits (True) or probabilities (False)
100100
support_many_classes=True, # handle >10 classes automatically
101101
batch_size=8, # ensemble members processed together, lower to save memory
102+
kv_cache=False, # cache training data KV projections for faster repeated inference
102103
model_path=None, # path to checkpoint, None downloads from Hugging Face
103104
allow_auto_download=True, # auto-download checkpoint if not found locally
104105
checkpoint_version="tabicl-classifier-v2-20260212.ckpt", # pretrained checkpoint version
105-
device=None, # inference device, None auto-selects CUDA or CPU
106+
device=None, # inference device, None auto-selects CUDA or CPU; specify "mps" for Apple Silicon
106107
use_amp="auto", # automatic mixed precision for faster inference
107108
use_fa3="auto", # Flash Attention 3 for Hopper GPUs (e.g. H100)
108109
offload_mode="auto", # automatically decide when to use cpu/disk offloading

src/tabicl/model/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def __getitem__(self, indices) -> Tensor:
369369

370370
def __setitem__(self, indices, value: Tensor) -> None:
371371
"""Write to the tensor (automatically persists to disk)."""
372-
if value.device.type != "cpu":
372+
if not value.is_cpu:
373373
value = value.cpu()
374374
self._tensor[indices] = value
375375

@@ -1128,8 +1128,8 @@ def __call__(
11281128
if not auto_batch:
11291129
return self._run_forward(forward_fn, self._prepare_inputs(inputs))
11301130

1131-
# CPU execution: batching not supported currently
1132-
if self.exe_device.type == "cpu":
1131+
# CPU/MPS execution: batching not supported (requires CUDA memory APIs)
1132+
if self.exe_device.type in ("cpu", "mps"):
11331133
return forward_fn(**inputs)
11341134

11351135
# Extract shape/dtype info

src/tabicl/model/kv_cache.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,17 @@ def __setitem__(self, indices, other: KVCacheEntry):
5252
self.key[indices] = other.key
5353
self.value[indices] = other.value
5454

55-
def to(self, device) -> KVCacheEntry:
56-
"""Move this entry to the given device. Returns a new KVCacheEntry."""
55+
def to(self, device, dtype=None) -> KVCacheEntry:
56+
"""Move this entry to the given device and optionally cast dtype.
57+
58+
Returns a new KVCacheEntry.
59+
"""
5760
if not self.is_valid():
5861
return KVCacheEntry()
59-
return KVCacheEntry(key=self.key.to(device), value=self.value.to(device))
62+
return KVCacheEntry(
63+
key=self.key.to(device=device, dtype=dtype),
64+
value=self.value.to(device=device, dtype=dtype),
65+
)
6066

6167
@staticmethod
6268
def concat(entries: List[KVCacheEntry], dim: int = 0) -> KVCacheEntry:
@@ -117,16 +123,16 @@ def __setitem__(self, indices, other: KVCache):
117123
"""Write batch-sliced entries into this pre-allocated cache."""
118124
for idx, other_entry in other.kv.items():
119125
if idx in self.kv:
120-
assert self.kv[idx].is_valid(), f"Cannot write to cache index {idx} because it is not valid."
121-
device = self.kv[idx].key.device
122-
self.kv[idx][indices] = other_entry.to(device)
126+
target = self.kv[idx]
127+
assert target.is_valid(), f"Cannot write to cache index {idx} because it is not valid."
128+
self.kv[idx][indices] = other_entry.to(target.key.device, dtype=target.key.dtype)
123129

124-
def to(self, device) -> KVCache:
125-
"""Move all entries to the given device.
130+
def to(self, device, dtype=None) -> KVCache:
131+
"""Move all entries to the given device and optionally cast dtype.
126132
127133
Returns a new cache of the same subclass type.
128134
"""
129-
moved_kv = {idx: entry.to(device) for idx, entry in self.kv.items()}
135+
moved_kv = {idx: entry.to(device, dtype=dtype) for idx, entry in self.kv.items()}
130136
return self.__class__(kv=moved_kv)
131137

132138
@staticmethod
@@ -155,7 +161,7 @@ def concat(caches: List[KVCache], dim: int = 0) -> KVCache:
155161
merged_kv[idx] = KVCacheEntry.concat(entries, dim=dim)
156162
return KVCache(kv=merged_kv)
157163

158-
def preallocate(self, reference: KVCache, batch_shape: tuple, device="cpu"):
164+
def preallocate(self, reference: KVCache, batch_shape: tuple, device="cpu", dtype=None):
159165
"""Pre-allocate entries in this cache based on shapes from a reference.
160166
161167
K/V tensors always have shape ``(*batch, num_heads, seq_len, head_dim)``.
@@ -173,14 +179,19 @@ def preallocate(self, reference: KVCache, batch_shape: tuple, device="cpu"):
173179
174180
device : str or torch.device
175181
Device on which to allocate the tensors.
182+
183+
dtype : torch.dtype or None
184+
Data type for the allocated tensors. If None, uses the reference
185+
entry's dtype.
176186
"""
177187
for idx, ref_entry in reference.kv.items():
178188
if ref_entry.is_valid():
189+
target_dtype = dtype if dtype is not None else ref_entry.key.dtype
179190
key_shape = batch_shape + ref_entry.key.shape[-3:]
180191
value_shape = batch_shape + ref_entry.value.shape[-3:]
181192
self.kv[idx] = KVCacheEntry(
182-
key=torch.zeros(key_shape, dtype=ref_entry.key.dtype, device=device),
183-
value=torch.zeros(value_shape, dtype=ref_entry.value.dtype, device=device),
193+
key=torch.zeros(key_shape, dtype=target_dtype, device=device),
194+
value=torch.zeros(value_shape, dtype=target_dtype, device=device),
184195
)
185196

186197

@@ -293,23 +304,26 @@ def slice_batch(self, start: int, end: int) -> TabICLCache:
293304
num_classes=self.num_classes,
294305
)
295306

296-
def to(self, device) -> TabICLCache:
297-
"""Move all cached tensors to the given device.
307+
def to(self, device, dtype=None) -> TabICLCache:
308+
"""Move all cached tensors to the given device and optionally cast dtype.
298309
299310
Parameters
300311
----------
301312
device : str or torch.device
302313
Target device (e.g. ``'cpu'``, ``'cuda:0'``).
303314
315+
dtype : torch.dtype or None
316+
Target dtype. If None, preserves the existing dtype.
317+
304318
Returns
305319
-------
306320
TabICLCache
307321
New cache with all tensors on the target device.
308322
"""
309323
return TabICLCache(
310-
col_cache=self.col_cache.to(device) if self.col_cache else KVCache(),
311-
row_repr=self.row_repr.to(device) if self.row_repr is not None else None,
312-
icl_cache=self.icl_cache.to(device) if self.icl_cache else KVCache(),
324+
col_cache=self.col_cache.to(device, dtype=dtype) if self.col_cache else KVCache(),
325+
row_repr=self.row_repr.to(device=device, dtype=dtype) if self.row_repr is not None else None,
326+
icl_cache=self.icl_cache.to(device, dtype=dtype) if self.icl_cache else KVCache(),
313327
train_shape=self.train_shape,
314328
num_classes=self.num_classes,
315329
)

src/tabicl/sklearn/base.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,43 @@ def _build_inference_config(self) -> None:
158158
else:
159159
self.inference_config_ = self.inference_config
160160

161+
def _move_cache_to_device(self) -> None:
162+
"""Move KV cache to the current device, auto-upcasting if needed.
163+
164+
When the cache contains reduced-precision tensors (float16/bfloat16)
165+
and the target environment cannot use them directly (CPU, MPS, or
166+
CUDA without AMP), the tensors are upcast to float32 and a warning
167+
is emitted.
168+
"""
169+
if not (hasattr(self, "model_kv_cache_") and self.model_kv_cache_ is not None):
170+
return
171+
172+
use_amp, _ = self._resolve_amp_fa3()
173+
# CPU and MPS do not support float16 attention; CUDA needs AMP on
174+
needs_upcast = self.device_.type in ("cpu", "mps") or not use_amp
175+
upcast_dtype = torch.float32 if needs_upcast else None
176+
177+
# Warn once if we are actually upcasting reduced-precision tensors
178+
if upcast_dtype is not None:
179+
first_cache = next(iter(self.model_kv_cache_.values()))
180+
cache_dtype = next(iter(first_cache.col_cache.kv.values())).key.dtype
181+
if cache_dtype != torch.float32:
182+
if self.device_.type in ("cpu", "mps"):
183+
reason = f"{self.device_.type.upper()} does not support float16/bfloat16 attention"
184+
else:
185+
reason = "AMP is not enabled"
186+
warnings.warn(
187+
f"KV cache contains {cache_dtype} tensors (typically from AMP). "
188+
f"Automatically upcasting to float32 because {reason}.",
189+
UserWarning,
190+
stacklevel=3,
191+
)
192+
193+
device_cache = OrderedDict()
194+
for method, cache in self.model_kv_cache_.items():
195+
device_cache[method] = cache.to(self.device_, dtype=upcast_dtype)
196+
self.model_kv_cache_ = device_cache
197+
161198
def __getstate__(self):
162199
"""Customize pickle serialization.
163200
@@ -277,12 +314,8 @@ def __setstate__(self, state):
277314
# Reconstruct inference config
278315
self._build_inference_config()
279316

280-
# Move KV cache to device
281-
if hasattr(self, "model_kv_cache_") and self.model_kv_cache_ is not None:
282-
device_cache = OrderedDict()
283-
for method, cache in self.model_kv_cache_.items():
284-
device_cache[method] = cache.to(self.device_)
285-
self.model_kv_cache_ = device_cache
317+
# Move KV cache to device, auto-upcasting if needed
318+
self._move_cache_to_device()
286319

287320
def save(
288321
self,
@@ -332,7 +365,7 @@ def save(
332365
if not save_training_data and not (save_kv_cache and has_kv_cache):
333366
raise ValueError(
334367
"Cannot exclude training data when KV cache is not available or not being saved. "
335-
"Either set save_training_data=True, or fit with kv_cache=True and set save_kv_cache=True."
368+
"Either set save_training_data=True, or set kv_cache=True during init and save_kv_cache=True."
336369
)
337370

338371
# Set temporary flags for __getstate__
@@ -376,11 +409,7 @@ def load(cls, path: str | Path, device: Optional[str | torch.device] = None) ->
376409
obj._resolve_device()
377410
obj.model_.to(obj.device_)
378411
obj._build_inference_config()
379-
if hasattr(obj, "model_kv_cache_") and obj.model_kv_cache_ is not None:
380-
device_cache = OrderedDict()
381-
for method, cache in obj.model_kv_cache_.items():
382-
device_cache[method] = cache.to(obj.device_)
383-
obj.model_kv_cache_ = device_cache
412+
obj._move_cache_to_device()
384413

385414
return obj
386415

src/tabicl/sklearn/classifier.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,23 @@ class TabICLClassifier(ClassifierMixin, TabICLBaseEstimator):
8484
Adjust this parameter based on available memory. Lower values use less memory but may
8585
be slower.
8686
87+
kv_cache : bool or str, default=False
88+
Controls caching of training data computations to speed up subsequent
89+
``predict_proba``/``predict`` calls. The cache is built during ``fit()``.
90+
91+
- False: No caching.
92+
- True or "kv": Cache key-value projections from both column embedding
93+
and ICL transformer layers. Fast inference but memory-heavy for large
94+
training sets.
95+
- "repr": Cache column embedding KV projections and row interaction outputs
96+
(representations). Uses ~24x less memory than "kv" for the ICL part,
97+
at the cost of re-running the ICL transformer at predict time.
98+
99+
The cache retains whatever dtype the model produced during ``fit()``
100+
(float16 when AMP is active, float32 otherwise). If the cache is later
101+
loaded on CPU or on CUDA without AMP, the tensors are automatically
102+
upcast to float32 to avoid dtype-mismatch errors.
103+
87104
model_path : Optional[str | Path] = None
88105
Path to the pre-trained model checkpoint file.
89106
- If provided and the file exists, it's loaded directly.
@@ -108,8 +125,10 @@ class TabICLClassifier(ClassifierMixin, TabICLBaseEstimator):
108125
- `'tabicl-classifier-v1-20250208.ckpt'`: The version used in our TabICLv1 paper.
109126
110127
device : Optional[str or torch.device], default=None
111-
Device to use for inference. If None, defaults to CUDA if available, else CPU.
112-
Can be specified as a string ('cuda', 'cpu') or a torch.device object.
128+
Device to use for inference. If None, automatically selects CUDA if
129+
available, otherwise CPU. Can be specified as a string (``'cuda'``,
130+
``'cpu'``, ``'mps'``) or a ``torch.device`` object. MPS (Apple Silicon
131+
GPU) is supported but must be explicitly requested.
113132
114133
use_amp : bool or "auto", default="auto"
115134
Controls automatic mixed precision (AMP) for inference.
@@ -237,14 +256,14 @@ class TabICLClassifier(ClassifierMixin, TabICLBaseEstimator):
237256
The inference configuration.
238257
239258
cache_mode_ : str or None
240-
The caching mode used when ``fit()`` was called with ``kv_cache``.
241-
One of ``"kv"``, ``"repr"``, or ``None`` (when no caching is used).
259+
The resolved caching mode, set during ``fit()`` based on the ``kv_cache``
260+
init parameter. One of ``"kv"``, ``"repr"``, or ``None`` (no caching).
242261
243262
model_kv_cache_ : OrderedDict[str, TabICLCache] or None
244263
Pre-computed KV caches for training data, keyed by normalization method.
245-
Created when ``fit()`` is called with ``kv_cache=True``. When set, ``predict_proba()``
246-
reuses the cached key-value projections instead of re-processing training data,
247-
enabling faster inference on multiple test sets.
264+
Created during ``fit()`` when ``kv_cache`` is enabled. When set,
265+
``predict_proba()`` reuses the cached key-value projections instead of
266+
re-processing training data, enabling faster inference on multiple test sets.
248267
"""
249268

250269
def __init__(
@@ -258,6 +277,7 @@ def __init__(
258277
average_logits: bool = True,
259278
support_many_classes: bool = True,
260279
batch_size: Optional[int] = 8,
280+
kv_cache: bool | str = False,
261281
model_path: Optional[str | Path] = None,
262282
allow_auto_download: bool = True,
263283
checkpoint_version: str = "tabicl-classifier-v2-20260212.ckpt",
@@ -280,6 +300,7 @@ def __init__(
280300
self.average_logits = average_logits
281301
self.support_many_classes = support_many_classes
282302
self.batch_size = batch_size
303+
self.kv_cache = kv_cache
283304
self.model_path = model_path
284305
self.allow_auto_download = allow_auto_download
285306
self.checkpoint_version = checkpoint_version
@@ -386,7 +407,7 @@ def _load_model(self) -> None:
386407
self.model_.load_state_dict(checkpoint["state_dict"])
387408
self.model_.eval()
388409

389-
def fit(self, X: np.ndarray, y: np.ndarray, kv_cache: bool | str = False) -> TabICLClassifier:
410+
def fit(self, X: np.ndarray, y: np.ndarray) -> TabICLClassifier:
390411
"""Fit the classifier to training data.
391412
392413
Prepares the model for prediction by:
@@ -395,6 +416,7 @@ def fit(self, X: np.ndarray, y: np.ndarray, kv_cache: bool | str = False) -> Tab
395416
3. Fitting the ensemble generator to create transformed dataset views
396417
4. Loading the pre-trained TabICL model
397418
5. Optionally pre-computing KV caches for training data to speed up inference
419+
(controlled by the ``kv_cache`` init parameter)
398420
399421
The model itself is not trained on the data; it uses in-context learning
400422
at inference time.
@@ -407,17 +429,6 @@ def fit(self, X: np.ndarray, y: np.ndarray, kv_cache: bool | str = False) -> Tab
407429
y : array-like of shape (n_samples,)
408430
Training target labels.
409431
410-
kv_cache : bool or str, default=False
411-
Controls caching of training data computations to speed up subsequent
412-
``predict_proba``/``predict`` calls.
413-
- False: No caching.
414-
- True or "kv": Cache key-value projections from both column embedding
415-
and ICL transformer layers. Fast inference but memory-heavy for large
416-
training sets.
417-
- "repr": Cache column embedding KV projections and row interaction outputs
418-
(representations). Uses ~24x less memory than "kv" for the ICL part,
419-
at the cost of re-running the ICL transformer at predict time.
420-
421432
Returns
422433
-------
423434
self : TabICLClassifier
@@ -454,7 +465,7 @@ def fit(self, X: np.ndarray, y: np.ndarray, kv_cache: bool | str = False) -> Tab
454465
self.n_classes_ = len(self.y_encoder_.classes_)
455466

456467
if self.n_classes_ > self.model_.max_classes:
457-
if kv_cache:
468+
if self.kv_cache:
458469
raise ValueError(
459470
f"KV caching is not supported when the number of classes ({self.n_classes_}) exceeds the max number "
460471
f"of classes ({self.model_.max_classes}) natively supported by the model."
@@ -491,13 +502,13 @@ def fit(self, X: np.ndarray, y: np.ndarray, kv_cache: bool | str = False) -> Tab
491502
self.ensemble_generator_.fit(X, y)
492503

493504
self.model_kv_cache_ = None
494-
if kv_cache:
495-
if kv_cache is True or kv_cache == "kv":
505+
if self.kv_cache:
506+
if self.kv_cache is True or self.kv_cache == "kv":
496507
self.cache_mode_ = "kv"
497-
elif kv_cache == "repr":
508+
elif self.kv_cache == "repr":
498509
self.cache_mode_ = "repr"
499510
else:
500-
raise ValueError(f"Invalid kv_cache value '{kv_cache}'. Expected False, True, 'kv', or 'repr'.")
511+
raise ValueError(f"Invalid kv_cache value '{self.kv_cache}'. Expected False, True, 'kv', or 'repr'.")
501512
self._build_kv_cache()
502513

503514
return self

0 commit comments

Comments
 (0)