Skip to content

Commit 1b7d972

Browse files
committed
feat(rpc): comp-time op metadata and graph validation
Refactor `ggml_op` and `GGML_OP_METADATA` using X-Macros. This ensures compile-time synchronization between the enum and metadata. `ggml_op_metadata_check()` verifies this at compile time during `ggml_init`. This enables robust graph validation in the RPC server. Previously, malformed graphs (e.g., ADD with NULL src[1]) could cause crashes. `validate_graph_operands` now uses the X-Macro-generated metadata (`ggml_op_get_n_src`) to check for required non-null source operands before execution. Invalid graphs are rejected early. Adds `test_op_metadata_counts` to verify the metadata system. Signed-off-by: Ville Vesilehto <[email protected]>
1 parent 6a2bc8b commit 1b7d972

File tree

4 files changed

+221
-91
lines changed

4 files changed

+221
-91
lines changed

ggml/include/ggml.h

Lines changed: 123 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -425,102 +425,119 @@ extern "C" {
425425
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
426426
};
427427

428+
// Helper macro for the X-Macro list of operations
429+
// Format: XX(op_name, n_src_value)
430+
#define GGML_OP_LIST(XX) \
431+
XX(GGML_OP_NONE, 0) \
432+
XX(GGML_OP_DUP, 1) \
433+
XX(GGML_OP_ADD, 2) \
434+
XX(GGML_OP_ADD1, 2) \
435+
XX(GGML_OP_ACC, 2) \
436+
XX(GGML_OP_SUB, 2) \
437+
XX(GGML_OP_MUL, 2) \
438+
XX(GGML_OP_DIV, 2) \
439+
XX(GGML_OP_SQR, 1) \
440+
XX(GGML_OP_SQRT, 1) \
441+
XX(GGML_OP_LOG, 1) \
442+
XX(GGML_OP_SIN, 1) \
443+
XX(GGML_OP_COS, 1) \
444+
XX(GGML_OP_SUM, 1) \
445+
XX(GGML_OP_SUM_ROWS, 1) \
446+
XX(GGML_OP_MEAN, 1) \
447+
XX(GGML_OP_ARGMAX, 1) \
448+
XX(GGML_OP_COUNT_EQUAL, 2) \
449+
XX(GGML_OP_REPEAT, 2) \
450+
XX(GGML_OP_REPEAT_BACK, 2) \
451+
XX(GGML_OP_CONCAT, 2) \
452+
XX(GGML_OP_SILU_BACK, 2) \
453+
XX(GGML_OP_NORM, 1) \
454+
XX(GGML_OP_RMS_NORM, 1) \
455+
XX(GGML_OP_RMS_NORM_BACK, 2) \
456+
XX(GGML_OP_GROUP_NORM, 1) \
457+
XX(GGML_OP_L2_NORM, 1) \
458+
XX(GGML_OP_MUL_MAT, 2) \
459+
XX(GGML_OP_MUL_MAT_ID, 3) \
460+
XX(GGML_OP_OUT_PROD, 2) \
461+
XX(GGML_OP_SCALE, 1) \
462+
XX(GGML_OP_SET, 2) \
463+
XX(GGML_OP_CPY, 2) \
464+
XX(GGML_OP_CONT, 1) \
465+
XX(GGML_OP_RESHAPE, 1) \
466+
XX(GGML_OP_VIEW, 1) \
467+
XX(GGML_OP_PERMUTE, 1) \
468+
XX(GGML_OP_TRANSPOSE, 1) \
469+
XX(GGML_OP_GET_ROWS, 2) \
470+
XX(GGML_OP_GET_ROWS_BACK, 3) \
471+
XX(GGML_OP_DIAG, 1) \
472+
XX(GGML_OP_DIAG_MASK_INF, 1) \
473+
XX(GGML_OP_DIAG_MASK_ZERO, 1) \
474+
XX(GGML_OP_SOFT_MAX, 2) \
475+
XX(GGML_OP_SOFT_MAX_BACK, 2) \
476+
XX(GGML_OP_ROPE, 3) \
477+
XX(GGML_OP_ROPE_BACK, 3) \
478+
XX(GGML_OP_CLAMP, 1) \
479+
XX(GGML_OP_CONV_TRANSPOSE_1D, 2) \
480+
XX(GGML_OP_IM2COL, 2) \
481+
XX(GGML_OP_IM2COL_BACK, 2) \
482+
XX(GGML_OP_CONV_2D_DW, 2) \
483+
XX(GGML_OP_CONV_TRANSPOSE_2D, 2) \
484+
XX(GGML_OP_POOL_1D, 1) \
485+
XX(GGML_OP_POOL_2D, 1) \
486+
XX(GGML_OP_POOL_2D_BACK, 2) \
487+
XX(GGML_OP_UPSCALE, 1) \
488+
XX(GGML_OP_PAD, 1) \
489+
XX(GGML_OP_PAD_REFLECT_1D, 1) \
490+
XX(GGML_OP_ARANGE, 0) \
491+
XX(GGML_OP_TIMESTEP_EMBEDDING, 1) \
492+
XX(GGML_OP_ARGSORT, 1) \
493+
XX(GGML_OP_LEAKY_RELU, 1) \
494+
XX(GGML_OP_FLASH_ATTN_EXT, 4) \
495+
XX(GGML_OP_FLASH_ATTN_BACK, 4) \
496+
XX(GGML_OP_SSM_CONV, 2) \
497+
XX(GGML_OP_SSM_SCAN, 6) \
498+
XX(GGML_OP_WIN_PART, 1) \
499+
XX(GGML_OP_WIN_UNPART, 1) \
500+
XX(GGML_OP_GET_REL_POS, 1) \
501+
XX(GGML_OP_ADD_REL_POS, 3) \
502+
XX(GGML_OP_RWKV_WKV6, 6) \
503+
XX(GGML_OP_GATED_LINEAR_ATTN, 5) \
504+
XX(GGML_OP_RWKV_WKV7, 7) \
505+
XX(GGML_OP_UNARY, 1) \
506+
XX(GGML_OP_MAP_CUSTOM1, 1) \
507+
XX(GGML_OP_MAP_CUSTOM2, 2) \
508+
XX(GGML_OP_MAP_CUSTOM3, 3) \
509+
XX(GGML_OP_CUSTOM, -1) \
510+
XX(GGML_OP_CROSS_ENTROPY_LOSS, 2) \
511+
XX(GGML_OP_CROSS_ENTROPY_LOSS_BACK, 3) \
512+
XX(GGML_OP_OPT_STEP_ADAMW, 5)
513+
428514
// available tensor operations:
429515
enum ggml_op {
430-
GGML_OP_NONE = 0,
431-
432-
GGML_OP_DUP,
433-
GGML_OP_ADD,
434-
GGML_OP_ADD1,
435-
GGML_OP_ACC,
436-
GGML_OP_SUB,
437-
GGML_OP_MUL,
438-
GGML_OP_DIV,
439-
GGML_OP_SQR,
440-
GGML_OP_SQRT,
441-
GGML_OP_LOG,
442-
GGML_OP_SIN,
443-
GGML_OP_COS,
444-
GGML_OP_SUM,
445-
GGML_OP_SUM_ROWS,
446-
GGML_OP_MEAN,
447-
GGML_OP_ARGMAX,
448-
GGML_OP_COUNT_EQUAL,
449-
GGML_OP_REPEAT,
450-
GGML_OP_REPEAT_BACK,
451-
GGML_OP_CONCAT,
452-
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
454-
GGML_OP_RMS_NORM,
455-
GGML_OP_RMS_NORM_BACK,
456-
GGML_OP_GROUP_NORM,
457-
GGML_OP_L2_NORM,
458-
459-
GGML_OP_MUL_MAT,
460-
GGML_OP_MUL_MAT_ID,
461-
GGML_OP_OUT_PROD,
462-
463-
GGML_OP_SCALE,
464-
GGML_OP_SET,
465-
GGML_OP_CPY,
466-
GGML_OP_CONT,
467-
GGML_OP_RESHAPE,
468-
GGML_OP_VIEW,
469-
GGML_OP_PERMUTE,
470-
GGML_OP_TRANSPOSE,
471-
GGML_OP_GET_ROWS,
472-
GGML_OP_GET_ROWS_BACK,
473-
GGML_OP_DIAG,
474-
GGML_OP_DIAG_MASK_INF,
475-
GGML_OP_DIAG_MASK_ZERO,
476-
GGML_OP_SOFT_MAX,
477-
GGML_OP_SOFT_MAX_BACK,
478-
GGML_OP_ROPE,
479-
GGML_OP_ROPE_BACK,
480-
GGML_OP_CLAMP,
481-
GGML_OP_CONV_TRANSPOSE_1D,
482-
GGML_OP_IM2COL,
483-
GGML_OP_IM2COL_BACK,
484-
GGML_OP_CONV_2D_DW,
485-
GGML_OP_CONV_TRANSPOSE_2D,
486-
GGML_OP_POOL_1D,
487-
GGML_OP_POOL_2D,
488-
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
490-
GGML_OP_PAD,
491-
GGML_OP_PAD_REFLECT_1D,
492-
GGML_OP_ARANGE,
493-
GGML_OP_TIMESTEP_EMBEDDING,
494-
GGML_OP_ARGSORT,
495-
GGML_OP_LEAKY_RELU,
496-
497-
GGML_OP_FLASH_ATTN_EXT,
498-
GGML_OP_FLASH_ATTN_BACK,
499-
GGML_OP_SSM_CONV,
500-
GGML_OP_SSM_SCAN,
501-
GGML_OP_WIN_PART,
502-
GGML_OP_WIN_UNPART,
503-
GGML_OP_GET_REL_POS,
504-
GGML_OP_ADD_REL_POS,
505-
GGML_OP_RWKV_WKV6,
506-
GGML_OP_GATED_LINEAR_ATTN,
507-
GGML_OP_RWKV_WKV7,
508-
509-
GGML_OP_UNARY,
510-
511-
GGML_OP_MAP_CUSTOM1,
512-
GGML_OP_MAP_CUSTOM2,
513-
GGML_OP_MAP_CUSTOM3,
514-
515-
GGML_OP_CUSTOM,
516-
517-
GGML_OP_CROSS_ENTROPY_LOSS,
518-
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519-
GGML_OP_OPT_STEP_ADAMW,
520-
516+
#define GGML_OP_ENUM_MEMBER(op_name, ...) op_name,
517+
GGML_OP_LIST(GGML_OP_ENUM_MEMBER)
518+
#undef GGML_OP_ENUM_MEMBER
521519
GGML_OP_COUNT,
522520
};
523521

522+
// metadata for ggml_op
523+
typedef struct {
524+
int n_src; // number of arguments
525+
} ggml_op_metadata_t;
526+
527+
static const ggml_op_metadata_t GGML_OP_METADATA[GGML_OP_COUNT] = {
528+
#define GGML_OP_METADATA_ENTRY(op_name, n_src_val) [op_name] = {.n_src = n_src_val},
529+
GGML_OP_LIST(GGML_OP_METADATA_ENTRY)
530+
#undef GGML_OP_METADATA_ENTRY
531+
};
532+
533+
// Inline function to get the number of source operands for an operation
534+
static inline int ggml_op_get_n_src(enum ggml_op op) {
535+
if (op >= 0 && op < GGML_OP_COUNT) {
536+
return GGML_OP_METADATA[op].n_src;
537+
}
538+
return -2; // invalid op
539+
}
540+
524541
enum ggml_unary_op {
525542
GGML_UNARY_OP_ABS,
526543
GGML_UNARY_OP_SGN,
@@ -2186,6 +2203,21 @@ extern "C" {
21862203
GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
21872204
GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
21882205

2206+
// Compile-time check helper function
2207+
// Asserts that GGML_OP_METADATA is updated when ggml_op changes.
2208+
// Ensure all ggml_op values are handled. Missing case = compile error.
2209+
// Relies on compiler warnings treated as errors (-Werror=switch-enum or similar).
2210+
static inline void ggml_op_metadata_check(void) {
2211+
enum ggml_op op = GGML_OP_NONE; // Dummy value
2212+
switch (op) {
2213+
#define GGML_OP_SWITCH_CASE(op_name, ...) case op_name: (void)GGML_OP_METADATA[op_name].n_src; break;
2214+
GGML_OP_LIST(GGML_OP_SWITCH_CASE)
2215+
#undef GGML_OP_SWITCH_CASE
2216+
case GGML_OP_COUNT: break;
2217+
// NOTE: No default case. Compiler warning/error for unhandled enum value is the goal.
2218+
}
2219+
}
2220+
21892221
#ifdef __cplusplus
21902222
}
21912223
#endif

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,43 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
754754
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
755755
}
756756

757+
// Helper function to validate graph operands before computation
758+
static bool validate_graph_operands(const ggml_cgraph *graph) {
759+
GGML_PRINT_DEBUG("[%s] Validating graph with %d nodes\n", __func__, graph->n_nodes);
760+
for (uint32_t i = 0; i < (uint32_t)graph->n_nodes; ++i) {
761+
const ggml_tensor* node = graph->nodes[i];
762+
// Initial null check added for safety.
763+
if (node == nullptr) {
764+
GGML_LOG_ERROR("[%s] Graph node %d is null.\n", __func__, i);
765+
return false;
766+
}
767+
768+
const int n_src = ggml_op_get_n_src(node->op);
769+
770+
if (n_src == -1) {
771+
// Ops like GGML_OP_CUSTOM have variable inputs, cannot validate here.
772+
GGML_PRINT_DEBUG("[%s] Skipping operand validation for node %d (op %s, name '%s') with variable inputs.\n", __func__, i, ggml_op_name(node->op), node->name);
773+
continue;
774+
} else if (n_src == -2) {
775+
GGML_LOG_ERROR("[%s] Graph node %d (name '%s') has invalid op type %d.\n", __func__, i, node->name, (int)node->op);
776+
return false;
777+
} else if (n_src > GGML_MAX_SRC) {
778+
GGML_LOG_ERROR("[%s] Graph node %d (op %s, name '%s') requires %d sources, exceeding GGML_MAX_SRC (%d).\n", __func__, i, ggml_op_name(node->op), node->name, n_src, GGML_MAX_SRC);
779+
return false;
780+
}
781+
782+
// Check required source operands
783+
for (int s_idx = 0; s_idx < n_src; ++s_idx) {
784+
if (node->src[s_idx] == nullptr) {
785+
GGML_LOG_ERROR("[%s] Graph node %d (op %s, name '%s') missing required input src[%d].\n", __func__, i, ggml_op_name(node->op), node->name, s_idx);
786+
return false;
787+
}
788+
}
789+
}
790+
GGML_PRINT_DEBUG("[%s] Graph validation successful\n", __func__);
791+
return true;
792+
}
793+
757794
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
758795
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
759796
std::vector<uint8_t> input;
@@ -1357,6 +1394,11 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13571394
return false;
13581395
}
13591396
}
1397+
1398+
if (!validate_graph_operands(graph)) {
1399+
return false;
1400+
}
1401+
13601402
ggml_status status = ggml_backend_graph_compute(backend, graph);
13611403
response.result = status;
13621404
return true;

ggml/src/ggml.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
13871387
// initialize time system (required on Windows)
13881388
ggml_time_init();
13891389

1390+
ggml_op_metadata_check();
1391+
1392+
13901393
for (int i = 0; i < (1 << 16); ++i) {
13911394
union {
13921395
uint16_t u16;

tests/test-backend-ops.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4546,6 +4546,59 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
45464546
test_cases.emplace_back(new test_falcon(2));
45474547
#endif
45484548

4549+
// Verify that the ggml_op_metadata_t correctly validates n_src
4550+
{
4551+
struct test_op_metadata_counts : public test_case {
4552+
std::string op_desc(ggml_tensor * t) override {
4553+
GGML_UNUSED(t);
4554+
return "OP_METADATA_COUNTS";
4555+
}
4556+
4557+
ggml_tensor * build_graph(ggml_context * ctx) override {
4558+
bool all_passed = true;
4559+
4560+
struct {
4561+
ggml_op op;
4562+
int expected_n_src;
4563+
const char* name;
4564+
} test_ops[] = {
4565+
{GGML_OP_NONE, 0, "NONE"},
4566+
{GGML_OP_UNARY, 1, "UNARY"},
4567+
{GGML_OP_ADD, 2, "ADD"},
4568+
{GGML_OP_MUL, 2, "MUL"},
4569+
{GGML_OP_ROPE, 3, "ROPE"},
4570+
{GGML_OP_FLASH_ATTN_EXT, 4, "FLASH_ATTN_EXT"},
4571+
{GGML_OP_CUSTOM, -1, "CUSTOM"}
4572+
};
4573+
4574+
// Test each operation's metadata
4575+
for (const auto& test : test_ops) {
4576+
int n_src = ggml_op_get_n_src(test.op);
4577+
if (n_src != test.expected_n_src) {
4578+
fprintf(stderr, "ERROR: Expected n_src=%d for GGML_OP_%s but got %d\n",
4579+
test.expected_n_src, test.name, n_src);
4580+
all_passed = false;
4581+
}
4582+
}
4583+
4584+
if (!all_passed) {
4585+
GGML_ASSERT("One or more metadata checks failed");
4586+
}
4587+
4588+
// Create a dummy tensor that will be used for backend comparison
4589+
ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 10);
4590+
ggml_set_name(a, "a");
4591+
4592+
ggml_tensor * result = ggml_scale(ctx, a, 1.0f);
4593+
ggml_set_name(result, "result");
4594+
4595+
return result;
4596+
}
4597+
};
4598+
4599+
test_cases.push_back(std::make_unique<test_op_metadata_counts>());
4600+
}
4601+
45494602
return test_cases;
45504603
}
45514604

0 commit comments

Comments
 (0)