@@ -31,7 +31,12 @@ using namespace executorch::backends::aoti;
3131
3232// Global storage for tensors and their metadata
3333std::unordered_set<std::shared_ptr<Tensor>> tensors;
34- std::unordered_map<Tensor*, bool > is_tensor_own_memory;
34+
35+ // Reference counting for memory addresses
36+ // Maps memory address to number of tensors using it
37+ // Special value: NOT_OWN (-1) means tensor never owns the memory
38+ constexpr int32_t NOT_OWN = -1 ;
39+ std::unordered_map<void *, int32_t > memory_to_n_tensor;
3540
3641extern " C" {
3742
@@ -110,7 +115,18 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
110115 // Store the tensor so it doesn't get destroyed
111116 tensors.insert (tensor);
112117 *ret_new_tensor = tensor.get ();
113- is_tensor_own_memory[tensor.get ()] = false ;
118+
119+ // Check if this memory address is already being tracked
120+ auto memory_it = memory_to_n_tensor.find (adjusted_data);
121+ ET_CHECK_OR_RETURN_ERROR (
122+ memory_it == memory_to_n_tensor.end (),
123+ InvalidArgument,
124+ " Memory address %p is already being tracked by another tensor" ,
125+ adjusted_data);
126+
127+ // Mark this memory as NOT_OWN since tensor created from blob never owns
128+ // memory
129+ memory_to_n_tensor[adjusted_data] = NOT_OWN;
114130
115131 ET_LOG (Debug, " aoti_torch_create_tensor_from_blob_v2: successfull" );
116132 return Error::Ok;
@@ -192,59 +208,98 @@ AOTITorchError aoti_torch_empty_strided(
192208 // Store the tensor so it doesn't get destroyed
193209 tensors.insert (tensor);
194210 *ret_new_tensor = tensor.get ();
195- is_tensor_own_memory[tensor.get ()] = true ;
211+
212+ // This tensor owns the memory it allocated, set reference count to 1
213+ memory_to_n_tensor[ptr] = 1 ;
196214
197215 ET_LOG (Debug, " aoti_torch_empty_strided: successfull" );
198216 return Error::Ok;
199217}
200218
201219AOTITorchError aoti_torch_delete_tensor_object (AOTITensorHandle tensor) {
202220 ET_LOG (Debug, " aoti_torch_delete_tensor_object: entered" );
203- // Find tensor in the set
221+
222+ // Handle null tensor pointer
223+ if (tensor == nullptr ) {
224+ ET_LOG (Debug, " aoti_torch_delete_tensor_object: null tensor" );
225+ return Error::Ok;
226+ }
227+
228+ // Check if tensor exists in our tracking
229+ bool found_in_tensors = false ;
204230 for (auto it = tensors.begin (); it != tensors.end (); ++it) {
205231 if (it->get () == tensor) {
206- auto tensor_ptr = *it;
232+ found_in_tensors = true ;
233+ break ;
234+ }
235+ }
207236
208- // Check ownership before cleaning up
209- auto ownership_it = is_tensor_own_memory.find (tensor);
210- bool owns_memory = (ownership_it != is_tensor_own_memory.end ())
211- ? ownership_it->second
212- : false ;
237+ // If tensor not found in our tracking, it's invalid
238+ ET_CHECK_OR_RETURN_ERROR (
239+ found_in_tensors, InvalidArgument, " Didn't find tensor %p" , tensor);
213240
214- // Clean up ownership metadata
215- is_tensor_own_memory.erase (tensor);
241+ // Find and delete the tensor
242+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
243+ if (it->get () == tensor) {
244+ // Get the tensor before erasing
245+ auto tensor_ptr = *it;
246+ void * data_ptr = tensor_ptr->mutable_data_ptr ();
216247
217- if (owns_memory) {
218- // et tensor owns the memory; need to free it manually
219- void * data_ptr = tensor_ptr->mutable_data_ptr ();
248+ // Find the reference count for this memory address
249+ auto memory_it = memory_to_n_tensor.find (data_ptr);
250+ if (memory_it != memory_to_n_tensor.end ()) {
251+ int32_t ref_count = memory_it->second ;
220252
221- // Check if it's Metal GPU memory
222- if (metal_is_device_pointer (data_ptr)) {
223- // This is Metal GPU memory - the Metal helper will handle cleanup
224- // Metal buffers are automatically managed by ARC when the buffer is
225- // released
253+ if (ref_count == NOT_OWN) {
254+ // Tensor never owned the memory, skip freeing
255+ // Just remove tensor from tracking
226256 tensors.erase (it);
227257 ET_LOG (
228258 Debug,
229- " aoti_torch_delete_tensor_object: successfull (Metal GPU memory) " );
259+ " aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free " );
230260 return Error::Ok;
261+ } else if (ref_count == 1 ) {
262+ // Only current tensor using this memory, free it
263+ // Check if it's Metal GPU memory
264+ if (metal_is_device_pointer (data_ptr)) {
265+ metal_deallocate_buffer (data_ptr);
266+ } else {
267+ // This is CPU memory - free immediately
268+ free (data_ptr);
269+ data_ptr = nullptr ;
270+ ET_LOG (
271+ Debug, " aoti_torch_delete_tensor_object: freeing CPU memory" );
272+ }
273+
274+ // Remove from memory tracking
275+ memory_to_n_tensor.erase (memory_it);
276+ } else if (ref_count > 1 ) {
277+ // Other tensors still using this memory, just decrement count
278+ memory_to_n_tensor[data_ptr] = ref_count - 1 ;
279+ ET_LOG (
280+ Debug,
281+ " aoti_torch_delete_tensor_object: decremented ref count from %d to %d" ,
282+ ref_count,
283+ ref_count - 1 );
231284 }
232-
233- // This is CPU memory - free immediately
234- free (data_ptr);
285+ } else {
286+ ET_CHECK_OR_RETURN_ERROR (
287+ false ,
288+ Internal,
289+ " Internal error: memory not found during deletion" );
235290 }
236- // else: Don't free memory since the tensor doesn't own it
237291
238- // Remove from set (this will call the destructor if it's the last
292+ // Remove tensor from set (this will call the destructor if it's the last
239293 // reference)
240294 tensors.erase (it);
241- ET_LOG (
242- Debug, " aoti_torch_delete_tensor_object: successfull (CPU memory)" );
295+ ET_LOG (Debug, " aoti_torch_delete_tensor_object: successfull" );
243296 return Error::Ok;
244297 }
245298 }
246- ET_LOG (Error, " Didn't find tensor %p" , tensor);
247- return Error::InvalidArgument;
299+
300+ // This should never be reached since we found it above
301+ ET_CHECK_OR_RETURN_ERROR (
302+ false , Internal, " Internal error: tensor not found after validation" );
248303}
249304
250305AOTITorchError aoti_torch_copy_ (
@@ -375,75 +430,105 @@ AOTITorchError aoti_torch__reinterpret_tensor(
375430 InvalidArgument,
376431 " aoti_torch__reinterpret_tensor failed: ret_new_tensor is null" );
377432
433+ // Check if storage_offset is not 0 - return error if not
434+ ET_CHECK_OK_OR_RETURN_ERROR (validate_storage_offset (storage_offset));
435+
436+ // Get the device info from the source tensor to perform device_index
437+ // validation
438+ int32_t device_type = 0 ;
439+ int32_t device_index = 0 ;
440+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_type (self, &device_type));
441+
442+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_index (self, &device_index));
443+
444+ // Ensure device_index is always 0
445+ ET_CHECK_OR_RETURN_ERROR (
446+ device_index == 0 ,
447+ InvalidArgument,
448+ " device_index must be 0, got: %d" ,
449+ device_index);
450+
378451 // Get the dtype from the source tensor
379452 int32_t dtype = 0 ;
380453 ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_dtype (self, &dtype));
381454
382455 // Validate dtype using SupportedDTypes
383456 ET_CHECK_OK_OR_RETURN_ERROR (validate_dtype (dtype));
384457
385- int32_t device_type = 0 ;
386- ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_type (self, &device_type));
458+ // Get the original data pointer from the source tensor
459+ void * data_ptr = self->mutable_data_ptr ();
460+ ET_CHECK_OR_RETURN_ERROR (
461+ data_ptr != nullptr ,
462+ InvalidArgument,
463+ " Source tensor has null data pointer" );
387464
388- int32_t device_index = 0 ;
389- ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_index (self, &device_index));
465+ // Check if the given memory is in the map, if not return error
466+ auto memory_it = memory_to_n_tensor.find (data_ptr);
467+ ET_CHECK_OR_RETURN_ERROR (
468+ memory_it != memory_to_n_tensor.end (),
469+ InvalidArgument,
470+ " Memory address %p is not being tracked by reference counting system" ,
471+ data_ptr);
472+
473+ // Convert sizes using utility function from utils.h
474+ std::vector<aten::SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
475+
476+ // Convert strides using utility function from utils.h
477+ std::vector<aten::StridesType> strides =
478+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
479+
480+ // Create new tensor view that reinterprets the same memory with different
481+ // shape/strides This creates a view, not a copy - the data pointer is shared
482+ std::shared_ptr<Tensor> tensor = executorch::extension::from_blob (
483+ data_ptr, // Reuse the same memory from source tensor
484+ sizes, // New sizes with explicit SizesType
485+ strides, // New strides with explicit StridesType
486+ dtype_to_scalar_type (dtype) // Convert dtype with explicit type casting
487+ );
390488
391- // Get the base data pointer from the source tensor
392- void * base_data_ptr = self->mutable_data_ptr ();
393489 ET_CHECK_OR_RETURN_ERROR (
394- base_data_ptr != nullptr ,
490+ tensor != nullptr ,
395491 InvalidArgument,
396- " Source tensor has null data pointer " );
492+ " Failed to create reinterpreted tensor view " );
397493
398- // Calculate new tensor size in elements for logging
399- int64_t new_numel = 1 ;
400- for (int64_t i = 0 ; i < ndim; i++) {
401- new_numel *= sizes_ptr[i];
402- }
494+ // Store the tensor so it doesn't get destroyed
495+ tensors.insert (tensor);
403496
404- ET_LOG (
405- Debug,
406- " aoti_torch__reinterpret_tensor: base_data_ptr=%p, new_numel=%lld, storage_offset=%lld" ,
407- base_data_ptr,
408- new_numel,
409- storage_offset);
410-
411- // Create a new tensor view that shares the same underlying storage
412- // This is the correct way to implement reinterpret_tensor - as a view, not a
413- // copy
414- AOTITorchError create_err = aoti_torch_create_tensor_from_blob_v2 (
415- base_data_ptr, // Same underlying data pointer
416- ndim, // New dimensions
417- sizes_ptr, // New sizes
418- strides_ptr, // New strides
419- storage_offset, // Storage offset (will be handled properly now)
420- dtype,
421- device_type,
422- device_index,
423- ret_new_tensor,
424- 0 , // layout (default)
425- nullptr , // opaque_metadata
426- 0 // opaque_metadata_size
427- );
497+ *ret_new_tensor = tensor.get ();
428498
429- if (create_err != Error::Ok) {
430- ET_LOG (Error, " failed to create reinterpreted tensor view" );
431- return create_err;
432- }
499+ // Increment the reference count for this memory address only if it is owned
500+ // by tensor
501+ memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
502+ ? NOT_OWN
503+ : memory_to_n_tensor[data_ptr] + 1 ;
433504
434505 ET_LOG (Debug, " aoti_torch__reinterpret_tensor: successfull" );
435506 return Error::Ok;
436507}
437508
438509// Cleanup function for clearing global state
439510void cleanup_memory () {
440- is_tensor_own_memory.clear ();
441- if (!tensors.empty ()) {
442- ET_LOG (Error, " Warning: tensors not empty during cleanup" );
511+ // Use aoti_torch_delete_tensor_object to properly delete each tensor
512+ // Note: We need to collect tensor pointers first since deletion modifies the
513+ // set
514+ std::vector<Tensor*> tensor_ptrs;
515+ tensor_ptrs.reserve (tensors.size ());
516+ for (const auto & tensor_shared : tensors) {
517+ tensor_ptrs.push_back (tensor_shared.get ());
443518 }
444519
520+ // Now delete each tensor - this will modify the global tensors set
521+ for (Tensor* tensor_ptr : tensor_ptrs) {
522+ aoti_torch_delete_tensor_object (tensor_ptr);
523+ }
524+
525+ // tensors set should now be empty, but ensure it's cleared
526+ tensors.clear ();
527+
445528 // Clean up Metal resources
446529 metal_cleanup_resources ();
530+
531+ ET_LOG (Info, " Cleared all tensors and Metal resources" );
447532}
448533
449534} // extern "C"
0 commit comments