diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index b9ef94564049..eba08b820d14 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -89,7 +89,6 @@ class FindBuffers : public IRGraphVisitor { r.param = op->param; r.type = op->param.type(); r.dimensions = op->param.dimensions(); - r.used_on_host = false; buffers[op->param.name()] = r; } else if (op->reduction_domain.defined()) { // The bounds of reduction domains are not yet defined, @@ -208,6 +207,7 @@ Stmt add_image_checks_inner(Stmt s, vector asserts_device_not_dirty; vector buffer_rewrites; vector msan_checks; + vector set_host_dirty; // Inject the code that conditionally returns if we're in inference mode Expr maybe_return_condition = const_false(); @@ -649,6 +649,16 @@ Stmt add_image_checks_inner(Stmt s, // If we have no device support, we can't handle // device_dirty, so every buffer touched needs checking. asserts_device_not_dirty.push_back(AssertStmt::make(!device_dirty, error)); + + // However, if it's an output, we do still need to set the host + // dirty bit in case the result is fed to a later GPU + // kernel. + if (is_output_buffer) { + Expr set = + Call::make(Int(32), Call::buffer_set_host_dirty, + {handle, const_true()}, Call::Extern); + set_host_dirty.push_back(Evaluate::make(set)); + } } } @@ -678,6 +688,10 @@ Stmt add_image_checks_inner(Stmt s, } }; + // After all asserts, set host dirty on outputs if this is a CPU-only + // pipeline + prepend_stmts(&set_host_dirty); + // Inject the code that checks the host pointers. prepend_stmts(&asserts_host_non_null); prepend_stmts(&asserts_host_alignment); diff --git a/src/Lower.cpp b/src/Lower.cpp index c248684ea5d9..cd7ccc9a03f4 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -326,10 +326,6 @@ void lower_impl(const vector &output_funcs, debug(1) << "Selecting a GPU API for extern stages...\n"; s = select_gpu_api(s, t); log("Lowering after selecting a GPU API for extern stages:", s); - } else { - debug(1) << "Injecting host-dirty marking...\n"; - s = inject_host_dev_buffer_copies(s, t); - log("Lowering after injecting host-dirty marking:", s); } debug(1) << "Simplifying...\n";