Skip to content

Commit 75ef0a0

Browse files
authored
flambda-backend: Add partial adjoints of join_with and meet_with (#2479)
* add partial adjoints of meet and join * just use id as adjoints * restrict def. of partial adjoints
1 parent 030a263 commit 75ef0a0

File tree

5 files changed

+87
-122
lines changed

5 files changed

+87
-122
lines changed

typing/mode.ml

+28-30
Original file line numberDiff line numberDiff line change
@@ -504,11 +504,11 @@ module Lattices_mono = struct
504504

505505
type ('a, 'b, 'd) morph =
506506
| Id : ('a, 'a, 'd) morph (** identity morphism *)
507-
| Meet_with : 'a -> ('a, 'a, 'd * disallowed) morph
507+
| Meet_with : 'a -> ('a, 'a, 'l * 'r) morph
508508
(** Meet the input with the parameter *)
509509
| Imply : 'a -> ('a, 'a, disallowed * 'd) morph
510510
(** The right adjoint of [Meet_with] *)
511-
| Join_with : 'a -> ('a, 'a, disallowed * 'd) morph
511+
| Join_with : 'a -> ('a, 'a, 'l * 'r) morph
512512
(** Join the input with the parameter *)
513513
| Subtract : 'a -> ('a, 'a, 'd * disallowed) morph
514514
(** The left adjoint of [Join_with] *)
@@ -557,6 +557,7 @@ module Lattices_mono = struct
557557
| Proj (src, ax) -> Proj (src, ax)
558558
| Min_with ax -> Min_with ax
559559
| Meet_with c -> Meet_with c
560+
| Join_with c -> Join_with c
560561
| Subtract c -> Subtract c
561562
| Compose (f, g) ->
562563
let f = allow_left f in
@@ -579,6 +580,7 @@ module Lattices_mono = struct
579580
| Proj (src, ax) -> Proj (src, ax)
580581
| Max_with ax -> Max_with ax
581582
| Join_with c -> Join_with c
583+
| Meet_with c -> Meet_with c
582584
| Imply c -> Imply c
583585
| Compose (f, g) ->
584586
let f = allow_right f in
@@ -893,7 +895,9 @@ module Lattices_mono = struct
893895
| Imply c0, Imply c1 -> Some (Imply (meet dst c0 c1))
894896
| Subtract c0, Subtract c1 -> Some (Subtract (join dst c0 c1))
895897
| Imply c0, Join_with c1 when le dst c0 c1 -> Some (Join_with (max dst))
898+
| Imply c0, Meet_with c1 when le dst c0 c1 -> Some (Imply c0)
896899
| Subtract c0, Meet_with c1 when le dst c1 c0 -> Some (Meet_with (min dst))
900+
| Subtract c0, Join_with c1 when le dst c1 c0 -> Some (Subtract c0)
897901
| Meet_with c0, m1 when is_max c0 -> Some m1
898902
| Join_with c0, m1 when is_min c0 -> Some m1
899903
| Imply c0, m1 when is_max c0 -> Some m1
@@ -1045,6 +1049,10 @@ module Lattices_mono = struct
10451049
let g' = left_adjoint mid g in
10461050
Compose (g', f')
10471051
| Join_with c -> Subtract c
1052+
| Meet_with _c ->
1053+
(* The downward closure of [Meet_with c]'s image is all [x <= c].
1054+
For those, [x <= meet c y] is equivalent to [x <= y]. *)
1055+
Id
10481056
| Imply c -> Meet_with c
10491057
| Comonadic_to_monadic _ -> Monadic_to_comonadic_min
10501058
| Monadic_to_comonadic_max -> Comonadic_to_monadic dst
@@ -1072,6 +1080,10 @@ module Lattices_mono = struct
10721080
Compose (g', f')
10731081
| Meet_with c -> Imply c
10741082
| Subtract c -> Join_with c
1083+
| Join_with _c ->
1084+
(* The upward closure of [Join_with c]'s image is all [x >= c].
1085+
For those, [join c y <= x] is equivalent to [y <= x]. *)
1086+
Id
10751087
| Comonadic_to_monadic _ -> Monadic_to_comonadic_max
10761088
| Monadic_to_comonadic_min -> Comonadic_to_monadic dst
10771089
| Local_to_regional -> Regional_to_local
@@ -1346,11 +1358,9 @@ module Comonadic_with_regionality = struct
13461358

13471359
let proj ax m = Solver.via_monotone (proj_obj ax) (Proj (Obj.obj, ax)) m
13481360

1349-
let meet_const c m =
1350-
Solver.via_monotone obj (Meet_with c) (Solver.disallow_right m)
1361+
let meet_const c m = Solver.via_monotone obj (Meet_with c) m
13511362

1352-
let join_const c m =
1353-
Solver.via_monotone obj (Join_with c) (Solver.disallow_left m)
1363+
let join_const c m = Solver.via_monotone obj (Join_with c) m
13541364

13551365
let min_with ax m =
13561366
Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_right m)
@@ -1411,11 +1421,9 @@ module Comonadic_with_locality = struct
14111421

14121422
let proj ax m = Solver.via_monotone (proj_obj ax) (Proj (Obj.obj, ax)) m
14131423

1414-
let meet_const c m =
1415-
Solver.via_monotone obj (Meet_with c) (Solver.disallow_right m)
1424+
let meet_const c m = Solver.via_monotone obj (Meet_with c) m
14161425

1417-
let join_const c m =
1418-
Solver.via_monotone obj (Join_with c) (Solver.disallow_left m)
1426+
let join_const c m = Solver.via_monotone obj (Join_with c) m
14191427

14201428
let min_with ax m =
14211429
Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_right m)
@@ -1483,11 +1491,9 @@ module Monadic = struct
14831491
by [Solver_polarized], but some remain, such as the [Min_with] below which
14841492
is inverted from [Max_with]. *)
14851493

1486-
let meet_const c m =
1487-
Solver.via_monotone obj (Join_with c) (Solver.disallow_right m)
1494+
let meet_const c m = Solver.via_monotone obj (Join_with c) m
14881495

1489-
let join_const c m =
1490-
Solver.via_monotone obj (Meet_with c) (Solver.disallow_left m)
1496+
let join_const c m = Solver.via_monotone obj (Meet_with c) m
14911497

14921498
let max_with ax m =
14931499
Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_left m)
@@ -1744,34 +1750,30 @@ module Value = struct
17441750
| Comonadic ax -> min_with_comonadic ax m
17451751

17461752
let join_with_monadic ax c { monadic; comonadic } =
1747-
let comonadic = Comonadic.disallow_left comonadic in
17481753
let monadic = Monadic.join_with ax c monadic in
17491754
{ monadic; comonadic }
17501755

17511756
let join_with_comonadic ax c { monadic; comonadic } =
1752-
let monadic = Monadic.disallow_left monadic in
17531757
let comonadic = Comonadic.join_with ax c comonadic in
17541758
{ comonadic; monadic }
17551759

1756-
let join_with :
1757-
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (disallowed * r) t =
1760+
let join_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
1761+
=
17581762
fun ax c m ->
17591763
match ax with
17601764
| Monadic ax -> join_with_monadic ax c m
17611765
| Comonadic ax -> join_with_comonadic ax c m
17621766

17631767
let meet_with_monadic ax c { monadic; comonadic } =
1764-
let comonadic = Comonadic.disallow_right comonadic in
17651768
let monadic = Monadic.meet_with ax c monadic in
17661769
{ monadic; comonadic }
17671770

17681771
let meet_with_comonadic ax c { monadic; comonadic } =
1769-
let monadic = Monadic.disallow_right monadic in
17701772
let comonadic = Comonadic.meet_with ax c comonadic in
17711773
{ comonadic; monadic }
17721774

1773-
let meet_with :
1774-
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * disallowed) t =
1775+
let meet_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
1776+
=
17751777
fun ax c m ->
17761778
match ax with
17771779
| Monadic ax -> meet_with_monadic ax c m
@@ -2004,34 +2006,30 @@ module Alloc = struct
20042006
| Comonadic ax -> min_with_comonadic ax m
20052007

20062008
let join_with_monadic ax c { monadic; comonadic } =
2007-
let comonadic = Comonadic.disallow_left comonadic in
20082009
let monadic = Monadic.join_with ax c monadic in
20092010
{ monadic; comonadic }
20102011

20112012
let join_with_comonadic ax c { monadic; comonadic } =
2012-
let monadic = Monadic.disallow_left monadic in
20132013
let comonadic = Comonadic.join_with ax c comonadic in
20142014
{ comonadic; monadic }
20152015

2016-
let join_with :
2017-
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (disallowed * r) t =
2016+
let join_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
2017+
=
20182018
fun ax c m ->
20192019
match ax with
20202020
| Monadic ax -> join_with_monadic ax c m
20212021
| Comonadic ax -> join_with_comonadic ax c m
20222022

20232023
let meet_with_monadic ax c { monadic; comonadic } =
2024-
let comonadic = Comonadic.disallow_right comonadic in
20252024
let monadic = Monadic.meet_with ax c monadic in
20262025
{ monadic; comonadic }
20272026

20282027
let meet_with_comonadic ax c { monadic; comonadic } =
2029-
let monadic = Monadic.disallow_right monadic in
20302028
let comonadic = Comonadic.meet_with ax c comonadic in
20312029
{ comonadic; monadic }
20322030

2033-
let meet_with :
2034-
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * disallowed) t =
2031+
let meet_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
2032+
=
20352033
fun ax c m ->
20362034
match ax with
20372035
| Monadic ax -> meet_with_monadic ax c m

typing/mode_intf.mli

+7-7
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,13 @@ module type S = sig
308308

309309
val min_with : ('m, 'a, 'l * 'r) axis -> 'm -> ('l * disallowed) t
310310

311-
val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * disallowed) t
311+
val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t
312312

313-
val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> (disallowed * 'r) t
313+
val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t
314314

315315
val comonadic_to_monadic : ('l * 'r) Comonadic.t -> ('r * 'l) Monadic.t
316316

317-
val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t
317+
val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t
318318

319319
val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t
320320
end
@@ -340,7 +340,7 @@ module type S = sig
340340

341341
include Common with module Const := Const
342342

343-
val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t
343+
val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t
344344
end
345345

346346
(** Represents a mode axis in this product whose constant is ['a], and
@@ -413,15 +413,15 @@ module type S = sig
413413

414414
val min_with : ('m, 'a, 'l * 'r) axis -> 'm -> ('l * disallowed) t
415415

416-
val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * disallowed) t
416+
val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t
417417

418-
val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> (disallowed * 'r) t
418+
val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t
419419

420420
val zap_to_legacy : lr -> Const.t
421421

422422
val zap_to_ceil : ('l * allowed) t -> Const.t
423423

424-
val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t
424+
val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t
425425

426426
val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t
427427

typing/solver.ml

+6-28
Original file line numberDiff line numberDiff line change
@@ -311,40 +311,16 @@ module Solver_mono (C : Lattices_mono) = struct
311311
type a l.
312312
log:_ -> a C.obj -> a -> (a, l * allowed) morphvar -> (unit, a) Result.t =
313313
fun ~log obj a (Amorphvar (v, f) as mv) ->
314-
(* Requested [a <= f v], therefore [f' a <= v], where [f'] is the left
315-
adjoint of [f]. We should just apply [f'] to [a] and use that to
316-
constrain [v].
317-
318-
However, we aim to support a wider of notion of adjunctions (see
319-
[solver_intf.mli] for context). Say [f : B' -> A'] and [f' : A' -> B'].
320-
Note that [f' a] is known to be well-defined only if [a \in A] where [A]
321-
is some convex sublattice of [A'].
322-
323-
Note that we don't request the [A] of [f] from [Lattices_mono] for
324-
simplicity. Instead, note that we need to check [a] against [f v] anyway,
325-
and the bound of the latter is a subset of [A]. Therefore, once we make
326-
sure [a] is within the bound of [f v], we are free to apply [f'] to [a].
327-
Concretely:
328-
329-
1. If [a <= (f v).lower], immediately succeed
330-
2. If not [a <= (f v).upper], immediately fail
331-
3. Note that at this point, we still can't ensure that [a >= (f v).lower].
332-
(We don't assume total ordering, for best generality)
333-
Therefore, we set [a] to [join a (f v).lower].
334-
335-
Note how the "convex" condition plays here: (2) and (3) together ensures
336-
[(f v).lower <= a <= (f v).upper]. Note that [v \in B], therefore
337-
[f v \in A]. By convexity, we have [a \in A], and thus we can safely
338-
apply [f'] to [a].
339-
*)
340314
let mlower = mlower obj mv in
341315
let mupper = mupper obj mv in
342316
if C.le obj a mlower
343317
then Ok ()
344318
else if not (C.le obj a mupper)
345319
then Error mupper
346320
else
347-
let a = C.join obj a mlower in
321+
(* At this point we know [a <= f v], therefore [a] is in the downward
322+
closure of [f]'s image. Therefore, asking [a <= f v] is equivalent to
323+
asking [f' a <= v]. *)
348324
let f' = C.left_adjoint obj f in
349325
let src = C.src obj f in
350326
let a' = C.apply src f' a in
@@ -395,7 +371,6 @@ module Solver_mono (C : Lattices_mono) = struct
395371
else if not (C.le obj mlower a)
396372
then Error mlower
397373
else
398-
let a = C.meet obj a mupper in
399374
let f' = C.right_adjoint obj f in
400375
let src = C.src obj f in
401376
let a' = C.apply src f' a in
@@ -464,6 +439,9 @@ module Solver_mono (C : Lattices_mono) = struct
464439
match submode_cmv ~log dst (mlower dst mv) mu with
465440
| Error a -> Error (mlower dst mv, a)
466441
| Ok () ->
442+
(* At this point, we know that [f v <= g u.upper], which means [f v]
443+
lies within the downward closure of [g]'s image. Therefore, asking [f
444+
v <= g u] is equivalent to asking [g' f v <= u] *)
467445
let g' = C.left_adjoint dst g in
468446
let src = C.src dst g in
469447
let g'f = C.compose src g' (C.disallow_right f) in

typing/solver_intf.mli

+28-26
Original file line numberDiff line numberDiff line change
@@ -120,32 +120,34 @@ module type Lattices_mono = sig
120120

121121
(* Usual notion of adjunction:
122122
Given two morphisms [f : A -> B] and [g : B -> A], we require [f a <= b]
123-
iff [a <= g b].
124-
125-
Our solver accepts a wider notion of adjunction and only requires the same
126-
condition on convex sublattices. To be specific, if [f] and [g] form a
127-
usual adjunction between [A] and [B], and [A] is a convex sublattice of
128-
[A'], and [B] is a convex sublattice of [B'], we say that [f] and [g]
129-
form a partial adjunction between [A'] and [B']. We do not require [f] to
130-
be defined on [A'\A]. Similar for [g].
131-
132-
Definition of convex sublattice can be found at:
133-
https://en.wikipedia.org/wiki/Lattice_(order)#Sublattices
134-
135-
For example: Define [A = B = {0, 1, 2}] with total ordering. Define both
136-
[f] and [g] to be the identity function. Obviously [f] and [g] form a usual
137-
adjunction. Now, further define [A'] = [A], and [B'] = [{0, 1, 2, 3}] with
138-
total ordering. Obviously [A] is a convex sublattice of [A'], and [B] of
139-
[B']. Then we say [f] and [g] forms a partial adjunction between [A'] and
140-
[B'].
141-
142-
The feature allows the user to invoke [f a <= b'], where [a \in A] and [b'
143-
\in B']. Similarly, they can invoke [a' <= g b], where [a' \in A'] and [b
144-
\in B].
145-
146-
Moreover, if [a' \in A'\A], it is still fine to apply [f] to [a'], but the
147-
result should not be used as a left mode. This is unfortunately not
148-
enforcable by the ocaml type system, and we have to rely on user's caution.
123+
iff [a <= g b] for each [a \in A] and [b \in B].
124+
125+
Our solver accepts a wider notion of adjunction: Given two morphisms [f : A
126+
-> B] and [g : B -> A], we require [f a <= b] iff [a <= g b] for each [a]
127+
in the downward closure of [g]'s image and [b \in B].
128+
129+
We say [f] is a partial left adjoint of [g], because [f] is only
130+
constrained in part of its domain. As a result, [f] is not unique, since
131+
its valuation out of the constrained range can be arbitrarily chosen.
132+
133+
Dually, we can define the concept of partial right adjoint. Since partial
134+
adjoints are not unique, they don't form a pair: i.e., a partial left
135+
joint of a partial right adjoint of [f] is not [f] in general.
136+
137+
Concretely, the solver provides/requires the following guarantees
138+
(continuing the example above):
139+
140+
For the user of the [Solvers_polarized].
141+
- [g] applied to a right mode [m] can be used as a right mode without
142+
any restriction.
143+
- [f] applied to to a left mode [m] can be used as a left mode, given that
144+
the [m] is fully within the downward closure of [g]. This is unfortunately
145+
not enforcable by the ocaml type system, and we have to rely on user's
146+
caution.
147+
148+
For the supplier of the [Lattices_mono]:
149+
- The result of [left_adjoint g] is applied only on the downward closure of
150+
[g]'s image.
149151
*)
150152

151153
(* Note that [left_adjoint] and [right_adjoint] returns a [morph] weaker than

0 commit comments

Comments
 (0)