diff --git a/python_bindings/src/halide/halide_/PyBuffer.cpp b/python_bindings/src/halide/halide_/PyBuffer.cpp index 7b32f9824349..6d1461a33ec4 100644 --- a/python_bindings/src/halide/halide_/PyBuffer.cpp +++ b/python_bindings/src/halide/halide_/PyBuffer.cpp @@ -587,27 +587,27 @@ void define_buffer(py::module &m) { }) .def( - "copy_to_device", [](Buffer<> &b, const Target &t) -> int { - return b.copy_to_device(t); + "copy_to_device", [](Buffer<> &b, const Target &target) -> int { + return b.copy_to_device(to_jit_target(target)); }, - py::arg("target") = get_jit_target_from_environment()) + py::arg("target") = Target()) .def( - "copy_to_device", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int { - return b.copy_to_device(d, t); + "copy_to_device", [](Buffer<> &b, const DeviceAPI &d, const Target &target) -> int { + return b.copy_to_device(d, to_jit_target(target)); }, - py::arg("device_api"), py::arg("target") = get_jit_target_from_environment()) + py::arg("device_api"), py::arg("target") = Target()) .def( - "device_malloc", [](Buffer<> &b, const Target &t) -> int { - return b.device_malloc(t); + "device_malloc", [](Buffer<> &b, const Target &target) -> int { + return b.device_malloc(to_jit_target(target)); }, - py::arg("target") = get_jit_target_from_environment()) + py::arg("target") = Target()) .def( - "device_malloc", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int { - return b.device_malloc(d, t); + "device_malloc", [](Buffer<> &b, const DeviceAPI &d, const Target &target) -> int { + return b.device_malloc(d, to_jit_target(target)); }, - py::arg("device_api"), py::arg("target") = get_jit_target_from_environment()) + py::arg("device_api"), py::arg("target") = Target()) .def( "set_min", [](Buffer<> &b, const std::vector &mins) -> void { diff --git a/python_bindings/src/halide/halide_/PyFunc.cpp b/python_bindings/src/halide/halide_/PyFunc.cpp index 966b35bb4388..6c6e38ec7501 100644 --- a/python_bindings/src/halide/halide_/PyFunc.cpp +++ b/python_bindings/src/halide/halide_/PyFunc.cpp @@ -212,39 +212,98 @@ void define_func(py::module &m) { .def("store_in", &Func::store_in, py::arg("memory_type")) - .def("compile_to", &Func::compile_to, py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_object", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_object", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_header", &Func::compile_to_header, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) - - .def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_c", &Func::compile_to_c, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) - - .def("compile_to_lowered_stmt", &Func::compile_to_lowered_stmt, py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = get_target_from_environment()) - - .def("compile_to_file", &Func::compile_to_file, py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) - - .def("compile_to_static_library", &Func::compile_to_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) + .def( + "compile_to", [](Func &f, const std::map &output_files, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to(output_files, args, fn_name, to_aot_target(target)); + }, + py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_bitcode", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_bitcode(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_bitcode", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_bitcode(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_llvm_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_llvm_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_object(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_object(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_header", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_header(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_c", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_c(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_lowered_stmt", [](Func &f, const std::string &filename, const std::vector &args, StmtOutputFormat fmt, const Target &target) { + f.compile_to_lowered_stmt(filename, args, fmt, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = Target()) + .def( + "compile_to_file", [](Func &f, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_file(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_static_library", [](Func &f, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_static_library(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) .def("compile_to_multitarget_static_library", &Func::compile_to_multitarget_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets")) .def("compile_to_multitarget_object_files", &Func::compile_to_multitarget_object_files, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"), py::arg("suffixes")) // TODO: useless until Module is defined. - .def("compile_to_module", &Func::compile_to_module, py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) + .def( + "compile_to_module", [](Func &f, const std::vector &args, const std::string &fn_name, const Target &target) -> Module { + return f.compile_to_module(args, fn_name, to_aot_target(target)); + }, + py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) - .def("compile_jit", &Func::compile_jit, py::arg("target") = get_jit_target_from_environment()) + .def( + "compile_jit", [](Func &f, const Target &target) { + f.compile_jit(to_jit_target(target)); + }, + py::arg("target") = Target()) - .def("compile_to_callable", &Func::compile_to_callable, py::arg("arguments"), py::arg("target") = get_jit_target_from_environment()) + .def( + "compile_to_callable", [](Func &f, const std::vector &args, const Target &target) { + return f.compile_to_callable(args, to_jit_target(target)); + }, + py::arg("arguments"), py::arg("target") = Target()) .def("has_update_definition", &Func::has_update_definition) .def("num_update_definitions", &Func::num_update_definitions) @@ -285,10 +344,11 @@ void define_func(py::module &m) { .def( "infer_input_bounds", [](Func &f, const py::object &dst, const Target &target) -> void { + const Target t = to_jit_target(target); // dst could be Buffer<>, vector, or vector try { Buffer<> b = dst.cast>(); - f.infer_input_bounds(b, target); + f.infer_input_bounds(b, t); return; } catch (...) { // fall thru @@ -296,7 +356,7 @@ void define_func(py::module &m) { try { std::vector> v = dst.cast>>(); - f.infer_input_bounds(Realization(std::move(v)), target); + f.infer_input_bounds(Realization(std::move(v)), t); return; } catch (...) { // fall thru @@ -304,7 +364,7 @@ void define_func(py::module &m) { try { std::vector v = dst.cast>(); - f.infer_input_bounds(v, target); + f.infer_input_bounds(v, t); return; } catch (...) { // fall thru @@ -312,7 +372,7 @@ void define_func(py::module &m) { throw py::value_error("Invalid arguments to infer_input_bounds"); }, - py::arg("dst"), py::arg("target") = get_jit_target_from_environment()) + py::arg("dst"), py::arg("target") = Target()) .def("in_", (Func(Func::*)(const Func &)) & Func::in, py::arg("f")) .def("in_", (Func(Func::*)(const std::vector &fs)) & Func::in, py::arg("fs")) diff --git a/python_bindings/src/halide/halide_/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp index d18ccf01b725..430ad690420d 100644 --- a/python_bindings/src/halide/halide_/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -111,5 +111,19 @@ std::vector collect_print_args(const py::args &args) { return v; } +Target to_jit_target(const Target &target) { + if (target != Target()) { + return target; + } + return get_jit_target_from_environment(); +} + +Target to_aot_target(const Target &target) { + if (target != Target()) { + return target; + } + return get_target_from_environment(); +} + } // namespace PythonBindings } // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyHalide.h b/python_bindings/src/halide/halide_/PyHalide.h index 2eefb1f463bf..64003c339bb3 100644 --- a/python_bindings/src/halide/halide_/PyHalide.h +++ b/python_bindings/src/halide/halide_/PyHalide.h @@ -34,6 +34,8 @@ std::vector args_to_vector(const py::args &args, size_t start_offset = 0, siz std::vector collect_print_args(const py::args &args); Expr double_to_expr_check(double v); +Target to_jit_target(const Target &target); +Target to_aot_target(const Target &target); } // namespace PythonBindings } // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyPipeline.cpp b/python_bindings/src/halide/halide_/PyPipeline.cpp index 2300ab5e76cb..069ea9394df8 100644 --- a/python_bindings/src/halide/halide_/PyPipeline.cpp +++ b/python_bindings/src/halide/halide_/PyPipeline.cpp @@ -77,40 +77,98 @@ void define_pipeline(py::module &m) { py::arg("index")) .def("print_loop_nest", &Pipeline::print_loop_nest) - .def("compile_to", &Pipeline::compile_to, - py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_bitcode", &Pipeline::compile_to_bitcode, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_llvm_assembly", &Pipeline::compile_to_llvm_assembly, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_object", &Pipeline::compile_to_object, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_header", &Pipeline::compile_to_header, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_assembly", &Pipeline::compile_to_assembly, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_c", &Pipeline::compile_to_c, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_file", &Pipeline::compile_to_file, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_static_library", &Pipeline::compile_to_static_library, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_lowered_stmt", &Pipeline::compile_to_lowered_stmt, - py::arg("filename"), py::arg("arguments"), py::arg("format") = StmtOutputFormat::Text, py::arg("target") = get_target_from_environment()) - - .def("compile_to_multitarget_static_library", &Pipeline::compile_to_multitarget_static_library, - py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets")) - .def("compile_to_multitarget_object_files", &Pipeline::compile_to_multitarget_object_files, - py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"), py::arg("suffixes")) - - .def("compile_to_module", &Pipeline::compile_to_module, - py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment(), py::arg("linkage") = LinkageType::ExternalPlusMetadata) - - .def("compile_jit", &Pipeline::compile_jit, py::arg("target") = get_jit_target_from_environment()) - - .def("compile_to_callable", &Pipeline::compile_to_callable, py::arg("arguments"), py::arg("target") = get_jit_target_from_environment()) + .def( + "compile_to", [](Pipeline &p, const std::map &output_files, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to(output_files, args, fn_name, to_aot_target(target)); + }, + py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + + .def( + "compile_to_bitcode", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_bitcode(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_bitcode", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_bitcode(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_llvm_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_llvm_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_object(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_object(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_header", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_header(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_c", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_c(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_lowered_stmt", [](Pipeline &p, const std::string &filename, const std::vector &args, StmtOutputFormat fmt, const Target &target) { + p.compile_to_lowered_stmt(filename, args, fmt, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = Target()) + .def( + "compile_to_file", [](Pipeline &p, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_file(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_static_library", [](Pipeline &p, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_static_library(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + + .def("compile_to_multitarget_static_library", &Pipeline::compile_to_multitarget_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets")) + .def("compile_to_multitarget_object_files", &Pipeline::compile_to_multitarget_object_files, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"), py::arg("suffixes")) + + .def( + "compile_to_module", [](Pipeline &p, const std::vector &args, const std::string &fn_name, const Target &target, LinkageType linkage_type) -> Module { + return p.compile_to_module(args, fn_name, to_aot_target(target), linkage_type); + }, + py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target(), py::arg("linkage") = LinkageType::ExternalPlusMetadata) + + .def( + "compile_jit", [](Pipeline &p, const Target &target) { + p.compile_jit(to_jit_target(target)); + }, + py::arg("target") = Target()) + + .def( + "compile_to_callable", [](Pipeline &p, const std::vector &args, const Target &target) { + return p.compile_to_callable(args, to_jit_target(target)); + }, + py::arg("arguments"), py::arg("target") = Target()) .def( "realize", [](Pipeline &p, Buffer<> buffer, const Target &target) -> void { @@ -146,10 +204,11 @@ void define_pipeline(py::module &m) { .def( "infer_input_bounds", [](Pipeline &p, const py::object &dst, const Target &target) -> void { + const Target t = to_jit_target(target); // dst could be Buffer<>, vector, or vector try { Buffer<> b = dst.cast>(); - p.infer_input_bounds(b, target); + p.infer_input_bounds(b, t); return; } catch (...) { // fall thru @@ -157,7 +216,7 @@ void define_pipeline(py::module &m) { try { std::vector> v = dst.cast>>(); - p.infer_input_bounds(Realization(std::move(v)), target); + p.infer_input_bounds(Realization(std::move(v)), t); return; } catch (...) { // fall thru @@ -165,7 +224,7 @@ void define_pipeline(py::module &m) { try { std::vector v = dst.cast>(); - p.infer_input_bounds(v, target); + p.infer_input_bounds(v, t); return; } catch (...) { // fall thru @@ -173,7 +232,7 @@ void define_pipeline(py::module &m) { throw py::value_error("Invalid arguments to infer_input_bounds"); }, - py::arg("dst"), py::arg("target") = get_jit_target_from_environment()) + py::arg("dst"), py::arg("target") = Target()) .def("infer_arguments", [](Pipeline &p) -> std::vector { return p.infer_arguments();