Skip to content

Commit a60c106

Browse files
committed
cann : refactor ACL graph cache
Move the graph property checking code into methods of LRU cache. Signed-off-by: Wang Weixuan <[email protected]>
1 parent dea9ba2 commit a60c106

File tree

2 files changed

+167
-181
lines changed

2 files changed

+167
-181
lines changed

ggml/src/ggml-cann/common.h

Lines changed: 149 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,60 @@ struct ggml_graph_node_properties {
229229
// op
230230
ggml_op node_op;
231231
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
232+
233+
/**
234+
* @brief Check if a ggml tensor node matches this property set.
235+
*
236+
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
237+
* to determine whether the current node matches these previously recorded properties.
238+
*
239+
* @param node The current ggml tensor node.
240+
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
241+
*/
242+
bool has_matching_properties(ggml_tensor * node) {
243+
if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
244+
return false;
245+
}
246+
247+
if (node->op != this->node_op) {
248+
return false;
249+
}
250+
251+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
252+
if (node->ne[i] != this->ne[i]) {
253+
return false;
254+
}
255+
if (node->nb[i] != this->nb[i]) {
256+
return false;
257+
}
258+
}
259+
260+
for (int i = 0; i < GGML_MAX_SRC; i++) {
261+
if (node->src[i]) {
262+
if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
263+
return false;
264+
}
265+
266+
for (int d = 0; d < GGML_MAX_DIMS; d++) {
267+
if (node->src[i]->ne[d] != this->src_ne[i][d]) {
268+
return false;
269+
}
270+
if (node->src[i]->nb[d] != this->src_nb[i][d]) {
271+
return false;
272+
}
273+
}
274+
} else {
275+
if (this->src_address[i] != nullptr) {
276+
return false;
277+
}
278+
}
279+
}
280+
281+
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
282+
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
283+
}
284+
return true;
285+
}
232286
};
233287

234288
struct ggml_cann_graph {
@@ -241,6 +295,79 @@ struct ggml_cann_graph {
241295
aclmdlRI graph = nullptr;
242296

243297
std::vector<ggml_graph_node_properties> ggml_graph_properties;
298+
299+
/**
300+
* @brief Create a new CANN graph from a ggml computation graph.
301+
*
302+
* This function creates a new ggml_cann_graph object and fills its node properties
303+
* (operation type, dimensions, strides, input sources, and operation parameters)
304+
* based on the current ggml computation graph.
305+
*
306+
* Each node in the ggml graph is mapped to a property entry in the new CANN graph:
307+
* - node address
308+
* - operation type
309+
* - shape (ne) and strides (nb)
310+
* - source tensor addresses
311+
* - operation parameters
312+
*
313+
* @param cgraph The current ggml computation graph.
314+
* @return Pointer to the newly created ggml_cann_graph object.
315+
*/
316+
static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
317+
ggml_cann_graph * new_graph = new ggml_cann_graph();
318+
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
319+
320+
for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
321+
ggml_tensor * node = cgraph->nodes[node_idx];
322+
auto & prop = new_graph->ggml_graph_properties[node_idx];
323+
324+
prop.node_address = node->data;
325+
prop.node_op = node->op;
326+
327+
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
328+
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
329+
330+
for (int src = 0; src < GGML_MAX_SRC; ++src) {
331+
if (node->src[src]) {
332+
prop.src_address[src] = node->src[src]->data;
333+
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
334+
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
335+
} else {
336+
prop.src_address[src] = nullptr;
337+
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
338+
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
339+
}
340+
}
341+
342+
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
343+
}
344+
345+
return new_graph;
346+
}
347+
348+
/**
349+
* @brief Check whether this CANN graph matches the given ggml computation graph.
350+
*
351+
* This function compares the number of nodes and each node's properties
352+
* (operation type, dimensions, strides, inputs, and operation parameters)
353+
* to determine whether this CANN graph matches the given ggml graph.
354+
*
355+
* @param cgraph The current ggml computation graph.
356+
* @return true if this CANN graph matches the ggml graph; false otherwise.
357+
*/
358+
bool matches_cgraph(ggml_cgraph * cgraph) {
359+
if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
360+
return false;
361+
}
362+
363+
for (int i = 0; i < cgraph->n_nodes; ++i) {
364+
if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
365+
return false;
366+
}
367+
}
368+
369+
return true;
370+
}
244371
};
245372

246373
/**
@@ -272,15 +399,6 @@ struct ggml_cann_graph_lru_cache {
272399
cache_list.push_front(new_node);
273400
}
274401

275-
/**
276-
* @brief Move an existing graph to the front of the cache.
277-
* @param node Pointer to the ggml_cann_graph to move.
278-
*/
279-
void move_to_front(ggml_cann_graph * node) {
280-
cache_list.remove(node);
281-
cache_list.push_front(node);
282-
}
283-
284402
/**
285403
* @brief Clear all graphs from the cache (also frees memory).
286404
*/
@@ -295,6 +413,28 @@ struct ggml_cann_graph_lru_cache {
295413
* @brief Destructor that clears the cache and frees all cached graphs.
296414
*/
297415
~ggml_cann_graph_lru_cache() { clear(); }
416+
417+
/**
418+
* @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
419+
*
420+
* This function iterates through the cached CANN graphs stored in the LRU cache and
421+
* compares them against the given ggml computation graph. If a matching graph is found,
422+
* it is promoted to the front of the LRU cache and returned. Otherwise, the function
423+
* returns nullptr.
424+
*
425+
* @param cgraph The current ggml computation graph.
426+
* @return true if found; false otherwise.
427+
*/
428+
bool find_and_move_to_front(ggml_cgraph * cgraph) {
429+
for (auto & graph_ptr : this->cache_list) {
430+
if (graph_ptr->matches_cgraph(cgraph)) {
431+
cache_list.remove(graph_ptr);
432+
cache_list.push_front(graph_ptr);
433+
return true;
434+
}
435+
}
436+
return false;
437+
}
298438
};
299439
#endif // USE_ACL_GRAPH
300440

0 commit comments

Comments
 (0)