Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/AddImageChecks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class FindBuffers : public IRGraphVisitor {
r.param = op->param;
r.type = op->param.type();
r.dimensions = op->param.dimensions();
r.used_on_host = false;
buffers[op->param.name()] = r;
} else if (op->reduction_domain.defined()) {
// The bounds of reduction domains are not yet defined,
Expand Down Expand Up @@ -208,6 +207,7 @@ Stmt add_image_checks_inner(Stmt s,
vector<Stmt> asserts_device_not_dirty;
vector<Stmt> buffer_rewrites;
vector<Stmt> msan_checks;
vector<Stmt> set_host_dirty;

// Inject the code that conditionally returns if we're in inference mode
Expr maybe_return_condition = const_false();
Expand Down Expand Up @@ -649,6 +649,16 @@ Stmt add_image_checks_inner(Stmt s,
// If we have no device support, we can't handle
// device_dirty, so every buffer touched needs checking.
asserts_device_not_dirty.push_back(AssertStmt::make(!device_dirty, error));

// However, if it's an output, we do still need to set the host
// dirty bit in case the result is fed to a later GPU
// kernel.
if (is_output_buffer) {
Expr set =
Call::make(Int(32), Call::buffer_set_host_dirty,
{handle, const_true()}, Call::Extern);
set_host_dirty.push_back(Evaluate::make(set));
}
}
}

Expand Down Expand Up @@ -678,6 +688,10 @@ Stmt add_image_checks_inner(Stmt s,
}
};

// After all asserts, set host dirty on outputs if this is a CPU-only
// pipeline
prepend_stmts(&set_host_dirty);

// Inject the code that checks the host pointers.
prepend_stmts(&asserts_host_non_null);
prepend_stmts(&asserts_host_alignment);
Expand Down
4 changes: 0 additions & 4 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,6 @@ void lower_impl(const vector<Function> &output_funcs,
debug(1) << "Selecting a GPU API for extern stages...\n";
s = select_gpu_api(s, t);
log("Lowering after selecting a GPU API for extern stages:", s);
} else {
debug(1) << "Injecting host-dirty marking...\n";
s = inject_host_dev_buffer_copies(s, t);
log("Lowering after injecting host-dirty marking:", s);
}

debug(1) << "Simplifying...\n";
Expand Down
Loading