11#!/usr/bin/env python3
22
3- # Copyright 2018-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+ # Copyright 2018-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44#
55# Redistribution and use in source and binary forms, with or without
66# modification, are permitted provided that the following conditions
@@ -1367,11 +1367,13 @@ def shm_basic_infer(
13671367 big_shm_name = "" ,
13681368 big_shm_size = 64 ,
13691369 default_shm_byte_size = 64 ,
1370+ register_offset = 0 ,
13701371 shm_output_offset = 0 ,
13711372 shm_output_byte_size = 64 ,
13721373 protocol = "http" ,
13731374 use_system_shared_memory = False ,
13741375 use_cuda_shared_memory = False ,
1376+ override_model_name = None ,
13751377):
13761378 # Lazy shm imports...
13771379 if use_system_shared_memory :
@@ -1381,20 +1383,34 @@ def shm_basic_infer(
13811383 else :
13821384 raise Exception ("No shared memory type specified" )
13831385
1386+ if override_model_name is None :
1387+ model_name = "simple"
1388+ else :
1389+ model_name = override_model_name
1390+
1391+ if model_name .startswith ("libtorch" ):
1392+ output_names = ["OUTPUT__0" , "OUTPUT__1" ]
1393+ else :
1394+ output_names = ["OUTPUT0" , "OUTPUT1" ]
1395+
13841396 input0_data = np .arange (start = 0 , stop = 16 , dtype = np .int32 )
13851397 input1_data = np .ones (shape = 16 , dtype = np .int32 )
13861398 inputs = []
13871399 outputs = []
13881400 if protocol == "http" :
13891401 inputs .append (httpclient .InferInput ("INPUT0" , [1 , 16 ], "INT32" ))
13901402 inputs .append (httpclient .InferInput ("INPUT1" , [1 , 16 ], "INT32" ))
1391- outputs .append (httpclient .InferRequestedOutput ("OUTPUT0" , binary_data = True ))
1392- outputs .append (httpclient .InferRequestedOutput ("OUTPUT1" , binary_data = False ))
1403+ outputs .append (
1404+ httpclient .InferRequestedOutput (output_names [0 ], binary_data = True )
1405+ )
1406+ outputs .append (
1407+ httpclient .InferRequestedOutput (output_names [1 ], binary_data = False )
1408+ )
13931409 else :
13941410 inputs .append (grpcclient .InferInput ("INPUT0" , [1 , 16 ], "INT32" ))
13951411 inputs .append (grpcclient .InferInput ("INPUT1" , [1 , 16 ], "INT32" ))
1396- outputs .append (grpcclient .InferRequestedOutput ("OUTPUT0" ))
1397- outputs .append (grpcclient .InferRequestedOutput ("OUTPUT1" ))
1412+ outputs .append (grpcclient .InferRequestedOutput (output_names [ 0 ] ))
1413+ outputs .append (grpcclient .InferRequestedOutput (output_names [ 1 ] ))
13981414
13991415 inputs [0 ].set_shared_memory ("input0_data" , default_shm_byte_size )
14001416
@@ -1414,9 +1430,9 @@ def shm_basic_infer(
14141430
14151431 try :
14161432 results = triton_client .infer (
1417- "simple" , inputs , model_version = "" , outputs = outputs
1433+ model_name , inputs , model_version = "" , outputs = outputs
14181434 )
1419- output = results .get_output ("OUTPUT0" )
1435+ output = results .get_output (output_names [ 0 ] )
14201436 if protocol == "http" :
14211437 output_datatype = output ["datatype" ]
14221438 output_shape = output ["shape" ]
@@ -1427,11 +1443,16 @@ def shm_basic_infer(
14271443
14281444 if use_system_shared_memory :
14291445 output_data = shm .get_contents_as_numpy (
1430- shm_op0_handle , output_dtype , output_shape
1446+ shm_op0_handle ,
1447+ output_dtype ,
1448+ output_shape ,
1449+ offset = register_offset + shm_output_offset ,
14311450 )
14321451 elif use_cuda_shared_memory :
14331452 output_data = cudashm .get_contents_as_numpy (
1434- shm_op0_handle , output_dtype , output_shape
1453+ shm_op0_handle ,
1454+ output_dtype ,
1455+ output_shape ,
14351456 )
14361457
14371458 tester .assertTrue (
0 commit comments