Skip to content

Commit c0850be

Browse files
authored
flambda-backend: unboxed float32 (#2520)
* squash for review * address comments
1 parent acdd1ae commit c0850be

40 files changed

+2303
-54
lines changed

asmcomp/cmmgen.ml

+2
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,15 @@ let exttype_of_sort (s : Jkind.Sort.const) =
363363
| Bits32 -> XInt32
364364
| Bits64 -> XInt64
365365
| Void -> Misc.fatal_error "Cmmgen.exttype_of_sort: void encountered"
366+
| Float32 -> Misc.fatal_error "Cmmgen.exttype_of_sort: float32 encountered"
366367

367368
let machtype_of_sort (s : Jkind.Sort.const) =
368369
match s with
369370
| Value -> typ_val
370371
| Float64 -> typ_float
371372
| Word | Bits32 | Bits64 -> typ_int
372373
| Void -> Misc.fatal_error "Cmmgen.machtype_of_sort: void encountered"
374+
| Float32 -> Misc.fatal_error "Cmmgen.machtype_of_sort: float32 encountered"
373375

374376
let is_unboxed_number_cmm ~strict cmm =
375377
let r = ref No_result in

bytecomp/symtable.ml

+2-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ let rec transl_const = function
152152
Const_base(Const_int i) -> Obj.repr i
153153
| Const_base(Const_char c) -> Obj.repr c
154154
| Const_base(Const_string (s, _, _)) -> Obj.repr s
155-
| Const_base(Const_float32 f) -> float32_of_string f
155+
| Const_base(Const_float32 f)
156+
| Const_base(Const_unboxed_float32 f) -> float32_of_string f
156157
| Const_base(Const_float f)
157158
| Const_base(Const_unboxed_float f) -> Obj.repr (float_of_string f)
158159
| Const_base(Const_int32 i)

lambda/lambda.ml

+2
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,7 @@ let constant_layout: constant -> layout = function
17481748
| Const_float _ -> Pvalue (Pboxedfloatval Pfloat64)
17491749
| Const_float32 _ -> Pvalue (Pboxedfloatval Pfloat32)
17501750
| Const_unboxed_float _ -> Punboxed_float Pfloat64
1751+
| Const_unboxed_float32 _ -> Punboxed_float Pfloat32
17511752

17521753
let structured_constant_layout = function
17531754
| Const_base const -> constant_layout const
@@ -1763,6 +1764,7 @@ let layout_of_extern_repr : extern_repr -> _ = function
17631764
begin match s with
17641765
| Value -> layout_any_value
17651766
| Float64 -> layout_unboxed_float Pfloat64
1767+
| Float32 -> layout_unboxed_float Pfloat32
17661768
| Word -> layout_unboxed_nativeint
17671769
| Bits32 -> layout_unboxed_int32
17681770
| Bits64 -> layout_unboxed_int64

lambda/matching.ml

+8-5
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,7 @@ let can_group discr pat =
11601160
| Constant (Const_float _), Constant (Const_float _)
11611161
| Constant (Const_float32 _), Constant (Const_float32 _)
11621162
| Constant (Const_unboxed_float _), Constant (Const_unboxed_float _)
1163+
| Constant (Const_unboxed_float32 _), Constant (Const_unboxed_float32 _)
11631164
| Constant (Const_int32 _), Constant (Const_int32 _)
11641165
| Constant (Const_int64 _), Constant (Const_int64 _)
11651166
| Constant (Const_nativeint _), Constant (Const_nativeint _)
@@ -1186,9 +1187,10 @@ let can_group discr pat =
11861187
( Any
11871188
| Constant
11881189
( Const_int _ | Const_char _ | Const_string _ | Const_float _
1189-
| Const_float32 _ | Const_unboxed_float _ | Const_int32 _
1190-
| Const_int64 _ | Const_nativeint _ | Const_unboxed_int32 _
1191-
| Const_unboxed_int64 _ | Const_unboxed_nativeint _ )
1190+
| Const_float32 _ | Const_unboxed_float _ | Const_unboxed_float32 _
1191+
| Const_int32 _ | Const_int64 _ | Const_nativeint _
1192+
| Const_unboxed_int32 _ | Const_unboxed_int64 _
1193+
| Const_unboxed_nativeint _ )
11921194
| Construct _ | Tuple _ | Record _ | Array _ | Variant _ | Lazy ) ) ->
11931195
false
11941196

@@ -2891,7 +2893,7 @@ let combine_constant value_kind loc arg cst partial ctx def
28912893
make_test_sequence value_kind loc fail (Pfloatcomp (Pfloat64, CFneq))
28922894
(Pfloatcomp (Pfloat64, CFlt)) arg
28932895
const_lambda_list
2894-
| Const_float32 _ ->
2896+
| Const_float32 _ | Const_unboxed_float32 _ ->
28952897
(* Should be caught in do_compile_matching. *)
28962898
Misc.fatal_error "Found unexpected float32 literal pattern."
28972899
| Const_unboxed_float _ ->
@@ -3567,7 +3569,8 @@ and do_compile_matching ~scopes value_kind repr partial ctx pmh =
35673569
compile_no_test ~scopes value_kind
35683570
(divide_record ~scopes lbl.lbl_all ph)
35693571
Context.combine repr partial ctx pm
3570-
| Constant (Const_float32 _) -> Parmatch.raise_matched_float32 ()
3572+
| Constant (Const_float32 _ | Const_unboxed_float32 _) ->
3573+
Parmatch.raise_matched_float32 ()
35713574
| Constant cst ->
35723575
compile_test
35733576
(compile_match ~scopes value_kind repr partial)

lambda/printlambda.ml

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ let rec struct_const ppf = function
2525
| Const_base(Const_string (s, _, _)) -> fprintf ppf "%S" s
2626
| Const_immstring s -> fprintf ppf "#%S" s
2727
| Const_base(Const_float f) -> fprintf ppf "%s" f
28-
| Const_base(Const_float32 f) -> fprintf ppf "%s" f
28+
| Const_base(Const_float32 f) -> fprintf ppf "%ss" f
2929
| Const_base(Const_unboxed_float f) ->
3030
fprintf ppf "%s" (Misc.format_as_unboxed_literal f)
31+
| Const_base(Const_unboxed_float32 f) ->
32+
fprintf ppf "%ss" (Misc.format_as_unboxed_literal f)
3133
| Const_base(Const_int32 n) -> fprintf ppf "%lil" n
3234
| Const_base(Const_int64 n) -> fprintf ppf "%LiL" n
3335
| Const_base(Const_nativeint n) -> fprintf ppf "%nin" n

lambda/translcore.ml

+3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ let layout_pat sort p = layout p.pat_env p.pat_loc sort p.pat_type
5656
let check_record_field_sort loc sort =
5757
match Jkind.Sort.get_default_value sort with
5858
| Value | Float64 | Bits32 | Bits64 | Word -> ()
59+
| Float32 ->
60+
(* CR mslater: (float32) float32# records *)
61+
Misc.fatal_error "Found unboxed float32 record field."
5962
| Void -> raise (Error (loc, Illegal_void_record_field))
6063

6164
(* Forward declaration -- to be filled in by Translmod.transl_module *)

lambda/translprim.ml

+2-1
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,10 @@ let lookup_primitive loc ~poly_mode ~poly_sort pos p =
700700
| "%obj_magic" -> Primitive(Pobj_magic layout, 1)
701701
| "%array_to_iarray" -> Primitive (Parray_to_iarray, 1)
702702
| "%array_of_iarray" -> Primitive (Parray_of_iarray, 1)
703-
(* CR mslater: (float32) unboxed *)
704703
| "%unbox_float" -> Primitive(Punbox_float Pfloat64, 1)
705704
| "%box_float" -> Primitive(Pbox_float (Pfloat64, mode), 1)
705+
| "%unbox_float32" -> Primitive(Punbox_float Pfloat32, 1)
706+
| "%box_float32" -> Primitive(Pbox_float (Pfloat32, mode), 1)
706707
| "%get_header" -> Primitive (Pget_header mode, 1)
707708
| "%atomic_load" ->
708709
Primitive ((Patomic_load {immediate_or_pointer=Pointer}), 1)

middle_end/closure/closure.ml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ let rec close ({ backend; fenv; cenv ; mutable_vars; kinds; catch_env } as env)
10271027
| Const_base (Const_string (s, _, _)) ->
10281028
str (Uconst_string s)
10291029
| Const_base(Const_float x) -> str (Uconst_float (float_of_string x))
1030-
| Const_base(Const_float32 _) ->
1030+
| Const_base(Const_float32 _ | Const_unboxed_float32 _) ->
10311031
Misc.fatal_error "float32 is not supported in closure. Consider using flambda2."
10321032
| Const_base (Const_unboxed_float _ | Const_unboxed_int32 _
10331033
| Const_unboxed_int64 _ | Const_unboxed_nativeint _) ->

middle_end/flambda/closure_conversion.ml

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ let rec declare_const t (const : Lambda.structured_constant)
137137
register_const t
138138
(Allocated_const (Float (float_of_string c)))
139139
Names.const_float
140-
| Const_base (Const_float32 _) ->
141-
Misc.fatal_error "float32 is not supported in closure. Consider using flambda2."
140+
| Const_base (Const_float32 _ | Const_unboxed_float32 _) ->
141+
Misc.fatal_error "float32 is not supported in flambda. Consider using flambda2."
142142
| Const_base (Const_int32 c) ->
143143
register_const t (Allocated_const (Int32 c))
144144
Names.const_int32

otherlibs/alpha/alpha.ml

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
module Float32 = Float32
2+
module Float32_u = Float32_u

otherlibs/alpha/alpha.mli

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
module Float32 = Float32
2+
module Float32_u = Float32_u

otherlibs/alpha/dune

+6-1
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,18 @@
4242
(alpha.cma as alpha/alpha.cma)
4343
(alpha.mli as alpha/alpha.mli)
4444
(float32.mli as alpha/float32.mli)
45+
(float32_u.mli as alpha/float32_u.mli)
4546
(.alpha.objs/byte/alpha.cmi as alpha/alpha.cmi)
4647
(.alpha.objs/byte/alpha.cmt as alpha/alpha.cmt)
4748
(.alpha.objs/byte/alpha.cmti as alpha/alpha.cmti)
4849
(.alpha.objs/native/alpha.cmx as alpha/alpha.cmx)
4950
(.alpha.objs/byte/float32.cmi as alpha/float32.cmi)
5051
(.alpha.objs/byte/float32.cmt as alpha/float32.cmt)
5152
(.alpha.objs/byte/float32.cmti as alpha/float32.cmti)
52-
(.alpha.objs/native/float32.cmx as alpha/float32.cmx))
53+
(.alpha.objs/native/float32.cmx as alpha/float32.cmx)
54+
(.alpha.objs/byte/float32_u.cmi as alpha/float32_u.cmi)
55+
(.alpha.objs/byte/float32_u.cmt as alpha/float32_u.cmt)
56+
(.alpha.objs/byte/float32_u.cmti as alpha/float32_u.cmti)
57+
(.alpha.objs/native/float32_u.cmx as alpha/float32_u.cmx))
5358
(section lib)
5459
(package ocaml))

otherlibs/alpha/float32_u.ml

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
(**************************************************************************)
2+
(* *)
3+
(* OCaml *)
4+
(* *)
5+
(* Chris Casinghino, Jane Street, New York *)
6+
(* *)
7+
(* Copyright 2023 Jane Street Group LLC *)
8+
(* *)
9+
(* All rights reserved. This file is distributed under the terms of *)
10+
(* the GNU Lesser General Public License version 2.1, with the *)
11+
(* special exception on linking described in the file LICENSE. *)
12+
(* *)
13+
(**************************************************************************)
14+
15+
open! Stdlib
16+
17+
[@@@ocaml.flambda_o3]
18+
19+
external box_float : float# -> (float[@local_opt]) = "%box_float"
20+
21+
external unbox_float : (float[@local_opt]) -> float# = "%unbox_float"
22+
23+
external box_int32 : int32# -> (int32[@local_opt]) = "%box_int32"
24+
25+
external unbox_int32 : (int32[@local_opt]) -> int32# = "%unbox_int32"
26+
27+
external to_float32 : float32# -> (float32[@local_opt]) = "%box_float32"
28+
29+
external of_float32 : (float32[@local_opt]) -> float32# = "%unbox_float32"
30+
31+
(* CR layouts: Investigate whether it's worth making these things externals.
32+
Are there situations where the middle-end won't inline them and remove the
33+
boxing/unboxing? *)
34+
35+
let[@inline always] neg x = of_float32 (Float32.neg (to_float32 x))
36+
37+
let[@inline always] add x y = of_float32 (Float32.add (to_float32 x) (to_float32 y))
38+
39+
let[@inline always] sub x y = of_float32 (Float32.sub (to_float32 x) (to_float32 y))
40+
41+
let[@inline always] mul x y = of_float32 (Float32.mul (to_float32 x) (to_float32 y))
42+
43+
let[@inline always] div x y = of_float32 (Float32.div (to_float32 x) (to_float32 y))
44+
45+
let[@inline always] pow x y = of_float32 (Float32.pow (to_float32 x) (to_float32 y))
46+
47+
module Operators = struct
48+
let[@inline always] ( ~-. ) x = of_float32 (Float32.neg (to_float32 x))
49+
50+
let[@inline always] ( +. ) x y = of_float32 (Float32.add (to_float32 x) (to_float32 y))
51+
52+
let[@inline always] ( -. ) x y = of_float32 (Float32.sub (to_float32 x) (to_float32 y))
53+
54+
let[@inline always] ( *. ) x y = of_float32 (Float32.mul (to_float32 x) (to_float32 y))
55+
56+
let[@inline always] ( /. ) x y = of_float32 (Float32.div (to_float32 x) (to_float32 y))
57+
58+
let[@inline always] ( ** ) x y = of_float32 (Float32.pow (to_float32 x) (to_float32 y))
59+
end
60+
61+
let[@inline always] fma x y z = of_float32 (Float32.fma (to_float32 x) (to_float32 y) (to_float32 z))
62+
63+
let[@inline always] rem x y = of_float32 (Float32.rem (to_float32 x) (to_float32 y))
64+
65+
let[@inline always] succ x = of_float32 (Float32.succ (to_float32 x))
66+
67+
let[@inline always] pred x = of_float32 (Float32.pred (to_float32 x))
68+
69+
let[@inline always] abs x = of_float32 (Float32.abs (to_float32 x))
70+
71+
let[@inline always] is_finite x = Float32.is_finite (to_float32 x)
72+
73+
let[@inline always] is_infinite x = Float32.is_infinite (to_float32 x)
74+
75+
let[@inline always] is_nan x = Float32.is_nan (to_float32 x)
76+
77+
let[@inline always] is_integer x = Float32.is_integer (to_float32 x)
78+
79+
let[@inline always] of_int x = of_float32 (Float32.of_int x)
80+
81+
let[@inline always] to_int x = Float32.to_int (to_float32 x)
82+
83+
let[@inline always] of_float x = of_float32 (Float32.of_float (box_float x))
84+
85+
let[@inline always] to_float x = unbox_float (Float32.to_float (to_float32 x))
86+
87+
let[@inline always] of_bits x = of_float32 (Float32.of_bits (box_int32 x))
88+
89+
let[@inline always] to_bits x = unbox_int32 (Float32.to_bits (to_float32 x))
90+
91+
let[@inline always] of_string x = of_float32 (Float32.of_string x)
92+
93+
let[@inline always] to_string x = Float32.to_string (to_float32 x)
94+
95+
type fpclass = Stdlib.fpclass =
96+
FP_normal
97+
| FP_subnormal
98+
| FP_zero
99+
| FP_infinite
100+
| FP_nan
101+
102+
let[@inline always] classify_float x = Float32.classify_float (to_float32 x)
103+
104+
let[@inline always] sqrt x = of_float32 (Float32.sqrt (to_float32 x))
105+
106+
let[@inline always] cbrt x = of_float32 (Float32.cbrt (to_float32 x))
107+
108+
let[@inline always] exp x = of_float32 (Float32.exp (to_float32 x))
109+
110+
let[@inline always] exp2 x = of_float32 (Float32.exp2 (to_float32 x))
111+
112+
let[@inline always] log x = of_float32 (Float32.log (to_float32 x))
113+
114+
let[@inline always] log10 x = of_float32 (Float32.log10 (to_float32 x))
115+
116+
let[@inline always] log2 x = of_float32 (Float32.log2 (to_float32 x))
117+
118+
let[@inline always] expm1 x = of_float32 (Float32.expm1 (to_float32 x))
119+
120+
let[@inline always] log1p x = of_float32 (Float32.log1p (to_float32 x))
121+
122+
let[@inline always] cos x = of_float32 (Float32.cos (to_float32 x))
123+
124+
let[@inline always] sin x = of_float32 (Float32.sin (to_float32 x))
125+
126+
let[@inline always] tan x = of_float32 (Float32.tan (to_float32 x))
127+
128+
let[@inline always] acos x = of_float32 (Float32.acos (to_float32 x))
129+
130+
let[@inline always] asin x = of_float32 (Float32.asin (to_float32 x))
131+
132+
let[@inline always] atan x = of_float32 (Float32.atan (to_float32 x))
133+
134+
let[@inline always] atan2 x y = of_float32 (Float32.atan2 (to_float32 x) (to_float32 y))
135+
136+
let[@inline always] hypot x y = of_float32 (Float32.hypot (to_float32 x) (to_float32 y))
137+
138+
let[@inline always] cosh x = of_float32 (Float32.cosh (to_float32 x))
139+
140+
let[@inline always] sinh x = of_float32 (Float32.sinh (to_float32 x))
141+
142+
let[@inline always] tanh x = of_float32 (Float32.tanh (to_float32 x))
143+
144+
let[@inline always] acosh x = of_float32 (Float32.acosh (to_float32 x))
145+
146+
let[@inline always] asinh x = of_float32 (Float32.asinh (to_float32 x))
147+
148+
let[@inline always] atanh x = of_float32 (Float32.atanh (to_float32 x))
149+
150+
let[@inline always] erf x = of_float32 (Float32.erf (to_float32 x))
151+
152+
let[@inline always] erfc x = of_float32 (Float32.erfc (to_float32 x))
153+
154+
let[@inline always] trunc x = of_float32 (Float32.trunc (to_float32 x))
155+
156+
let[@inline always] round x = of_float32 (Float32.round (to_float32 x))
157+
158+
let[@inline always] ceil x = of_float32 (Float32.ceil (to_float32 x))
159+
160+
let[@inline always] floor x = of_float32 (Float32.floor (to_float32 x))
161+
162+
let[@inline always] next_after x y = of_float32 (Float32.next_after (to_float32 x) (to_float32 y))
163+
164+
let[@inline always] copy_sign x y = of_float32 (Float32.copy_sign (to_float32 x) (to_float32 y))
165+
166+
let[@inline always] sign_bit x = Float32.sign_bit (to_float32 x)
167+
168+
let[@inline always] ldexp x i = of_float32 (Float32.ldexp (to_float32 x) i)
169+
170+
type t = float32#
171+
172+
let[@inline always] compare x y = Float32.compare (to_float32 x) (to_float32 y)
173+
174+
let[@inline always] equal x y = Float32.equal (to_float32 x) (to_float32 y)
175+
176+
let[@inline always] min x y = of_float32 (Float32.min (to_float32 x) (to_float32 y))
177+
178+
let[@inline always] max x y = of_float32 (Float32.max (to_float32 x) (to_float32 y))
179+
180+
let[@inline always] min_num x y = of_float32 (Float32.min_num (to_float32 x) (to_float32 y))
181+
182+
let[@inline always] max_num x y = of_float32 (Float32.max_num (to_float32 x) (to_float32 y))

0 commit comments

Comments
 (0)