diff --git a/src/FindCalls.cpp b/src/FindCalls.cpp index 1fca6de1175c..8389d2a5ff7e 100644 --- a/src/FindCalls.cpp +++ b/src/FindCalls.cpp @@ -3,6 +3,8 @@ #include "ExternFuncArgument.h" #include "Function.h" #include "IRVisitor.h" +#include "Parameter.h" +#include #include namespace Halide { @@ -97,6 +99,64 @@ std::map build_environment(const std::vector &f for (const Function &f : funcs) { populate_environment_helper(f, &env, &order, true, true); } + + // Validate the environment: no Parameter (ImageParam, Generator + // Input, or scalar Param) may share a name with a Func in the + // pipeline. Such a collision otherwise causes confusing internal errors + // later in lowering. Output Funcs intentionally share names with their + // output buffer Parameters, so exclude buffer params whose name matches + // an output Func. + class FindParamNames : public IRVisitor { + using IRVisitor::visit; + void record(const Parameter &p) { + if (!p.defined()) { + return; + } + if (p.is_buffer()) { + buffer_names.insert(p.name()); + } else { + scalar_names.insert(p.name()); + } + } + void visit(const Variable *op) override { + record(op->param); + } + void visit(const Call *op) override { + IRVisitor::visit(op); + record(op->param); + } + + public: + std::set buffer_names; + std::set scalar_names; + } finder; + for (const auto &p : env) { + p.second.accept(&finder); + } + std::set output_names; + for (const Function &f : funcs) { + output_names.insert(f.name()); + } + for (const std::string &name : finder.buffer_names) { + if (output_names.count(name)) { + continue; + } + if (env.count(name)) { + user_error << "The name \"" << name << "\" is used for both " + << "an input buffer (ImageParam or Generator Input) " + << "and a Func in the same pipeline. " + << "Input buffers and Funcs must have distinct names.\n"; + } + } + for (const std::string &name : finder.scalar_names) { + if (env.count(name)) { + user_error << "The name \"" << name << "\" is used for both " + << "a scalar Param (or Generator Input scalar) " + << "and a Func in the same pipeline. " + << "Params and Funcs must have distinct names.\n"; + } + } + return env; } diff --git a/src/Generator.cpp b/src/Generator.cpp index b28bea0975c2..0a445ceabafa 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -1919,6 +1919,8 @@ void GeneratorInputBase::init_internals() { funcs_.clear(); for (size_t i = 0; i < array_size(); ++i) { auto name = array_name(i); + // Discourage future Funcs from having the same name as this Input. + Internal::unique_name(name); parameters_.emplace_back(gio_type(), kind() != ArgInfoKind::Scalar, dims(), name); auto &p = parameters_[i]; if (kind() != ArgInfoKind::Scalar) { diff --git a/src/Param.h b/src/Param.h index 57991dd12475..30864c2db435 100644 --- a/src/Param.h +++ b/src/Param.h @@ -37,6 +37,8 @@ class Param { << "is no longer used to control whether Halide functions take explicit " << "user_context arguments. Use set_custom_user_context() when jitting, " << "or add Target::UserContext to the Target feature set when compiling ahead of time."; + // Discourage future Funcs from having the same name as this Param. + Internal::unique_name(param.name()); } // Allow all Param<> variants friend access to each other diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index ca43f5f2cf40..4ee1d19ddb05 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -188,6 +188,7 @@ tests(GROUPS correctness infer_arguments.cpp inline_reduction.cpp inlined_generator.cpp + input_func_name_unique.cpp input_image_bounds_check.cpp input_larger_than_two_gigs.cpp integer_powers.cpp diff --git a/test/correctness/input_func_name_unique.cpp b/test/correctness/input_func_name_unique.cpp new file mode 100644 index 000000000000..0683a0c45134 --- /dev/null +++ b/test/correctness/input_func_name_unique.cpp @@ -0,0 +1,64 @@ +#include "Halide.h" +#include + +using namespace Halide; + +namespace { + +// Generator with an Input declared before any Func of the same +// name. The Func created inside generate() must be renamed so the +// pipeline compiles without collision. +class GenInputBufferThenFunc : public Halide::Generator { +public: + Input> input_foo{"foo"}; + Output> out{"out"}; + + void generate() { + Var x; + Func foo("foo"); + foo(x) = x; + out(x) = input_foo(x) + foo(x); + } +}; + +} // namespace + +HALIDE_REGISTER_GENERATOR(GenInputBufferThenFunc, gen_input_buffer_then_func) + +int main(int argc, char **argv) { + // An ImageParam followed by a Func of the same name: the Func is + // renamed so the names differ. + { + ImageParam ip(Int(32), 1, "foo"); + Func foo("foo"); + assert(ip.name() != foo.name() && + "ImageParam should reserve its name against later Funcs"); + } + + // A scalar Param followed by a Func of the same name: the Func is + // renamed so the names differ. + { + Param p("foo"); + Func foo("foo"); + assert(p.name() != foo.name() && + "Param should reserve its name against later Funcs"); + } + + // A Generator Input followed by a Func of the same name + // inside generate(): the Func is renamed and the pipeline compiles. + { + GeneratorContext ctx(get_jit_target_from_environment()); + Callable c = create_callable_from_generator(ctx, "gen_input_buffer_then_func"); + + Buffer in(10), out(10); + in.fill(0); + int r = c(in, out); + assert(r == 0); + for (int i = 0; i < 10; i++) { + assert(out(i) == i); + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index fc9496af0244..234366c8d1b3 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -74,6 +74,9 @@ tests(GROUPS error impossible_constraints.cpp incomplete_target.cpp init_def_should_be_all_vars.cpp + input_buffer_func_name_collision.cpp + input_generator_buffer_func_name_collision.cpp + input_param_func_name_collision.cpp inspect_loop_level.cpp lerp_float_weight_out_of_range.cpp lerp_mismatch.cpp diff --git a/test/error/input_buffer_func_name_collision.cpp b/test/error/input_buffer_func_name_collision.cpp new file mode 100644 index 000000000000..14a663acc201 --- /dev/null +++ b/test/error/input_buffer_func_name_collision.cpp @@ -0,0 +1,22 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + // Func declared before an ImageParam of the same name. Lowering + // should produce a clean user_error rather than crashing. + Var x; + Func existing("foo"); + existing(x) = x; + + ImageParam ip(Int(32), 1, "foo"); + + Func out("out"); + out(x) = existing(x) + ip(x); + + out.compile_jit(); + + printf("Should not get here\n"); + return 0; +} diff --git a/test/error/input_generator_buffer_func_name_collision.cpp b/test/error/input_generator_buffer_func_name_collision.cpp new file mode 100644 index 000000000000..c36027899c77 --- /dev/null +++ b/test/error/input_generator_buffer_func_name_collision.cpp @@ -0,0 +1,35 @@ +#include "Halide.h" +#include + +using namespace Halide; + +namespace { + +// A Func declared as a member is constructed before the Input +// member (declaration order), so by the time the Input's Parameter is +// created, the Func has already reserved "foo". The Parameter still keeps +// its literal name, so they collide in the resulting pipeline. +class GenFuncBeforeInputBuffer : public Halide::Generator { +public: + Func foo{"foo"}; + Input> input_foo{"foo"}; + Output> out{"out"}; + + void generate() { + Var x; + foo(x) = x; + out(x) = input_foo(x) + foo(x); + } +}; + +} // namespace + +HALIDE_REGISTER_GENERATOR(GenFuncBeforeInputBuffer, gen_input_buffer_func_collision) + +int main(int argc, char **argv) { + GeneratorContext ctx(get_jit_target_from_environment()); + (void)create_callable_from_generator(ctx, "gen_input_buffer_func_collision"); + + printf("Should not get here\n"); + return 0; +} diff --git a/test/error/input_param_func_name_collision.cpp b/test/error/input_param_func_name_collision.cpp new file mode 100644 index 000000000000..212c555cb8b5 --- /dev/null +++ b/test/error/input_param_func_name_collision.cpp @@ -0,0 +1,22 @@ +#include "Halide.h" +#include + +using namespace Halide; + +int main(int argc, char **argv) { + // Func declared before a scalar Param of the same name. Lowering + // should produce a clean user_error rather than crashing. + Var x; + Func existing("foo"); + existing(x) = x; + + Param p("foo"); + + Func out("out"); + out(x) = existing(x) + p; + + out.compile_jit(); + + printf("Should not get here\n"); + return 0; +}