Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions src/slangpy_ext/utils/slangpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,38 @@ namespace {
}
} // anonymous namespace

// Type cache for get_this/update_this method presence.
static std::unordered_map<PyTypeObject*, bool> s_has_get_this_cache;
static std::unordered_map<PyTypeObject*, bool> s_has_update_this_cache;

inline bool has_get_this(nb::handle obj)
{
PyTypeObject* type = Py_TYPE(obj.ptr());

auto it = s_has_get_this_cache.find(type);
if (it != s_has_get_this_cache.end())
return it->second;

bool result = nb::hasattr(obj, "get_this");
s_has_get_this_cache[type] = result;

return result;
}

inline bool has_update_this(nb::handle obj)
{
PyTypeObject* type = Py_TYPE(obj.ptr());

auto it = s_has_update_this_cache.find(type);
if (it != s_has_update_this_cache.end())
return it->second;

bool result = nb::hasattr(obj, "update_this");
s_has_update_this_cache[type] = result;

return result;
}

// Implementation of to_string methods
std::string NativeSlangType::to_string() const
{
Expand Down Expand Up @@ -921,10 +953,9 @@ void NativeCallDataCache::get_value_signature(const ref<SignatureBuilder> builde
auto type_name = nb::str(nb::getattr(o.type(), "__name__"));
*builder << type_name.c_str() << "\n";

// Handle objects with get_this method.
auto get_this = nb::getattr(o, "get_this", nb::none());
if (!get_this.is_none()) {
auto this_ = get_this();
// Handle objects with get_this method (cached type check).
if (has_get_this(o)) {
auto this_ = nb::getattr(o, "get_this")();
get_value_signature(builder, this_);
return;
}
Expand Down Expand Up @@ -1018,8 +1049,8 @@ nb::object unpack_arg(nb::object arg, std::optional<nb::list> refs)
{
auto obj = arg;

// If object has 'get_this', read it.
if (nb::hasattr(obj, "get_this")) {
// If object has 'get_this', read it (cached type check).
if (has_get_this(obj)) {
obj = nb::getattr(obj, "get_this")();
}

Expand Down Expand Up @@ -1060,8 +1091,8 @@ nb::object unpack_arg(nb::object arg, std::optional<nb::list> refs)

void pack_arg(nanobind::object arg, nanobind::object unpacked_arg)
{
// If object has 'update_this', update it.
if (nb::hasattr(arg, "update_this")) {
// If object has 'update_this', update it (cached type check).
if (has_update_this(arg)) {
nb::getattr(arg, "update_this")(unpacked_arg);
}

Expand Down