Skip to content

Commit 3c70047

Browse files
authored
Merge pull request #23 from mcabbott/lazy
Lazy loading of LazyArrays
2 parents f446c22 + c9c1fb1 commit 3c70047

File tree

6 files changed

+48
-13
lines changed

6 files changed

+48
-13
lines changed

Diff for: Project.toml

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@ version = "0.2.2"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
8-
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
98
LazyStack = "1fad7336-0346-5a1a-a56f-a06ba010965b"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1211
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1312
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1414
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1515
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1616

1717
[compat]
1818
Compat = "2.2, 3"
19-
LazyArrays = "0.12, 0.13, 0.14, 0.15, 0.16"
2019
LazyStack = "0.0.4, 0.0.5, 0.0.6, 0.0.7, 0.0.8"
2120
MacroTools = "0.5"
2221
OffsetArrays = "0.11, 1.0"
22+
Requires = "0.5, 1"
2323
StaticArrays = "0.10, 0.11, 0.12"
2424
ZygoteRules = "0.1, 0.2"
2525
julia = "1"
@@ -28,10 +28,11 @@ julia = "1"
2828
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
2929
Einsum = "b7d42ee7-0b51-5a75-98ca-779d3107e4c0"
3030
JuliennedArrays = "5cadff95-7770-533d-a838-a1bf817ee6e0"
31+
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3132
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
3233
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3334
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
3435
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3536

3637
[targets]
37-
test = ["Test", "Compat", "Einsum", "JuliennedArrays", "LoopVectorization", "Statistics", "Strided"]
38+
test = ["Test", "Compat", "Einsum", "JuliennedArrays", "LazyArrays", "LoopVectorization", "Statistics", "Strided"]

Diff for: docs/src/options.md

+2
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,13 @@ In the following example, the product `V .* V' .* V3` contains about 1GB of data
7979
the writing of which is avoided by giving the option `lazy`:
8080

8181
```julia
82+
using LazyArrays # you must now load this package
8283
V = rand(500); V3 = reshape(V,1,1,:);
8384

8485
@time @reduce W[i] := sum(j,k) V[i]*V[j]*V[k]; # 0.6 seconds, 950 MB
8586
@time @reduce W[i] := sum(j,k) V[i]*V[j]*V[k] lazy; # 0.025 s, 5 KB
8687
```
88+
However, right now this gives `3.7 s (250 M allocations, 9 GB)`, something is broken!
8789

8890
The package [Strided.jl](https://github.com/Jutho/Strided.jl) can apply multi-threading to
8991
broadcasting, and some other magic. You can enable it with the option `strided`, like this:

Diff for: src/TensorCast.jl

+28-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11

22
module TensorCast
33

4+
# This speeds up loading a bit, on Julia 1.5, about 1s in my test.
5+
# https://github.com/JuliaPlots/Plots.jl/pull/2544/files
6+
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@optlevel"))
7+
@eval Base.Experimental.@optlevel 1
8+
end
9+
410
export @cast, @reduce, @matmul, @pretty
511

612
using MacroTools, StaticArrays, Compat
@@ -10,11 +16,28 @@ include("macro.jl")
1016
include("pretty.jl")
1117
include("string.jl")
1218

13-
include("slice.jl") # slice, glue, etc
14-
include("view.jl") # orient, Reverse{d} etc
15-
include("lazy.jl") # LazyCast
16-
include("static.jl") # StaticArrays
19+
module Fast # shield non-macro code from @optlevel 1
20+
using LinearAlgebra, StaticArrays, Compat
21+
22+
include("slice.jl") # slice, glue, etc
23+
export sliceview, slicecopy, glue, glue!, red_glue, cat_glue, copy_glue, lazy_glue, iscodesorted, countcolons
24+
25+
include("view.jl") # orient, Reverse{d} etc
26+
export diagview, orient, rview, mul!, star, PermuteDims, Reverse, Shuffle
27+
28+
include("static.jl") # StaticArrays
29+
export static_slice, static_glue
30+
31+
end
32+
using .Fast
33+
const mul! = Fast.mul!
34+
35+
using Requires
36+
37+
@init @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
38+
include("lazy.jl") # LazyCast # this costs about 3s in my test, 3.8s -> 7.7s
39+
end
1740

18-
include("warm.jl")
41+
include("warm.jl") # worth 2s in my test
1942

2043
end # module

Diff for: src/lazy.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
import LazyArrays
2+
import .LazyArrays
33

44
#=
55
The macro option "lazy" always produces things like sum(@__dot__(lazy(x+y)))

Diff for: src/macro.jl

+12-4
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ This mostly aims to re-work the given expression into `some(steps(A))[i,j]`,
234234
but also pushes `A = f(x)` into `store.top`.
235235
"""
236236
function standardise(ex, store::NamedTuple, call::CallInfo; LHS=false)
237+
@nospecialize ex
238+
237239
# This acts only on single indexing expressions:
238240
if @capture(ex, A_{ijk__})
239241
static=true
@@ -378,6 +380,7 @@ target dims not correctly handled yet -- what do I want? TODO
378380
Simple glue / stand. does not permutedims, but broadcasting may have to... avoid twice?
379381
"""
380382
function standardglue(ex, target, store::NamedTuple, call::CallInfo)
383+
@nospecialize ex
381384

382385
# The sole target here is indexing expressions:
383386
if @capture(ex, A_[inner__])
@@ -469,6 +472,7 @@ This beings the expression to have target indices,
469472
by permutedims and if necessary broadcasting, always using `readycast()`.
470473
"""
471474
function targetcast(ex, target, store::NamedTuple, call::CallInfo)
475+
@nospecialize ex
472476

473477
# If just one naked expression, then we won't broadcast:
474478
if @capture(ex, A_[ijk__])
@@ -503,6 +507,7 @@ end
503507
This is walked over the expression to prepare for `@__dot__` etc, by `targetcast()`.
504508
"""
505509
function readycast(ex, target, store::NamedTuple, call::CallInfo)
510+
@nospecialize ex
506511

507512
# Scalar functions can be protected entirely from broadcasting:
508513
# TODO this means A[i,j] + rand()/10 doesn't work, /(...,10) is a function!
@@ -578,6 +583,7 @@ If there are more than two factors, it recurses, and you get `(A*B) * C`,
578583
or perhaps tuple `(A*B, C)`.
579584
"""
580585
function matmultarget(ex, target, parsed, store::NamedTuple, call::CallInfo)
586+
@nospecialize ex
581587

582588
@capture(ex, A_ * B_ * C__ | *(A_, B_, C__) ) || throw(MacroError("can't @matmul that!", call))
583589

@@ -631,6 +637,7 @@ pushing calculation steps into store.
631637
Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`.
632638
"""
633639
function recursemacro(ex, store::NamedTuple, call::CallInfo)
640+
@nospecialize ex
634641

635642
# Actually look for recursion
636643
if @capture(ex, @reduce(subex__) )
@@ -675,6 +682,8 @@ This saves to `store` the sizes of all input tensors, and their sub-slices if an
675682
however it should not destroy this so that `sz_j` can be got later.
676683
"""
677684
function rightsizes(ex, store::NamedTuple, call::CallInfo)
685+
@nospecialize ex
686+
678687
:recurse in call.flags && return nothing # outer version took care of this
679688

680689
if @capture(ex, A_[outer__][inner__] | A_[outer__]{inner__} )
@@ -1115,8 +1124,7 @@ end
11151124

11161125
tensorprimetidy(v::Vector) = Any[ tensorprimetidy(x) for x in v ]
11171126
function tensorprimetidy(ex)
1118-
MacroTools.postwalk(ex) do x
1119-
1127+
MacroTools.postwalk(ex) do @nospecialize x
11201128
@capture(x, ((ij__,) \ k_) ) && return :( ($(ij...),$k) )
11211129
@capture(x, i_ \ j_ ) && return :( ($i,$j) )
11221130

@@ -1172,7 +1180,7 @@ containsindexing(s) = false
11721180
function containsindexing(ex::Expr)
11731181
flag = false
11741182
# MacroTools.postwalk(x -> @capture(x, A_[ijk__]) && (flag=true), ex)
1175-
MacroTools.postwalk(ex) do x
1183+
MacroTools.postwalk(ex) do @nospecialize x
11761184
# @capture(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true)
11771185
if @capture(x, A_[ijk__])
11781186
# @show x ijk # TODO this is a bit broken? @pretty @cast Z[i,j] := W[i] * exp(X[1][i] - X[2][j])
@@ -1185,7 +1193,7 @@ end
11851193
listindices(s::Symbol) = []
11861194
function listindices(ex::Expr)
11871195
list = []
1188-
MacroTools.postwalk(ex) do x
1196+
MacroTools.postwalk(ex) do @nospecialize x
11891197
if @capture(x, A_[ijk__])
11901198
flat, _ = indexparse(nothing, ijk)
11911199
push!(list, flat)

Diff for: test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using StaticArrays
66
using OffsetArrays
77
using Einsum
88
using Strided
9+
using LazyArrays
910
using Compat
1011
if VERSION >= v"1.1"
1112
using LoopVectorization

0 commit comments

Comments
 (0)