Skip to content

Commit a5940c6

Browse files
committed
Fix conv3d signature to pass device as arg due to metal 0d3558f
1 parent 4c556e7 commit a5940c6

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,10 +1667,19 @@ class Conv3dOpConversionPattern
16671667
auto outputDtype =
16681668
dtypeAttr ? dtypeAttr.getValue() : ttcore::DataType::BFloat16;
16691669

1670+
// Emit SSA operands in ODS order (input, weight, bias, device) to
1671+
// maintain correct index mapping, then arrange args in the C++ API
1672+
// call order (input, weight, device, bias, ...).
1673+
auto inputAttr = emitter.emit(srcOp.getInput());
1674+
auto weightAttr = emitter.emit(srcOp.getWeight());
1675+
auto biasAttr = emitter.emit(srcOp.getBias());
1676+
auto deviceAttr = emitter.emit(srcOp.getDevice());
1677+
16701678
llvm::SmallVector<mlir::Attribute> args{
1671-
emitter.emit(srcOp.getInput()),
1672-
emitter.emit(srcOp.getWeight()),
1673-
emitter.emit(srcOp.getBias()),
1679+
inputAttr,
1680+
weightAttr,
1681+
deviceAttr,
1682+
biasAttr,
16741683
emitter.emitConv3dConfig(srcOp.getConv3dConfig()),
16751684
emitter.emit(outputDtype),
16761685
emitter.emit(srcOp.getOutChannels()),

lib/OpModel/TTNN/TTNNOpModel.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5111,6 +5111,7 @@ llvm::Expected<OpConstraints> OpModel<Conv3dOp>::getOpConstraints(
51115111
auto conv3dOpQuery = [=, &specs]() {
51125112
return ::ttnn::graph::query_op_constraints(
51135113
::ttnn::experimental::conv3d, device, specs.inputSpec, specs.weightSpec,
5114+
std::optional<::tt::tt_metal::distributed::MeshDevice *>(device),
51145115
specs.biasSpec, specs.config, specs.dtype, specs.outputChannels,
51155116
specs.kernelSize, specs.stride, specs.padding,
51165117
std::array<uint32_t, 3>{1, 1, 1}, specs.paddingMode, specs.groups,
@@ -5156,6 +5157,7 @@ llvm::Expected<size_t> OpModel<Conv3dOp>::getOpRuntime(
51565157
auto conv3dOpRuntime = [=, &specs]() {
51575158
return ::ttnn::graph::query_op_runtime(
51585159
::ttnn::experimental::conv3d, device, specs.inputSpec, specs.weightSpec,
5160+
std::optional<::tt::tt_metal::distributed::MeshDevice *>(device),
51595161
specs.biasSpec, specs.config, specs.dtype, specs.outputChannels,
51605162
specs.kernelSize, specs.stride, specs.padding,
51615163
std::array<uint32_t, 3>{1, 1, 1}, specs.paddingMode, specs.groups,

runtime/lib/ttnn/operations/conv/conv3d.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ void run(const ::tt::target::ttnn::Conv3dOp *op, ProgramContext &context) {
9494
"Memory config must exist for device tensors");
9595

9696
::ttnn::Tensor out = ::ttnn::experimental::conv3d(
97-
input, weight, bias, conv3dConfig, outputDtype, op->out_channels(),
98-
kernelSize, stride, padding, std::array<uint32_t, 3>{1, 1, 1},
99-
op->padding_mode()->str(), op->groups(), outputMemoryConfig,
100-
deviceComputeConfig);
97+
input, weight, &targetDevice, bias, conv3dConfig, outputDtype,
98+
op->out_channels(), kernelSize, stride, padding,
99+
std::array<uint32_t, 3>{1, 1, 1}, op->padding_mode()->str(), op->groups(),
100+
outputMemoryConfig, deviceComputeConfig);
101101

102102
tensorPool.insertTTNNTensorAndValidate(op->out(), out);
103103
}

0 commit comments

Comments
 (0)