-
-
Notifications
You must be signed in to change notification settings - Fork 25
Add in-place destructure!
#165
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: master
Are you sure you want to change the base?
Conversation
Thanks you! Weird that it is so slow. |
Could the in-place version be tripping some aliasing heuristic and hitting a slow path? I guess a profile would be illuminating. |
Thank you initiating and implementing this idea, I think this is a great idea and would be very useful, I was trying this out because I am interested in-place copy of parameters into a model from a flat vector. From my comparisons, I suspect that one of the reason for the slowness of the in-place version is cache issues involving Additionally the following version (just utilising the fact that function _rebuild_alt!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
vecy = vec(y)
copyto!(y, _getat_alt(vecy, o, flat, view))
end
x
end
_getat_alt(y::AbstractVector, o::Int, flat::AbstractVector, get=getindex) =
ProjectTo(y)(get(flat, o .+ (1:length(y)))) and its get better than the usual using Flux, Optimisers, Zygote, BenchmarkTools
N = 1024
model = Chain(Dense(28^2 => N, relu), Dense(N => 10));
params,re = destructure(model)
params!,re! = destructure!(model)
params_alt!,re_alt! = destructure_alt!(model) # using above alternatives
@btime $re($params)
@btime $re!($params)
@btime $re_alt!($params)
106.964 μs (44 allocations: 3.11 MiB)
250.546 μs (35 allocations: 1.53 KiB)
156.664 μs (39 allocations: 1.69 KiB) When I choose 12.184 μs (43 allocations: 312.61 KiB)
21.374 μs (35 allocations: 1.53 KiB)
7.651 μs (39 allocations: 1.69 KiB) |
Ah that looks great, thanks for digging! For me, with the example at top: julia> @btime $re($params); # This is the reconstruction cost
min 92.167 μs, mean 301.432 μs (44 allocations, 3.11 MiB)
julia> @btime copy($params); # ... and it's mostly allocation, same mean:
min 97.333 μs, mean 309.699 μs (2 allocations, 3.11 MiB)
julia> @btime $re!($params); # new version without reshape
min 58.333 μs, mean 62.932 μs (39 allocations, 1.69 KiB) and with N=100:
I think the mean times are probably a better indication of the cost in actual use, when allocations differ so much, although possibly not perfect. |
Nice, that's good to know, the in-place version seems to be pretty stable with the timings, and how do I make And is the PR good to go? |
Mean is from JuliaCI/BenchmarkTools.jl#258, which I should eventually re-write to I see I did write some tests of this, it could all use one more look over. There's a commented out |
Interesting! That would be very useful! I ll see if I can take some pirate code out of it to start using locally :) I didn't take a look at |
89c8d43
to
058a25b
Compare
I commented the method out just to focus on getting one thing working first. I believe it still needs tests, but otherwise this is nearly done. Maybe I should check that my scary warning is true. I think something like this will return zero gradient:
|
That's great! Yes, you are right it returns |
This adds a variant of
destructure
with minimal changes such that it writes back into the original model, instead of creating a copy. This may close #146, cc @glatteisMarked draft as it seems surprisingly slow -- why?