File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,8 @@ class RegLossObjOneAPI : public ObjFunction {
7171 sycl::buffer<bst_float, 1 > weights_buf (is_null_weight ? NULL : info.weights_ .HostPointer (),
7272 is_null_weight ? 1 : info.weights_ .Size ());
7373
74+ const size_t n_targets = std::max (info.labels .Shape (1 ), static_cast <size_t >(1 ));
75+
7476 sycl::buffer<int , 1 > additional_input_buf (1 );
7577 {
7678 auto additional_input_acc = additional_input_buf.get_access <sycl::access::mode::write>();
@@ -92,7 +94,7 @@ class RegLossObjOneAPI : public ObjFunction {
9294 cgh.parallel_for <>(sycl::range<1 >(ndata), [=](sycl::id<1 > pid) {
9395 int idx = pid[0 ];
9496 bst_float p = Loss::PredTransform (preds_acc[idx]);
95- bst_float w = is_null_weight ? 1 .0f : weights_acc[idx];
97+ bst_float w = is_null_weight ? 1 .0f : weights_acc[idx/n_targets ];
9698 bst_float label = labels_acc[idx];
9799 if (label == 1 .0f ) {
98100 w *= scale_pos_weight;
@@ -125,7 +127,6 @@ class RegLossObjOneAPI : public ObjFunction {
125127
126128 void PredTransform (HostDeviceVector<float > *io_preds) const override {
127129 size_t const ndata = io_preds->Size ();
128-
129130 sycl::buffer<bst_float, 1 > io_preds_buf (io_preds->HostPointer (), io_preds->Size ());
130131
131132 qu_.submit ([&](sycl::handler& cgh) {
You can’t perform that action at this time.
0 commit comments