-
Notifications
You must be signed in to change notification settings - Fork 225
[FEA] Support raft::KeyValuePair Inputs in linalg::map #2913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
viclafargue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Even though this might need a second pair of eyes looking at it.
| map_offset(handle, out_ref_view, op); | ||
| map_offset(handle, out_view, op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use a known trusted reference (such as Thrust) to perform the reference transformation here instead of reusing the same KVP mapping?
aamijar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Tarang, although I don't have experience with them myself, I heard that concepts can be nicer than SFINAE, now that we have upgraded to c++ 20.
dantegd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just had a few comments and questions
| */ | ||
|
|
||
| template <typename K, typename V> | ||
| struct IOType<KeyValuePair<K, V>, 1, std::enable_if_t<sizeof(KeyValuePair<K, V>) == 4>> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vectorized I/O here works by type-punning KeyValuePair<K,V> through int32_t/int2/int4, no? This is only safe if KeyValuePair is trivially copyable (i.e., memcpy-safe). If someone later creates a KeyValuePair with a non-trivially-copyable type this would silently produce undefined behavior. Adding static_assert(std::is_trivially_copyable_v<KeyValuePair<K, V>>) here would catch this at compile time and make the assumption explicit.
|
|
||
| #pragma once | ||
|
|
||
| #include <raft/core/kvp.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding #include <raft/core/kvp.hpp> here means every file that includes vectorized.cuh now also pulls in kvp.hpp, even if it doesn't use KeyValuePair. This could affect compile times, though
kvp.hpp is small, so the impact may be minimal in practice. The concern is more about hygiene and precedent—if every utility header pulls in "just one more small include," it compounds over time.
We could consider:
(1) is a forward declaration sufficient here? or
(2) should the KVP IOType specializations live in a separate header (e.g., vectorized_kvp.cuh) that's only included where needed?
| uniform(handle, r, fval2.data(), len, float(-1.0), float(1.0)); | ||
| uniform(handle, r, fval3.data(), len, float(-1.0), float(1.0)); | ||
|
|
||
| raft::device_resources handle{stream}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This declares a new handle{stream} inside the if constexpr block, which shadows the class member handle. It's fine becasue both refer to the same stream, but shadowing class members can be iffy, it can cause subtle bugs if someone later modifies the code expecting to use the class member. Consider renaming to local_handle or just using the existing this->handle.
| struct CompareKVP { | ||
| float eps; | ||
| CompareKVP(float eps_) : eps(eps_) {} | ||
| CompareKVP(KVP eps_) : eps(eps_.value) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This constructor takes a KVP but only uses eps_.value as the tolerance, ignoring the key. This appears to be for template compatibility with the test macros (which pass params.tolerance) right?
This could be confusing to a a reader that might expect both key and value to be used as tolerances. Consider adding a comment like
// For template compatibility; only value is used as float tolerance or changing the test infrastructure to
pass the tolerance directly as a float.
| const std::vector<MapInputs<KVP, int>> inputs_kvp_i32 = { | ||
| {KVP{0, 0.000001f}, 1024 * 1024, 1234ULL, KVP{10, 1.5f}}}; | ||
| MAP_TEST_KVP((MapTest<KVP, int>), MapTestKVP_i32, inputs_kvp_i32); | ||
| MAP_TEST_KVP((MapOffsetTest<KVP, int>), MapOffsetTestKVP_i32, inputs_kvp_i32); | ||
|
|
||
| const std::vector<MapInputs<KVP, size_t>> inputs_kvp_i64 = { | ||
| {KVP{0, 0.000001f}, 1024 * 1024, 1234ULL, KVP{5, 2.3f}}}; | ||
| MAP_TEST_KVP((MapTest<KVP, size_t>), MapTestKVP_i64, inputs_kvp_i64); | ||
| MAP_TEST_KVP((MapOffsetTest<KVP, size_t>), MapOffsetTestKVP_i64, inputs_kvp_i64); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only covers 8-byte variants, no?
Without tests for 4-byte and 16-byte variants, we have no CI coverage verifying those IOType specializations work correctly. The vectorized load paths differ, so a bug in one size category might not be be caught by tests of another. Should we add test cases for the other sizes to ensure full coverage.
Similar to how vectorized loads are handled for numeric datatypes, we use SFINAE principles to do vectorized loads wherever possible for KVP types (using the sizeof check at compile time).