Skip to content

Commit 5759870

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Update Demo Scripts To Use .ptd (retry)
Summary: add options for checkpoint loading in the demo script Differential Revision: D70498387
1 parent 9227cdc commit 5759870

File tree

4 files changed

+55
-18
lines changed

4 files changed

+55
-18
lines changed

CMakeLists.txt

+5-4
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,15 @@ cmake_dependent_option(
248248
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
249249
)
250250

251-
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
251+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
252252
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
253+
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
254+
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
255+
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
253256
endif()
254257

255-
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
256-
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
258+
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
257259
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
258-
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
259260
endif()
260261

261262
if(EXECUTORCH_BUILD_EXTENSION_MODULE)

extension/training/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ target_include_directories(
2626
target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..)
2727
target_compile_options(extension_training PUBLIC ${_common_compile_options})
2828
target_link_libraries(extension_training executorch_core
29-
extension_data_loader extension_module extension_tensor)
29+
extension_data_loader extension_module extension_tensor extension_flat_tensor)
3030

3131

3232
list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/")

extension/training/examples/XOR/export_model.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
import os
1212

1313
import torch
14-
from executorch.exir import to_edge
14+
from executorch.exir import ExecutorchBackendConfig, to_edge
1515

1616
from executorch.extension.training.examples.XOR.model import Net, TrainingNet
1717
from torch.export import export
1818
from torch.export.experimental import _export_forward_backward
1919

2020

21-
def _export_model():
21+
def _export_model(external_mutable_weights: bool = False):
2222
net = TrainingNet(Net())
2323
x = torch.randn(1, 2)
2424

@@ -30,7 +30,11 @@ def _export_model():
3030
# Lower the graph to edge dialect.
3131
ep = to_edge(ep)
3232
# Lower the graph to executorch.
33-
ep = ep.to_executorch()
33+
ep = ep.to_executorch(
34+
config=ExecutorchBackendConfig(
35+
external_mutable_weights=external_mutable_weights
36+
)
37+
)
3438
return ep
3539

3640

@@ -44,19 +48,27 @@ def main() -> None:
4448
"--outdir",
4549
type=str,
4650
required=True,
47-
help="Path to the directory to write xor.pte files to",
51+
help="Path to the directory to write xor.pte and xor.ptd files to",
52+
)
53+
parser.add_argument(
54+
"--external",
55+
action="store_true",
56+
help="Export the model with external weights",
4857
)
4958
args = parser.parse_args()
5059

51-
ep = _export_model()
60+
ep = _export_model(args.external)
5261

5362
# Write out the .pte file.
5463
os.makedirs(args.outdir, exist_ok=True)
5564
outfile = os.path.join(args.outdir, "xor.pte")
5665
with open(outfile, "wb") as fp:
57-
fp.write(
58-
ep.buffer,
59-
)
66+
ep.write_to_file(fp)
67+
68+
if args.external:
69+
# current infra doesnt easily allow renaming this file, so just hackily do it here.
70+
ep._tensor_data["xor"] = ep._tensor_data.pop("_default_external_constant")
71+
ep.write_tensor_data_to_file(args.outdir)
6072

6173

6274
if __name__ == "__main__":

extension/training/examples/XOR/train.cpp

+29-5
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@ using executorch::extension::training::optimizer::SGDOptions;
2323
using executorch::runtime::Error;
2424
using executorch::runtime::Result;
2525
DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");
26+
DEFINE_string(ptd_path, "", "Model weights serialized in flatbuffer format.");
2627

2728
int main(int argc, char** argv) {
2829
gflags::ParseCommandLineFlags(&argc, &argv, true);
29-
if (argc != 1) {
30+
if (argc == 0) {
31+
ET_LOG(Error, "Please provide a model path.");
32+
return 1;
33+
} else if (argc > 2) {
3034
std::string msg = "Extra commandline args: ";
31-
for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
35+
for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path) */;
36+
i < argc;
37+
i++) {
3238
msg += argv[i];
3339
}
3440
ET_LOG(Error, "%s", msg.c_str());
@@ -46,7 +52,21 @@ int main(int argc, char** argv) {
4652
auto loader = std::make_unique<executorch::extension::FileDataLoader>(
4753
std::move(loader_res.get()));
4854

49-
auto mod = executorch::extension::training::TrainingModule(std::move(loader));
55+
std::unique_ptr<executorch::extension::FileDataLoader> ptd_loader = nullptr;
56+
if (!FLAGS_ptd_path.empty()) {
57+
executorch::runtime::Result<executorch::extension::FileDataLoader>
58+
ptd_loader_res =
59+
executorch::extension::FileDataLoader::from(FLAGS_ptd_path.c_str());
60+
if (ptd_loader_res.error() != Error::Ok) {
61+
ET_LOG(Error, "Failed to open ptd file: %s", FLAGS_ptd_path.c_str());
62+
return 1;
63+
}
64+
ptd_loader = std::make_unique<executorch::extension::FileDataLoader>(
65+
std::move(ptd_loader_res.get()));
66+
}
67+
68+
auto mod = executorch::extension::training::TrainingModule(
69+
std::move(loader), nullptr, nullptr, nullptr, std::move(ptd_loader));
5070

5171
// Create full data set of input and labels.
5272
std::vector<std::pair<
@@ -70,7 +90,10 @@ int main(int argc, char** argv) {
7090
// Get the params and names
7191
auto param_res = mod.named_parameters("forward");
7292
if (param_res.error() != Error::Ok) {
73-
ET_LOG(Error, "Failed to get named parameters");
93+
ET_LOG(
94+
Error,
95+
"Failed to get named parameters, error: %d",
96+
static_cast<int>(param_res.error()));
7497
return 1;
7598
}
7699

@@ -112,5 +135,6 @@ int main(int argc, char** argv) {
112135
std::string(param.first.data()), param.second});
113136
}
114137

115-
executorch::extension::flat_tensor::save_ptd("xor.ptd", param_map, 16);
138+
executorch::extension::flat_tensor::save_ptd(
139+
"trained_xor.ptd", param_map, 16);
116140
}

0 commit comments

Comments
 (0)