Skip to content

Commit 6ec4c02

Browse files
authored
Add col bcast for where binary variant TTS and TST (#27757)
### Ticket #27624 ### Problem description where op llk does not have col broadcast support ### What's changed provide native bcast support for TTS and TST variant For 32x32: ~56% improved LLK version - 6159.00 ns Legacy version - 13953.00 ns For 1024x1024: ~80% improved LLK version - 45608.00 ns Legacy version - 230796.00 ns ### Checklist - [ ] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes https://github.com/tenstorrent/tt-metal/actions/runs/17673906164 - [ ] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml) CI with demo tests passes (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/17673909619 - [ ] [Model regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) CI passes (if applicable) - [ ] [Device performance regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) CI passes (if applicable) - [ ] (For models and ops writers) [Single-card demo tests](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) CI passes (if applicable) See [recommended dev flow](https://github.com/tenstorrent/tt-metal/blob/main/models/docs/MODEL_ADD.md#a-recommended-dev-flow-on-github-for-adding-new-models). - [ ] [Galaxy quick](https://github.com/tenstorrent/tt-metal/actions/workflows/tg-quick-trigger.yaml) CI passes (if applicable) - [ ] [TG demo tests, for Llama](https://github.com/tenstorrent/tt-metal/actions/workflows/tg-demo-tests.yaml) CI passes, if applicable, because of current Llama work - [ ] (For runtime and ops writers) [T3000 unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml) CI passes (if applicable, since this is run on push to main) - [ ] (For models and ops writers) [T3000 demo tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) CI passes (if applicable, since this is required for release) - [ ] New/Existing tests provide coverage for changes
1 parent 9e05089 commit 6ec4c02

File tree

9 files changed

+829
-77
lines changed

9 files changed

+829
-77
lines changed

tests/ttnn/unit_tests/operations/eltwise/test_where.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def torch_equal_nan(a, b):
2626
((1, 1, 32, 32), (1, 1, 32, 32), (1, 1, 32, 32)), # LLK
2727
((2, 3, 64, 128), (2, 3, 64, 128), (2, 3, 64, 128)), # LLK
2828
((3, 2, 3, 64, 128), (3, 2, 3, 64, 128), (3, 2, 3, 64, 128)), # LLK
29+
((1, 1, 1024, 1024), (1, 1, 1024, 1), (1, 1, 1024, 1024)), # A, Bcol, C
30+
((1, 1, 1024, 1), (1, 1, 1024, 1024), (1, 1, 1024, 1024)), # Acol, B, C
31+
((1, 1, 1024, 1024), (1, 1, 1024, 1024), (1, 1, 1024, 1)), # A, B, Ccol
32+
((1, 1, 64, 1), (1, 1, 64, 64), (1, 1, 64, 64)), # Acol, B, C
2933
((256,), (256,), (256,)), # LLK
3034
# Bcast cases for dims -5, -4, -3 (outer dims)
3135
((128, 128), (2, 2, 2, 128, 128), (2, 2, 128, 128)),
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include <cstdint>
6+
7+
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
8+
#include "compute_kernel_api/eltwise_unary/where.h"
9+
#include "compute_kernel_api/eltwise_unary/fill.h"
10+
#include "ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp"
11+
#include "ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_sfpu.hpp"
12+
13+
namespace NAMESPACE {
14+
15+
ALWI void process_tile(
16+
tt::CBIndex predicate_cb,
17+
tt::CBIndex false_cb,
18+
tt::CBIndex cb_out,
19+
uint32_t true_scalar,
20+
uint32_t freq,
21+
uint32_t tile_start,
22+
uint32_t num_tiles_per_cycle) {
23+
using namespace ckernel;
24+
25+
// 2-tensor broadcast-aware synchronization - wait for broadcast CBs outside loop
26+
// (true is scalar, so no CB for it)
27+
#if BCAST_PRED
28+
cb_wait_front(predicate_cb, num_tiles_per_cycle); // predicate_cb is broadcast
29+
#endif
30+
#if BCAST_FALSE
31+
cb_wait_front(false_cb, num_tiles_per_cycle); // false_cb is broadcast
32+
#endif
33+
34+
for (uint32_t j = tile_start; j < freq; ++j) {
35+
// Wait for non-broadcast CBs inside loop
36+
#if !BCAST_PRED
37+
cb_wait_front(predicate_cb, num_tiles_per_cycle);
38+
#endif
39+
#if !BCAST_FALSE
40+
cb_wait_front(false_cb, num_tiles_per_cycle);
41+
#endif
42+
43+
cb_reserve_back(cb_out, num_tiles_per_cycle);
44+
45+
tile_regs_acquire();
46+
47+
// Copy predicate to destination register 0
48+
copy_tile_init(predicate_cb);
49+
copy_tile(predicate_cb, 0, 0); // predicate to reg 0
50+
51+
// Fill scalar true value to destination register 1
52+
fill_tile_init();
53+
#ifdef FILL_WITH_VALUE_FLOAT
54+
const auto true_value = reinterpret_cast<const float*>(&true_scalar);
55+
FILL_LLK(1, *true_value);
56+
#endif
57+
#ifdef FILL_WITH_VALUE_INT
58+
FILL_LLK(1, true_scalar);
59+
#endif
60+
61+
// Copy false tensor to destination register 2
62+
copy_tile_init(false_cb);
63+
copy_tile(false_cb, 0, 2); // false to reg 2
64+
65+
// Perform the where operation: where(predicate, true, false)
66+
where_tile_init();
67+
WHERE_LLK(0, 1, 2, 0);
68+
69+
tile_regs_commit();
70+
71+
tile_regs_wait();
72+
73+
pack_tile(0, cb_out); // result is stored in register 0
74+
tile_regs_release();
75+
76+
cb_push_back(cb_out, num_tiles_per_cycle);
77+
78+
// Pop non-broadcast CBs inside loop
79+
#if !BCAST_PRED
80+
cb_pop_front(predicate_cb, num_tiles_per_cycle);
81+
#endif
82+
#if !BCAST_FALSE
83+
cb_pop_front(false_cb, num_tiles_per_cycle);
84+
#endif
85+
}
86+
87+
// Pop broadcast CBs outside loop
88+
#if BCAST_PRED
89+
cb_pop_front(predicate_cb, num_tiles_per_cycle);
90+
#endif
91+
#if BCAST_FALSE
92+
cb_pop_front(false_cb, num_tiles_per_cycle);
93+
#endif
94+
}
95+
96+
void MAIN {
97+
uint32_t num_tiles = get_arg_val<uint32_t>(0);
98+
uint32_t tile_freq = get_arg_val<uint32_t>(1);
99+
uint32_t tile_start = get_arg_val<uint32_t>(2);
100+
const uint32_t true_scalar = get_arg_val<uint32_t>(3);
101+
102+
constexpr uint32_t num_tiles_per_cycle = get_compile_time_arg_val(0);
103+
104+
if (num_tiles == 0) {
105+
return;
106+
}
107+
108+
constexpr auto predicate_cb = tt::CBIndex::c_0;
109+
constexpr auto false_cb = tt::CBIndex::c_1;
110+
constexpr auto cb_out = tt::CBIndex::c_3;
111+
112+
unary_op_init_common(predicate_cb, cb_out);
113+
114+
uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq;
115+
uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq;
116+
117+
for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) {
118+
process_tile(predicate_cb, false_cb, cb_out, true_scalar, tile_freq, tile_start, num_tiles_per_cycle);
119+
}
120+
121+
if (remaining_iterations > 0) {
122+
process_tile(
123+
predicate_cb, false_cb, cb_out, true_scalar, remaining_iterations, tile_start, num_tiles_per_cycle);
124+
}
125+
}
126+
} // namespace NAMESPACE
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include <cstdint>
6+
7+
#include "compute_kernel_api/eltwise_unary/eltwise_unary.h"
8+
#include "compute_kernel_api/eltwise_unary/where.h"
9+
#include "compute_kernel_api/eltwise_unary/fill.h"
10+
#include "ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp"
11+
#include "ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_sfpu.hpp"
12+
13+
namespace NAMESPACE {
14+
15+
ALWI void process_tile(
16+
tt::CBIndex predicate_cb,
17+
tt::CBIndex true_cb,
18+
tt::CBIndex cb_out,
19+
uint32_t false_scalar,
20+
uint32_t freq,
21+
uint32_t tile_start,
22+
uint32_t num_tiles_per_cycle) {
23+
using namespace ckernel;
24+
25+
// 2-tensor broadcast-aware synchronization - wait for broadcast CBs outside loop
26+
// (false is scalar, so no CB for it)
27+
#if BCAST_PRED
28+
cb_wait_front(predicate_cb, num_tiles_per_cycle); // predicate_cb is broadcast
29+
#endif
30+
#if BCAST_TRUE
31+
cb_wait_front(true_cb, num_tiles_per_cycle); // true_cb is broadcast
32+
#endif
33+
34+
for (uint32_t j = tile_start; j < freq; ++j) {
35+
// Wait for non-broadcast CBs inside loop
36+
#if !BCAST_PRED
37+
cb_wait_front(predicate_cb, num_tiles_per_cycle);
38+
#endif
39+
#if !BCAST_TRUE
40+
cb_wait_front(true_cb, num_tiles_per_cycle);
41+
#endif
42+
43+
cb_reserve_back(cb_out, num_tiles_per_cycle);
44+
45+
tile_regs_acquire();
46+
47+
// Copy predicate to destination register 0
48+
copy_tile_init(predicate_cb);
49+
copy_tile(predicate_cb, 0, 0); // predicate to reg 0
50+
51+
// Copy true tensor to destination register 1
52+
copy_tile_init(true_cb);
53+
copy_tile(true_cb, 0, 1); // true to reg 1
54+
55+
// Fill scalar false value to destination register 2
56+
fill_tile_init();
57+
#ifdef FILL_WITH_VALUE_FLOAT
58+
const auto false_value = reinterpret_cast<const float*>(&false_scalar);
59+
FILL_LLK(2, *false_value);
60+
#endif
61+
#ifdef FILL_WITH_VALUE_INT
62+
FILL_LLK(2, false_scalar);
63+
#endif
64+
65+
// Perform the where operation: where(predicate, true, false)
66+
where_tile_init();
67+
WHERE_LLK(0, 1, 2, 0);
68+
69+
tile_regs_commit();
70+
71+
tile_regs_wait();
72+
73+
pack_tile(0, cb_out); // result is stored in register 0
74+
tile_regs_release();
75+
76+
cb_push_back(cb_out, num_tiles_per_cycle);
77+
78+
// Pop non-broadcast CBs inside loop
79+
#if !BCAST_PRED
80+
cb_pop_front(predicate_cb, num_tiles_per_cycle);
81+
#endif
82+
#if !BCAST_TRUE
83+
cb_pop_front(true_cb, num_tiles_per_cycle);
84+
#endif
85+
}
86+
87+
// Pop broadcast CBs outside loop
88+
#if BCAST_PRED
89+
cb_pop_front(predicate_cb, num_tiles_per_cycle);
90+
#endif
91+
#if BCAST_TRUE
92+
cb_pop_front(true_cb, num_tiles_per_cycle);
93+
#endif
94+
}
95+
96+
void MAIN {
97+
uint32_t num_tiles = get_arg_val<uint32_t>(0);
98+
uint32_t tile_freq = get_arg_val<uint32_t>(1);
99+
uint32_t tile_start = get_arg_val<uint32_t>(2);
100+
const uint32_t false_scalar = get_arg_val<uint32_t>(3);
101+
102+
constexpr uint32_t num_tiles_per_cycle = get_compile_time_arg_val(0);
103+
104+
if (num_tiles == 0) {
105+
return;
106+
}
107+
108+
constexpr auto predicate_cb = tt::CBIndex::c_0;
109+
constexpr auto true_cb = tt::CBIndex::c_1;
110+
constexpr auto cb_out = tt::CBIndex::c_3;
111+
112+
unary_op_init_common(predicate_cb, cb_out);
113+
114+
uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq;
115+
uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq;
116+
117+
for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) {
118+
process_tile(predicate_cb, true_cb, cb_out, false_scalar, tile_freq, tile_start, num_tiles_per_cycle);
119+
}
120+
121+
if (remaining_iterations > 0) {
122+
process_tile(
123+
predicate_cb, true_cb, cb_out, false_scalar, remaining_iterations, tile_start, num_tiles_per_cycle);
124+
}
125+
}
126+
} // namespace NAMESPACE

0 commit comments

Comments
 (0)