@@ -45,31 +45,54 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
4545 // Memory node binding
4646 if (numa_available () != -1 ) {
4747 int mem_node_id = numa_node_of_cpu (omp_cpu_ids.front ());
48- // Verify all CPUs are on the same NUMA node
49- for (size_t i = 1 ; i < omp_cpu_ids.size (); ++i) {
50- int node_id = numa_node_of_cpu (omp_cpu_ids[i]);
51- TORCH_CHECK (node_id == mem_node_id, " CPU " , omp_cpu_ids[i],
52- " is on NUMA node " , node_id, " , but CPU " ,
53- omp_cpu_ids.front (), " is on NUMA node " , mem_node_id,
54- " . All CPUs should be on the same NUMA node for optimal "
55- " performance. Memory will be bound to NUMA node " ,
56- mem_node_id, " ." );
48+ std::set<int > node_ids;
49+ for (const auto & cpu_id : omp_cpu_ids) {
50+ int node_id = numa_node_of_cpu (cpu_id);
51+ if (node_id != -1 ) {
52+ node_ids.insert (node_id);
53+ }
54+ TORCH_WARN (node_id == mem_node_id, " CPU " , cpu_id, " is on NUMA node " ,
55+ node_id, " , but CPU " , omp_cpu_ids.front (),
56+ " is on NUMA node " , mem_node_id,
57+ " . All CPUs should be on the same NUMA node for optimal "
58+ " performance. Memory will be bound to NUMA node " ,
59+ mem_node_id, " ." );
5760 }
58- bitmask* mask = numa_parse_nodestring (std::to_string (mem_node_id).c_str ());
59- bitmask* src_mask = numa_get_membind ();
60-
61- int pid = getpid ();
61+ // Concatenate all node_ids into a single comma-separated string
62+ if (!node_ids.empty ()) {
63+ std::string node_ids_str;
64+ for (const int node_id : node_ids) {
65+ if (!node_ids_str.empty ()) {
66+ node_ids_str += " ," ;
67+ }
68+ node_ids_str += std::to_string (node_id);
69+ }
6270
63- // move all existing pages to the specified numa node.
64- *(src_mask->maskp ) = *(src_mask->maskp ) ^ *(mask->maskp );
65- int page_num = numa_migrate_pages (pid, src_mask, mask);
66- if (page_num == -1 ) {
67- TORCH_WARN (" numa_migrate_pages failed. errno: " + std::to_string (errno));
71+ bitmask* mask = numa_parse_nodestring (node_ids_str.c_str ());
72+ bitmask* src_mask = numa_get_membind ();
73+
74+ int pid = getpid ();
75+
76+ if (mask && src_mask) {
77+ // move all existing pages to the specified numa node.
78+ *(src_mask->maskp ) = *(src_mask->maskp ) ^ *(mask->maskp );
79+ int page_num = numa_migrate_pages (pid, src_mask, mask);
80+ if (page_num == -1 ) {
81+ TORCH_WARN (" numa_migrate_pages failed. errno: " +
82+ std::to_string (errno));
83+ }
84+
85+ // restrict memory allocation node.
86+ numa_set_membind (mask);
87+ numa_set_strict (1 );
88+
89+ numa_free_nodemask (mask);
90+ numa_free_nodemask (src_mask);
91+ } else {
92+ TORCH_WARN (" numa_parse_nodestring or numa_get_membind failed. errno: " +
93+ std::to_string (errno));
94+ }
6895 }
69-
70- // restrict memory allocation node.
71- numa_set_membind (mask);
72- numa_set_strict (1 );
7396 }
7497
7598 // OMP threads binding
0 commit comments