Skip to content

Commit 64bf1c3

Browse files
authored
rpc : check for null buffers in get/set/copy tensor endpoints (ggml-org#14868)
1 parent c12bbde commit 64bf1c3

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
10551055
GGML_ASSERT(ctx_ptr != nullptr);
10561056
ggml_context * ctx = ctx_ptr.get();
10571057
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
1058-
if (tensor == nullptr) {
1058+
if (tensor == nullptr || tensor->buffer == nullptr) {
10591059
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
10601060
return false;
10611061
}
@@ -1124,7 +1124,7 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp
11241124
GGML_ASSERT(ctx_ptr != nullptr);
11251125
ggml_context * ctx = ctx_ptr.get();
11261126
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1127-
if (tensor == nullptr) {
1127+
if (tensor == nullptr || tensor->buffer == nullptr) {
11281128
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
11291129
return false;
11301130
}
@@ -1192,7 +1192,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
11921192
GGML_ASSERT(ctx_ptr != nullptr);
11931193
ggml_context * ctx = ctx_ptr.get();
11941194
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
1195-
if (tensor == nullptr) {
1195+
if (tensor == nullptr || tensor->buffer == nullptr) {
11961196
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
11971197
return false;
11981198
}
@@ -1229,7 +1229,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
12291229

12301230
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
12311231
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
1232-
if (src == nullptr || dst == nullptr) {
1232+
if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
12331233
GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
12341234
return false;
12351235
}

0 commit comments

Comments
 (0)