@@ -775,11 +775,17 @@ def codegen_kernel(self, name: Optional[str] = None) -> str:
775775 """Called at the end to generate a final kernel string"""
776776 self .codegen_body ()
777777 code = IndentedBuffer ()
778- code .writeline ("compile_mps_shader('''" )
778+
779+ if V .graph .cpp_wrapper :
780+ code .writeline ('(R"MTL(' )
781+ else :
782+ code .writeline ("compile_mps_shader('''" )
783+
779784 idx_vars = self .active_range_trees ()
780785 with code .indent ():
781- for header in self .headers :
782- code .writeline (f"#include <c10/metal/{ header } .h>" )
786+ if not V .graph .cpp_wrapper :
787+ for header in self .headers :
788+ code .writeline (f"#include <c10/metal/{ header } .h>" )
783789 if self .inside_reduction :
784790 total_reduction_size = math .prod (
785791 t .numel for t in self .range_trees if t .is_reduction
@@ -833,7 +839,11 @@ def codegen_kernel(self, name: Optional[str] = None) -> str:
833839 code .splice (self .indexing_code )
834840 code .splice (self .body )
835841 code .writeline ("}" )
836- code .writeline ("''')" )
842+
843+ if V .graph .cpp_wrapper :
844+ code .writeline (')MTL");' )
845+ else :
846+ code .writeline ("''')" )
837847
838848 return code .getvalue ()
839849
@@ -858,15 +868,31 @@ def call_kernel(self, name: str, node: Any = None) -> None:
858868 )
859869 for v in self .active_range_trees ()
860870 ]
861- args += [f"threads=[{ ', ' .join (threads )} ]" ]
871+
872+ if V .graph .cpp_wrapper :
873+ args += [f"{ ', ' .join (threads )} " ]
874+ else :
875+ args += [f"threads=[{ ', ' .join (threads )} ]" ]
876+ else :
877+ if V .graph .cpp_wrapper :
878+ raise RuntimeError ("We should always have threads?" )
879+
862880 if self .inside_reduction :
863881 threads = [
864882 self .pexpr (sympy .Min (v .numel , self .max_threadgroup_size )) # type: ignore[misc]
865883 if v .is_reduction
866884 else "1"
867885 for v in self .active_range_trees ()
868886 ]
869- args += [f"group_size=[{ ', ' .join (threads )} ]" ]
887+ if V .graph .cpp_wrapper :
888+ args += [f"{{{ ', ' .join (threads )} }}" ]
889+ else :
890+ args += [f"group_size=[{ ', ' .join (threads )} ]" ]
891+ else :
892+ if V .graph .cpp_wrapper :
893+ # Add a None so that we always have a group_size in the
894+ # arguments. We won't use it if the value is None.
895+ args += [None ] # type: ignore[list-item]
870896
871897 wrapper .generate_kernel_call (
872898 name ,
@@ -900,9 +926,10 @@ def __init__(self, scheduler: Optional[Scheduler]) -> None:
900926 super ().__init__ (scheduler )
901927 wrapper = V .graph .wrapper_code
902928 if wrapper is not None :
903- wrapper .header .splice (
904- "from torch._inductor.runtime.runtime_utils import compile_mps_shader"
905- )
929+ if not V .graph .cpp_wrapper :
930+ wrapper .header .splice (
931+ "from torch._inductor.runtime.runtime_utils import compile_mps_shader"
932+ )
906933
907934 def define_kernel (
908935 self , src_code : str , node_schedule : list [SchedulerNode ], kernel : MetalKernel
@@ -914,10 +941,19 @@ def define_kernel(
914941 # TODO: Merge multiple kernels into a single library
915942 # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling
916943 mps_lib_name = f"mps_lib_{ wrapper .next_kernel_suffix ()} "
917- kernel_name = f"{ mps_lib_name } .generated_kernel"
944+
945+ if V .graph .cpp_wrapper :
946+ src_code = (
947+ f"at::native::mps::DynamicMetalShaderLibrary { mps_lib_name } "
948+ + src_code
949+ )
950+ kernel_name = f"{ mps_lib_name } _func"
951+ else :
952+ kernel_name = f"{ mps_lib_name } .generated_kernel"
953+
918954 wrapper .src_to_kernel [src_code ] = kernel_name
919955 origins , detailed_origins = get_kernel_metadata (node_schedule , wrapper )
920956 metadata_comment = f"{ origins } \n { detailed_origins } "
921- wrapper .define_kernel (mps_lib_name , src_code , metadata_comment )
957+ wrapper .define_kernel (mps_lib_name , src_code , metadata_comment , gpu = False )
922958
923959 return kernel_name
0 commit comments