Skip to content

Commit 3b6665b

Browse files
committed
xe: conv: jit: add heuristic to handle model tie
1 parent 808227d commit 3b6665b

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

src/gpu/intel/conv/jit/tiler.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,31 @@ void sort_by_model_scores(params_generator_t &params_gen, const config_t &cfg,
12181218
}
12191219
params_gen.sort(0, params_gen.configs(),
12201220
[&](const blocking_params_t &p) { return -eff_scores.at(p.id()); });
1221+
1222+
// Heuristics when model tie is detected
1223+
auto &params_vec = params_gen.params_vec();
1224+
auto &p_best = params_vec[0];
1225+
for (auto &p_next : params_gen.params_vec()) {
1226+
if (&p_best == &p_next) continue;
1227+
if (eff_scores.at(p_best.id()) != eff_scores.at(p_next.id())) break;
1228+
1229+
if (cfg.prb().is_bwd_w && cfg.allow_global_reduction()) {
1230+
// As the model estimate is the same, prefer fewer atomic reductions
1231+
// to reduce contention on L3 cache lines
1232+
auto size = [&](const tile_t &loop, const tile_t &iter) {
1233+
return iter.get(pvars::mb) * iter.get(pvars::ow)
1234+
* iter.get(pvars::oh) * iter.get(pvars::od)
1235+
* loop.get(pvars::mb) * loop.get(pvars::ow)
1236+
* loop.get(pvars::oh) * loop.get(pvars::od);
1237+
};
1238+
auto &b0 = p_best.blocking(), &b1 = p_next.blocking();
1239+
if (size(b0.loop(), b0.iter()) < size(b1.loop(), b1.iter())) {
1240+
std::swap(p_best, p_next);
1241+
continue;
1242+
}
1243+
}
1244+
}
1245+
12211246
#ifdef DNNL_DEV_MODE
12221247
using namespace ir_utils;
12231248
std::vector<std::string> headers

src/gpu/intel/jit/ir/blocking.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ class params_generator_t {
527527
const std::vector<blocking_params_t> &params_vec() const {
528528
return params_vec_;
529529
}
530+
std::vector<blocking_params_t> &params_vec() { return params_vec_; }
530531

531532
bool is_empty() const { return params_vec_.empty(); }
532533

0 commit comments

Comments
 (0)