Skip to content

Commit

Permalink
Change early-bound default args in Python bindings to late-bound (#7347)
Browse files Browse the repository at this point in the history
In PyBind11, if you specify a default argument for a method, it is evaluated when the Python module is initialized, *not* when the method is called (as you might expect in C++). For defaults that are just constants/literals, this is no big deal, but when calling get_*_target_from_environment, this means it is called at module init time -- also normally not a big deal (since the values ~never change at runtime anyway), with one big exception (no pun intended): if the function throws an exception (e.g. via calling user_assert() or similar), that exception is thrown at Module-initialization time, which is a much more inscrutable crash, and one that is very hard to recover from.

This may seem unlikely, but can happen pretty easily if you set (say) HL_JIT_TARGET=host-cuda (or other gpu) and the given GPU runtime isn't present on the given system; the current behavior is basically "make if impossible for the libHalidePython bindings to run", whereas what we want is "runtime exception thrown when you call the method".

This changes the relevant methods to use `Target()` as the default value, and inside the method wrapper, if the value passed equals `Target()`, it replaces the value with the righ `get_*_target_from_environment()` call.

(This turned up while doing some testing of #6924 on a system without Vulkan available)
  • Loading branch information
steven-johnson authored Feb 14, 2023
1 parent 8bd07fb commit 7963cd4
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 80 deletions.
24 changes: 12 additions & 12 deletions python_bindings/src/halide/halide_/PyBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,27 +587,27 @@ void define_buffer(py::module &m) {
})

.def(
"copy_to_device", [](Buffer<> &b, const Target &t) -> int {
return b.copy_to_device(t);
"copy_to_device", [](Buffer<> &b, const Target &target) -> int {
return b.copy_to_device(to_jit_target(target));
},
py::arg("target") = get_jit_target_from_environment())
py::arg("target") = Target())

.def(
"copy_to_device", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int {
return b.copy_to_device(d, t);
"copy_to_device", [](Buffer<> &b, const DeviceAPI &d, const Target &target) -> int {
return b.copy_to_device(d, to_jit_target(target));
},
py::arg("device_api"), py::arg("target") = get_jit_target_from_environment())
py::arg("device_api"), py::arg("target") = Target())
.def(
"device_malloc", [](Buffer<> &b, const Target &t) -> int {
return b.device_malloc(t);
"device_malloc", [](Buffer<> &b, const Target &target) -> int {
return b.device_malloc(to_jit_target(target));
},
py::arg("target") = get_jit_target_from_environment())
py::arg("target") = Target())

.def(
"device_malloc", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int {
return b.device_malloc(d, t);
"device_malloc", [](Buffer<> &b, const DeviceAPI &d, const Target &target) -> int {
return b.device_malloc(d, to_jit_target(target));
},
py::arg("device_api"), py::arg("target") = get_jit_target_from_environment())
py::arg("device_api"), py::arg("target") = Target())

.def(
"set_min", [](Buffer<> &b, const std::vector<int> &mins) -> void {
Expand Down
120 changes: 90 additions & 30 deletions python_bindings/src/halide/halide_/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,39 +212,98 @@ void define_func(py::module &m) {

.def("store_in", &Func::store_in, py::arg("memory_type"))

.def("compile_to", &Func::compile_to, py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment())

.def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector<Argument> &, const std::string &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment())
.def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector<Argument> &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment())

.def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector<Argument> &, const std::string &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment())
.def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector<Argument> &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment())

.def("compile_to_object", (void(Func::*)(const std::string &, const std::vector<Argument> &, const std::string &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment())
.def("compile_to_object", (void(Func::*)(const std::string &, const std::vector<Argument> &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment())

.def("compile_to_header", &Func::compile_to_header, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment())

.def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector<Argument> &, const std::string &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment())
.def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector<Argument> &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment())

.def("compile_to_c", &Func::compile_to_c, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment())

.def("compile_to_lowered_stmt", &Func::compile_to_lowered_stmt, py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = get_target_from_environment())

.def("compile_to_file", &Func::compile_to_file, py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment())

.def("compile_to_static_library", &Func::compile_to_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment())
.def(
"compile_to", [](Func &f, const std::map<OutputFileType, std::string> &output_files, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to(output_files, args, fn_name, to_aot_target(target));
},
py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target())
.def(
"compile_to_bitcode", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_bitcode(filename, args, fn_name, to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target())
.def(
"compile_to_bitcode", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const Target &target) {
f.compile_to_bitcode(filename, args, "", to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("target") = Target())
.def(
"compile_to_llvm_assembly", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_llvm_assembly(filename, args, fn_name, to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target())
.def(
"compile_to_llvm_assembly", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const Target &target) {
f.compile_to_llvm_assembly(filename, args, "", to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("target") = Target())
.def(
"compile_to_object", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_object(filename, args, fn_name, to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target())
.def(
"compile_to_object", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const Target &target) {
f.compile_to_object(filename, args, "", to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("target") = Target())
.def(
"compile_to_header", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_header(filename, args, fn_name, to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target())
.def(
"compile_to_assembly", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_assembly(filename, args, fn_name, to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target())
.def(
"compile_to_assembly", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const Target &target) {
f.compile_to_assembly(filename, args, "", to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("target") = Target())
.def(
"compile_to_c", [](Func &f, const std::string &filename, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_c(filename, args, fn_name, to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target())
.def(
"compile_to_lowered_stmt", [](Func &f, const std::string &filename, const std::vector<Argument> &args, StmtOutputFormat fmt, const Target &target) {
f.compile_to_lowered_stmt(filename, args, fmt, to_aot_target(target));
},
py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = Target())
.def(
"compile_to_file", [](Func &f, const std::string &filename_prefix, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_file(filename_prefix, args, fn_name, to_aot_target(target));
},
py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target())
.def(
"compile_to_static_library", [](Func &f, const std::string &filename_prefix, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) {
f.compile_to_static_library(filename_prefix, args, fn_name, to_aot_target(target));
},
py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target())

.def("compile_to_multitarget_static_library", &Func::compile_to_multitarget_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"))
.def("compile_to_multitarget_object_files", &Func::compile_to_multitarget_object_files, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"), py::arg("suffixes"))

// TODO: useless until Module is defined.
.def("compile_to_module", &Func::compile_to_module, py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment())
.def(
"compile_to_module", [](Func &f, const std::vector<Argument> &args, const std::string &fn_name, const Target &target) -> Module {
return f.compile_to_module(args, fn_name, to_aot_target(target));
},
py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target())

.def("compile_jit", &Func::compile_jit, py::arg("target") = get_jit_target_from_environment())
.def(
"compile_jit", [](Func &f, const Target &target) {
f.compile_jit(to_jit_target(target));
},
py::arg("target") = Target())

.def("compile_to_callable", &Func::compile_to_callable, py::arg("arguments"), py::arg("target") = get_jit_target_from_environment())
.def(
"compile_to_callable", [](Func &f, const std::vector<Argument> &args, const Target &target) {
return f.compile_to_callable(args, to_jit_target(target));
},
py::arg("arguments"), py::arg("target") = Target())

.def("has_update_definition", &Func::has_update_definition)
.def("num_update_definitions", &Func::num_update_definitions)
Expand Down Expand Up @@ -285,34 +344,35 @@ void define_func(py::module &m) {

.def(
"infer_input_bounds", [](Func &f, const py::object &dst, const Target &target) -> void {
const Target t = to_jit_target(target);
// dst could be Buffer<>, vector<Buffer>, or vector<int>
try {
Buffer<> b = dst.cast<Buffer<>>();
f.infer_input_bounds(b, target);
f.infer_input_bounds(b, t);
return;
} catch (...) {
// fall thru
}

try {
std::vector<Buffer<>> v = dst.cast<std::vector<Buffer<>>>();
f.infer_input_bounds(Realization(std::move(v)), target);
f.infer_input_bounds(Realization(std::move(v)), t);
return;
} catch (...) {
// fall thru
}

try {
std::vector<int32_t> v = dst.cast<std::vector<int32_t>>();
f.infer_input_bounds(v, target);
f.infer_input_bounds(v, t);
return;
} catch (...) {
// fall thru
}

throw py::value_error("Invalid arguments to infer_input_bounds");
},
py::arg("dst"), py::arg("target") = get_jit_target_from_environment())
py::arg("dst"), py::arg("target") = Target())

.def("in_", (Func(Func::*)(const Func &)) & Func::in, py::arg("f"))
.def("in_", (Func(Func::*)(const std::vector<Func> &fs)) & Func::in, py::arg("fs"))
Expand Down
14 changes: 14 additions & 0 deletions python_bindings/src/halide/halide_/PyHalide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,19 @@ std::vector<Expr> collect_print_args(const py::args &args) {
return v;
}

Target to_jit_target(const Target &target) {
if (target != Target()) {
return target;
}
return get_jit_target_from_environment();
}

Target to_aot_target(const Target &target) {
if (target != Target()) {
return target;
}
return get_target_from_environment();
}

} // namespace PythonBindings
} // namespace Halide
2 changes: 2 additions & 0 deletions python_bindings/src/halide/halide_/PyHalide.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ std::vector<T> args_to_vector(const py::args &args, size_t start_offset = 0, siz

std::vector<Expr> collect_print_args(const py::args &args);
Expr double_to_expr_check(double v);
Target to_jit_target(const Target &target);
Target to_aot_target(const Target &target);

} // namespace PythonBindings
} // namespace Halide
Expand Down
Loading

0 comments on commit 7963cd4

Please # to comment.