forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmps_preprocess.py
287 lines (241 loc) · 10 KB
/
mps_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#
# Copyright (c) 2023 Apple Inc. All rights reserved.
# Provided subject to the LICENSE file in the top level directory.
#
import logging
from typing import ClassVar, Dict, final, List, Tuple
import torch
from executorch.backends.apple.mps.operators.node_visitor import (
get_node_visitors,
NodeVisitor,
process_output_node,
process_placeholder_nodes,
)
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
Buffer,
DataSegment,
MPSGraph,
MPSTensor,
OpType,
)
from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
convert_to_flatbuffer,
)
from executorch.exir._serialize._program import Cord
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
from executorch.exir.program._program import _transform
from torch.export.exported_program import ExportedProgram
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
@final
class MPSBackend(BackendDetails):
@staticmethod
def slice_len_max(s):
assert s.start is not None
assert s.stop is not None
step = 1
if s.step is not None:
step = s.step
return max((s.stop - s.start) // step, 1)
MAGIC_IX: ClassVar[slice] = slice(4, 8)
DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16)
DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24)
# magic bytes that should be at the beginning of the header
EXPECTED_MAGIC: ClassVar[bytes] = b"MP00"
# The length of the header in bytes
EXPECTED_LENGTH: ClassVar[int] = (
4
+ slice_len_max(MAGIC_IX)
+ slice_len_max(DATA_SEGMENT_OFFSET_IX)
+ slice_len_max(DATA_SEGMENT_SIZE_IX)
)
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
# The EdgeIR nodes are processed in the following order:
# 1. Process first the input feeds to the graph (in the same
# order as args from forward(*args)), and generate a unique
# id for each input placeholder. Each input id is appended to
# `input_ids` array from the FlatBuffer schema.
# 2. Process the nodes the graph (e.g `call_function`). For each
# EdgeIR node, create an equivalent MPS node in the FlatBuffer,
# based on which the MPSGraph is constructed at runtime. During
# this process, any visited constant in the EdgeIR is added to the
# final MPS FlatBuffer schema. Each constant id is appended to the
# `constant_ids` FlatBuffer schema.
# 3. After all the inputs, nodes and constants are added to the
# FlatBuffer graph, process the `output` nodes and add their id to
# the `output_ids` array in the schema.
# TODO: Remove this once we have a better support for the dim-order ops.
edge_program = _transform(edge_program, DimOrderOpsRevertPass())
mps_graph = MPSGraph(
version="0",
mps_nodes=[],
mps_values=[],
input_ids=[],
output_ids=[],
constant_ids=[],
graph_type=OpType.mps_graph,
constant_segment=DataSegment(0, 0),
)
convert_model_to_fp16 = True
for spec in compile_specs:
if spec.key == "use_fp16":
convert_model_to_fp16 = bool(list(bytes(spec.value))[0])
logging.debug(f"Convert model to FP16: {convert_model_to_fp16}")
node_visitors = get_node_visitors(edge_program, convert_model_to_fp16)
if logging.DEBUG >= logging.root.level:
edge_program.graph.print_tabular()
process_placeholder_nodes(
edge_program,
edge_program.graph_module,
mps_graph,
node_visitors["placeholder"],
)
op_handler = {
"call_function": MPSBackend.handle_call_function,
"placeholder": MPSBackend.handle_placeholder,
"output": MPSBackend.handle_output,
"get_attr": MPSBackend.handle_get_attr,
}
for node in edge_program.graph_module.graph.nodes:
if node.op not in op_handler:
raise RuntimeError(f"{node.op} is not supported in MPS")
else:
op_handler[node.op](edge_program, node_visitors, node, mps_graph)
segment_data, mps_graph = _extract_constant_segment(mps_graph)
if logging.DEBUG >= logging.root.level:
pretty_print(mps_graph)
# Add to aggregate segments cord with padding.
padding_length = _padding_required(len(segment_data), 16)
if padding_length > 0:
segment_data.append(b"\x00" * padding_length)
# Combine mps_graph with segment data
combined = Cord()
graph_bytes = convert_to_flatbuffer(mps_graph)
data_segment_offset: int = MPSBackend.EXPECTED_LENGTH
data_segment_offset = data_segment_offset + len(graph_bytes)
graph_padding_length = _padding_required(data_segment_offset, 16)
data_segment_offset = data_segment_offset + graph_padding_length
data_segment_size = len(segment_data)
data: bytes = (
b"\x00\x00\x00\x00"
+ MPSBackend.EXPECTED_MAGIC
+ data_segment_offset.to_bytes(8, byteorder="little")
+ data_segment_size.to_bytes(8, byteorder="little")
)
assert len(data) == MPSBackend.EXPECTED_LENGTH
combined.append(data)
combined.append(graph_bytes)
if graph_padding_length > 0:
combined.append(b"\x00" * graph_padding_length)
# Append the segment data to the end of the mps graph
combined.append(segment_data)
return PreprocessResult(processed_bytes=bytes(combined))
@staticmethod
def handle_call_function(
_: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
logging.info(f"Visiting: {node}, {node.target.__name__}")
if (
"delegation_tag" in node.meta
and "metal_kernel" in node.meta["delegation_tag"]
):
logging.info(
f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!"
)
mps_graph.graph_type = OpType.metal_kernel
if node.target.__name__ in node_visitors:
node_visitors[node.target.__name__].define_node(node, mps_graph)
else:
pretty_print(mps_graph)
raise RuntimeError(
f"For {node}, {node.op}:{node.target.__name__} is not supported in MPS delegate"
)
@staticmethod
def handle_placeholder(
edge_program: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
# Constants are handled directly when visiting the nodes.
pass
@staticmethod
def handle_output(
edge_program: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
for output_nodes in node.args:
for output_node in output_nodes:
process_output_node(output_node, mps_graph, node_visitors[node.op])
@staticmethod
def handle_get_attr(
edge_program: ExportedProgram,
node_visitors: Dict[str, NodeVisitor],
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
pass
def _padding_required(offset: int, alignment: int) -> int:
"""Returns the padding required to align `offset` to `alignment`."""
remainder: int = offset % alignment
if remainder != 0:
return alignment - remainder
return 0
def _extract_constant_segment(mps_graph: MPSGraph) -> Tuple[Cord, MPSGraph]:
"""Extracts the constant segment from the MPSGraph and returns the updated MPSGraph along with the segment data."""
# Note that the beginning of the segment data is not aligned. Need to handle out of this call.
segment_data = Cord()
offset = 0
for i in range(len(mps_graph.mps_values)):
tensor = mps_graph.mps_values[i]
if tensor.constant_buffer_size > 0:
# Notice that buffer is already force aligned so we don't need to pad it
segment_data.append(tensor.constant_buffer.storage)
# Reset buffer to empty
tensor.constant_buffer = Buffer(storage=b"")
# Update segment offset
tensor.segment_offset = offset
offset += tensor.constant_buffer_size
return segment_data, mps_graph
def tensor_to_str(mps_tensor: MPSTensor):
tensor_str = "MPSTensor("
tensor_str += "datatype=" + str(mps_tensor.datatype) + ", "
tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", "
tensor_str += "dims=" + str(mps_tensor.dims) + ", "
tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", "
tensor_str += "segment_offset=" + str(mps_tensor.segment_offset)
tensor_str += ")"
return tensor_str
def pretty_print(mps_graph: MPSGraph):
logging.info("Serialized MPSGraph:")
logging.info(f" Version: {mps_graph.version}")
logging.info(" MPS nodes: ")
for i in range(len(mps_graph.mps_nodes)):
logging.info(f" [{i}]: {mps_graph.mps_nodes[i]}")
logging.info(" MPS values: ")
for i in range(len(mps_graph.mps_values)):
logging.info(f" [{i}]: {tensor_to_str(mps_graph.mps_values[i])}")
logging.info(" Input ids:")
for in_id in mps_graph.input_ids:
logging.info(f" {in_id}")
logging.info(" Constant ids:")
for constant_id in mps_graph.constant_ids:
logging.info(f" {constant_id}")
logging.info(" Output ids:")
for out_id in mps_graph.output_ids:
logging.info(f" {out_id}")
logging.info(f" Constant segment: {mps_graph.constant_segment}")