Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add Safetensors.write!/2 for streamed write #8

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 72 additions & 31 deletions lib/safetensors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,70 @@ defmodule Safetensors do

@dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k}

@doc """
Writes a map of tensors to a file.

Tensors are written into the file one by one, without the need to
dump all of them into the memory at once.
"""
@spec write!(path :: Path.t(), %{String.t() => Nx.Tensor.t()}) :: :ok
def write!(path, tensors) when is_map(tensors) do
File.open!(path, [:write, :raw], fn file ->
tensors = Enum.sort(tensors)

{header_entries, _offset} =
Enum.map_reduce(tensors, 0, fn {tensor_name, tensor}, offset ->
tensor_header_entry(tensor_name, tensor, offset)
end)

:ok = :file.write(file, header_binary(header_entries))

for {_tensor_name, tensor} <- tensors do
:ok = :file.write(file, tensor_to_binary(tensor))
end
end)

:ok
end

defp tensor_header_entry(tensor_name, tensor, offset) do
end_offset = offset + tensor_byte_size(tensor)

header_entry = {
tensor_name,
Jason.OrderedObject.new(
dtype: tensor |> Nx.type() |> type_to_dtype(),
shape: tensor |> Nx.shape() |> Tuple.to_list(),
data_offsets: [offset, end_offset]
)
}

{header_entry, end_offset}
end

defp header_binary(header_entries) do
header_json =
header_entries
|> Jason.OrderedObject.new()
|> Jason.encode!()

[<<byte_size(header_json)::unsigned-64-integer-little>>, header_json]
end

defp tensor_byte_size(tensor) do
{_, elem_size} = Nx.type(tensor)
elem_byte_size = div(elem_size, 8)
Nx.size(tensor) * elem_byte_size
end

defp tensor_to_binary(tensor) do
{_, elem_size} = Nx.type(tensor)

tensor
|> Nx.to_binary()
|> new_byte_order(elem_size, :little)
end

@doc """
Serializes the given map of tensors to iodata.

Expand All @@ -50,46 +114,23 @@ defmodule Safetensors do
"""
@spec dump(%{String.t() => Nx.Tensor.t()}) :: iodata()
def dump(tensors) when is_map(tensors) do
tensors = Enum.sort(tensors)

{header_entries, {buffer, _offset}} =
Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} ->
{_, elem_size} = Nx.type(tensor)

binary =
tensor
|> Nx.to_binary()
|> new_byte_order(elem_size, :little)

end_offset = offset + byte_size(binary)

header_entry = {
tensor_name,
Jason.OrderedObject.new(
dtype: tensor |> Nx.type() |> type_to_dtype(),
shape: tensor |> Nx.shape() |> Tuple.to_list(),
data_offsets: [offset, end_offset]
)
}

{header_entry, end_offset} = tensor_header_entry(tensor_name, tensor, offset)
binary = tensor_to_binary(tensor)
{header_entry, {[buffer, binary], end_offset}}
end)

header_json =
header_entries
|> Jason.OrderedObject.new()
|> Jason.encode!()

[
<<byte_size(header_json)::unsigned-64-integer-little>>,
header_json,
buffer
]
[header_binary(header_entries), buffer]
end

@doc """
Reads a safe tensor from file.
Reads a serialized map of tensors from a file.

Tensors are loaded into Nx one by one,
without the need to load the entire file from disk into memory.
Tensors are loaded into Nx one by one, without the need to load the
entire file from disk into memory.
"""
@spec read!(path :: Path.t()) :: %{String.t() => Nx.Tensor.t()}
def read!(path) do
Expand Down
14 changes: 14 additions & 0 deletions test/safetensors_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ defmodule SafetensorsTest do

doctest Safetensors

@tag :tmp_dir
test "write", %{tmp_dir: tmp_dir} do
path = Path.join(tmp_dir, "safetensor")

data = %{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)}
Safetensors.write!(path, data)

# source:
# https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L22-L25
# with the header padding removed and changed numbers
assert File.read!(path) ==
~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00)
end

test "dump" do
binary =
%{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)}
Expand Down
Loading