Skip to content

Commit 53cd6b1

Browse files
authored
remove bounds and type checks from IngredientCache (#937)
1 parent 0e1df67 commit 53cd6b1

11 files changed

+155
-46
lines changed

components/salsa-macro-rules/src/setup_accumulator_impl.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@ macro_rules! setup_accumulator_impl {
3434
static $CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Struct>> =
3535
$zalsa::IngredientCache::new();
3636

37-
$CACHE.get_or_create(zalsa, || {
38-
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>()
39-
})
37+
// SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only
38+
// ingredient created by our jar is the struct ingredient.
39+
unsafe {
40+
$CACHE.get_or_create(zalsa, || {
41+
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>()
42+
})
43+
}
4044
}
4145

4246
impl $zalsa::Accumulator for $Struct {

components/salsa-macro-rules/src/setup_input_struct.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,13 @@ macro_rules! setup_input_struct {
109109
static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> =
110110
$zalsa::IngredientCache::new();
111111

112-
CACHE.get_or_create(zalsa, || {
113-
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>()
114-
})
112+
// SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only
113+
// ingredient created by our jar is the struct ingredient.
114+
unsafe {
115+
CACHE.get_or_create(zalsa, || {
116+
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>()
117+
})
118+
}
115119
}
116120

117121
pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl<Self>, &mut $zalsa::Runtime) {

components/salsa-macro-rules/src/setup_interned_struct.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,14 @@ macro_rules! setup_interned_struct {
157157
$zalsa::IngredientCache::new();
158158

159159
let zalsa = db.zalsa();
160-
CACHE.get_or_create(zalsa, || {
161-
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>()
162-
})
160+
161+
// SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only
162+
// ingredient created by our jar is the struct ingredient.
163+
unsafe {
164+
CACHE.get_or_create(zalsa, || {
165+
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>()
166+
})
167+
}
163168
}
164169
}
165170

components/salsa-macro-rules/src/setup_tracked_fn.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,13 @@ macro_rules! setup_tracked_fn {
175175
impl $Configuration {
176176
fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> {
177177
let zalsa = db.zalsa();
178-
$FN_CACHE
179-
.get_or_create(zalsa, || zalsa.lookup_jar_by_type::<$fn_name>())
180-
.get_or_init(|| <dyn $Db as $Db>::zalsa_register_downcaster(db))
178+
179+
// SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the first
180+
// ingredient created by our jar is the function ingredient.
181+
unsafe {
182+
$FN_CACHE.get_or_create(zalsa, || zalsa.lookup_jar_by_type::<$fn_name>())
183+
}
184+
.get_or_init(|| <dyn $Db as $Db>::zalsa_register_downcaster(db))
181185
}
182186

183187
pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl<Self> {
@@ -195,9 +199,14 @@ macro_rules! setup_tracked_fn {
195199
db: &dyn $Db,
196200
) -> &$zalsa::interned::IngredientImpl<$Configuration> {
197201
let zalsa = db.zalsa();
198-
$INTERN_CACHE.get_or_create(zalsa, || {
199-
zalsa.lookup_jar_by_type::<$fn_name>().successor(0)
200-
})
202+
203+
// SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the second
204+
// ingredient created by our jar is the interned ingredient (given `needs_interner`).
205+
unsafe {
206+
$INTERN_CACHE.get_or_create(zalsa, || {
207+
zalsa.lookup_jar_by_type::<$fn_name>().successor(0)
208+
})
209+
}
201210
}
202211
}
203212
}

components/salsa-macro-rules/src/setup_tracked_struct.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,13 @@ macro_rules! setup_tracked_struct {
196196
static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> =
197197
$zalsa::IngredientCache::new();
198198

199-
CACHE.get_or_create(zalsa, || {
200-
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>()
201-
})
199+
// SAFETY: `lookup_jar_by_type` returns a valid ingredient index, and the only
200+
// ingredient created by our jar is the struct ingredient.
201+
unsafe {
202+
CACHE.get_or_create(zalsa, || {
203+
zalsa.lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>()
204+
})
205+
}
202206
}
203207
}
204208

src/ingredient.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync {
177177
}
178178

179179
impl dyn Ingredient {
180-
/// Equivalent to the `downcast` methods on `any`.
180+
/// Equivalent to the `downcast` method on `Any`.
181+
///
181182
/// Because we do not have dyn-upcasting support, we need this workaround.
182183
pub fn assert_type<T: Any>(&self) -> &T {
183184
assert_eq!(
@@ -192,7 +193,27 @@ impl dyn Ingredient {
192193
unsafe { transmute_data_ptr(self) }
193194
}
194195

195-
/// Equivalent to the `downcast` methods on `any`.
196+
/// Equivalent to the `downcast` methods on `Any`.
197+
///
198+
/// Because we do not have dyn-upcasting support, we need this workaround.
199+
///
200+
/// # Safety
201+
///
202+
/// The contained value must be of type `T`.
203+
pub unsafe fn assert_type_unchecked<T: Any>(&self) -> &T {
204+
debug_assert_eq!(
205+
self.type_id(),
206+
TypeId::of::<T>(),
207+
"ingredient `{self:?}` is not of type `{}`",
208+
std::any::type_name::<T>()
209+
);
210+
211+
// SAFETY: Guaranteed by caller.
212+
unsafe { transmute_data_ptr(self) }
213+
}
214+
215+
/// Equivalent to the `downcast` method on `Any`.
216+
///
196217
/// Because we do not have dyn-upcasting support, we need this workaround.
197218
pub fn assert_type_mut<T: Any>(&mut self) -> &mut T {
198219
assert_eq!(

src/ingredient_cache.rs

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ mod imp {
4747
/// Get a reference to the ingredient in the database.
4848
///
4949
/// If the ingredient index is not already in the cache, it will be loaded and cached.
50-
pub fn get_or_create<'db>(
50+
///
51+
/// # Safety
52+
///
53+
/// The `IngredientIndex` returned by the closure must reference a valid ingredient of
54+
/// type `I` in the provided zalsa database.
55+
pub unsafe fn get_or_create<'db>(
5156
&self,
5257
zalsa: &'db Zalsa,
5358
load_index: impl Fn() -> IngredientIndex,
@@ -57,9 +62,21 @@ mod imp {
5762
ingredient_index = self.get_or_create_index_slow(load_index).as_u32();
5863
};
5964

60-
zalsa
61-
.lookup_ingredient(IngredientIndex::from_unchecked(ingredient_index))
62-
.assert_type()
65+
// SAFETY: `ingredient_index` is initialized from a valid `IngredientIndex`.
66+
let ingredient_index = unsafe { IngredientIndex::new_unchecked(ingredient_index) };
67+
68+
// SAFETY: There are a two cases here:
69+
// - The `create_index` closure was called due to the data being uncached. In this
70+
// case, the caller guarantees the index is in-bounds and has the correct type.
71+
// - The index was cached. While the current database might not be the same database
72+
// the ingredient was initially loaded from, the `inventory` feature is enabled, so
73+
// ingredient indices are stable across databases. Thus the index is still in-bounds
74+
// and has the correct type.
75+
unsafe {
76+
zalsa
77+
.lookup_ingredient_unchecked(ingredient_index)
78+
.assert_type_unchecked()
79+
}
6380
}
6481

6582
#[cold]
@@ -134,14 +151,30 @@ mod imp {
134151
/// Get a reference to the ingredient in the database.
135152
///
136153
/// If the ingredient is not already in the cache, it will be created.
154+
///
155+
/// # Safety
156+
///
157+
/// The `IngredientIndex` returned by the closure must reference a valid ingredient of
158+
/// type `I` in the provided zalsa database.
137159
#[inline(always)]
138-
pub fn get_or_create<'db>(
160+
pub unsafe fn get_or_create<'db>(
139161
&self,
140162
zalsa: &'db Zalsa,
141163
create_index: impl Fn() -> IngredientIndex,
142164
) -> &'db I {
143165
let index = self.get_or_create_index(zalsa, create_index);
144-
zalsa.lookup_ingredient(index).assert_type::<I>()
166+
167+
// SAFETY: There are a two cases here:
168+
// - The `create_index` closure was called due to the data being uncached for the
169+
// provided database. In this case, the caller guarantees the index is in-bounds
170+
// and has the correct type.
171+
// - We verified the index was cached for the same database, by the nonce check.
172+
// Thus the initial safety argument still applies.
173+
unsafe {
174+
zalsa
175+
.lookup_ingredient_unchecked(index)
176+
.assert_type_unchecked::<I>()
177+
}
145178
}
146179

147180
pub fn get_or_create_index(
@@ -159,7 +192,9 @@ mod imp {
159192
};
160193

161194
// Unpack our `u64` into the nonce and index.
162-
let index = IngredientIndex::from_unchecked(cached_data as u32);
195+
//
196+
// SAFETY: The lower bits of `cached_data` are initialized from a valid `IngredientIndex`.
197+
let index = unsafe { IngredientIndex::new_unchecked(cached_data as u32) };
163198

164199
// SAFETY: We've checked against `UNINITIALIZED` (0) above and so the upper bits must be non-zero.
165200
let nonce = crate::nonce::Nonce::<StorageNonce>::from_u32(unsafe {

src/tracked_struct.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -975,19 +975,19 @@ mod tests {
975975
let mut d = DisambiguatorMap::default();
976976
// set up all 4 permutations of differing field values
977977
let h1 = IdentityHash {
978-
ingredient_index: IngredientIndex::from(0),
978+
ingredient_index: IngredientIndex::new(0),
979979
hash: 0,
980980
};
981981
let h2 = IdentityHash {
982-
ingredient_index: IngredientIndex::from(1),
982+
ingredient_index: IngredientIndex::new(1),
983983
hash: 0,
984984
};
985985
let h3 = IdentityHash {
986-
ingredient_index: IngredientIndex::from(0),
986+
ingredient_index: IngredientIndex::new(0),
987987
hash: 1,
988988
};
989989
let h4 = IdentityHash {
990-
ingredient_index: IngredientIndex::from(1),
990+
ingredient_index: IngredientIndex::new(1),
991991
hash: 1,
992992
};
993993
assert_eq!(d.disambiguate(h1), Disambiguator(0));
@@ -1005,42 +1005,42 @@ mod tests {
10051005
let mut d = IdentityMap::default();
10061006
// set up all 8 permutations of differing field values
10071007
let i1 = Identity {
1008-
ingredient_index: IngredientIndex::from(0),
1008+
ingredient_index: IngredientIndex::new(0),
10091009
hash: 0,
10101010
disambiguator: Disambiguator(0),
10111011
};
10121012
let i2 = Identity {
1013-
ingredient_index: IngredientIndex::from(1),
1013+
ingredient_index: IngredientIndex::new(1),
10141014
hash: 0,
10151015
disambiguator: Disambiguator(0),
10161016
};
10171017
let i3 = Identity {
1018-
ingredient_index: IngredientIndex::from(0),
1018+
ingredient_index: IngredientIndex::new(0),
10191019
hash: 1,
10201020
disambiguator: Disambiguator(0),
10211021
};
10221022
let i4 = Identity {
1023-
ingredient_index: IngredientIndex::from(1),
1023+
ingredient_index: IngredientIndex::new(1),
10241024
hash: 1,
10251025
disambiguator: Disambiguator(0),
10261026
};
10271027
let i5 = Identity {
1028-
ingredient_index: IngredientIndex::from(0),
1028+
ingredient_index: IngredientIndex::new(0),
10291029
hash: 0,
10301030
disambiguator: Disambiguator(1),
10311031
};
10321032
let i6 = Identity {
1033-
ingredient_index: IngredientIndex::from(1),
1033+
ingredient_index: IngredientIndex::new(1),
10341034
hash: 0,
10351035
disambiguator: Disambiguator(1),
10361036
};
10371037
let i7 = Identity {
1038-
ingredient_index: IngredientIndex::from(0),
1038+
ingredient_index: IngredientIndex::new(0),
10391039
hash: 1,
10401040
disambiguator: Disambiguator(1),
10411041
};
10421042
let i8 = Identity {
1043-
ingredient_index: IngredientIndex::from(1),
1043+
ingredient_index: IngredientIndex::new(1),
10441044
hash: 1,
10451045
disambiguator: Disambiguator(1),
10461046
};

src/zalsa.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,18 @@ impl IngredientIndex {
8181
const MAX_INDEX: u32 = 0x7FFF_FFFF;
8282

8383
/// Create an ingredient index from a `u32`.
84-
pub(crate) fn from(v: u32) -> Self {
84+
pub(crate) fn new(v: u32) -> Self {
8585
assert!(v <= Self::MAX_INDEX);
8686
Self(v)
8787
}
8888

8989
/// Create an ingredient index from a `u32`, without performing validating
9090
/// that the index is valid.
91-
pub(crate) fn from_unchecked(v: u32) -> Self {
91+
///
92+
/// # Safety
93+
///
94+
/// The index must be less than or equal to `IngredientIndex::MAX_INDEX`.
95+
pub(crate) unsafe fn new_unchecked(v: u32) -> Self {
9296
Self(v)
9397
}
9498

@@ -236,6 +240,7 @@ impl Zalsa {
236240
unsafe { T::memo_table(self, id, self.current_revision()) }
237241
}
238242

243+
/// Returns the ingredient at the given index, or panics if it is out-of-bounds.
239244
#[inline]
240245
pub fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient {
241246
let index = index.as_u32() as usize;
@@ -245,6 +250,19 @@ impl Zalsa {
245250
.as_ref()
246251
}
247252

253+
/// Returns the ingredient at the given index.
254+
///
255+
/// # Safety
256+
///
257+
/// The index must be in-bounds.
258+
#[inline]
259+
pub unsafe fn lookup_ingredient_unchecked(&self, index: IngredientIndex) -> &dyn Ingredient {
260+
let index = index.as_u32() as usize;
261+
262+
// SAFETY: Guaranteed by caller.
263+
unsafe { self.ingredients_vec.get_unchecked(index).as_ref() }
264+
}
265+
248266
pub(crate) fn ingredient_index_for_memo(
249267
&self,
250268
struct_ingredient_index: IngredientIndex,
@@ -331,7 +349,7 @@ impl Zalsa {
331349
fn insert_jar(&mut self, jar: ErasedJar) {
332350
let jar_type_id = (jar.type_id)();
333351

334-
let index = IngredientIndex::from(self.ingredients_vec.len() as u32);
352+
let index = IngredientIndex::new(self.ingredients_vec.len() as u32);
335353

336354
if self.jar_map.contains_key(&jar_type_id) {
337355
return;

src/zalsa_local.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,11 @@ impl QueryOrigin {
754754
QueryOriginKind::Assigned => {
755755
// SAFETY: `data.index` is initialized when the tag is `QueryOriginKind::Assigned`.
756756
let index = unsafe { self.data.index };
757-
let ingredient_index = IngredientIndex::from(self.metadata);
757+
758+
// SAFETY: `metadata` is initialized from a valid `IngredientIndex` when the tag
759+
// is `QueryOriginKind::Assigned`.
760+
let ingredient_index = unsafe { IngredientIndex::new_unchecked(self.metadata) };
761+
758762
QueryOriginRef::Assigned(DatabaseKeyIndex::new(ingredient_index, index))
759763
}
760764

0 commit comments

Comments
 (0)