@@ -67,19 +67,19 @@ TEST_P(QKVColumnParallelLinearTest, LoadFusedWeight) {
6767
6868 // generate random input and compare with the output
6969 auto input = torch::randn ({n_tokens, hidden_size}, options);
70- auto qkv = linear.forward (input);
70+ const auto [q, k, v] = linear.forward (input);
7171
7272 const int64_t kv_shard_id =
7373 n_kv_heads >= n_shards ? shard_id : n_kv_heads * shard_id / n_shards;
7474
7575 auto query = input.matmul (query_chunks[shard_id].t ());
76- EXPECT_TRUE (torch::allclose (qkv[ 0 ] , query, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 ));
76+ EXPECT_TRUE (torch::allclose (q , query, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 ));
7777
7878 auto key = input.matmul (key_chunks[kv_shard_id].t ());
79- EXPECT_TRUE (torch::allclose (qkv[ 1 ] , key, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 ));
79+ EXPECT_TRUE (torch::allclose (k , key, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 ));
8080
8181 auto value = input.matmul (value_chunks[kv_shard_id].t ());
82- EXPECT_TRUE (torch::allclose (qkv[ 2 ] , value, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 ));
82+ EXPECT_TRUE (torch::allclose (v , value, /* rtol=*/ 1e-5 , /* atol=*/ 1e-5 ));
8383 }
8484}
8585
0 commit comments