@@ -88,3 +88,35 @@ func TestRhaiFeaturesRocm(t *testing.T) {
8888 Tags (t , KftoRocm , MultiNodeGpu (2 , support .AMD ))
8989 sdktests .RunRhaiFeaturesAllTest (t , support .AMD )
9090}
91+
92+ // Multi-GPU CUDA tests - 2 nodes, 2 GPUs each (requires 4 total NVIDIA GPUs)
93+ func TestRhaiTrainingProgressionMultiGpuCuda (t * testing.T ) {
94+ Tags (t , KftoCuda , MultiNodeMultiGpu (2 , support .NVIDIA , 2 ))
95+ sdktests .RunRhaiFeaturesProgressionMultiGpuTest (t , support .NVIDIA , 2 , 2 )
96+ }
97+
98+ func TestRhaiJitCheckpointingMultiGpuCuda (t * testing.T ) {
99+ Tags (t , KftoCuda , MultiNodeMultiGpu (2 , support .NVIDIA , 2 ))
100+ sdktests .RunRhaiFeaturesCheckpointMultiGpuTest (t , support .NVIDIA , 2 , 2 )
101+ }
102+
103+ func TestRhaiFeaturesMultiGpuCuda (t * testing.T ) {
104+ Tags (t , KftoCuda , MultiNodeMultiGpu (2 , support .NVIDIA , 2 ))
105+ sdktests .RunRhaiFeaturesAllMultiGpuTest (t , support .NVIDIA , 2 , 2 )
106+ }
107+
108+ // Multi-GPU ROCm tests - 2 nodes, 2 GPUs each (requires 4 total AMD GPUs)
109+ func TestRhaiTrainingProgressionMultiGpuRocm (t * testing.T ) {
110+ Tags (t , KftoRocm , MultiNodeMultiGpu (2 , support .AMD , 2 ))
111+ sdktests .RunRhaiFeaturesProgressionMultiGpuTest (t , support .AMD , 2 , 2 )
112+ }
113+
114+ func TestRhaiJitCheckpointingMultiGpuRocm (t * testing.T ) {
115+ Tags (t , KftoRocm , MultiNodeMultiGpu (2 , support .AMD , 2 ))
116+ sdktests .RunRhaiFeaturesCheckpointMultiGpuTest (t , support .AMD , 2 , 2 )
117+ }
118+
119+ func TestRhaiFeaturesMultiGpuRocm (t * testing.T ) {
120+ Tags (t , KftoRocm , MultiNodeMultiGpu (2 , support .AMD , 2 ))
121+ sdktests .RunRhaiFeaturesAllMultiGpuTest (t , support .AMD , 2 , 2 )
122+ }
0 commit comments