diff --git a/Manifest.toml b/Manifest.toml
index 23989dc9..2b8a9b09 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -1,5 +1,11 @@
 # This file is machine-generated - editing it directly is not advised
 
+[[ArrayInterface]]
+deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays"]
+git-tree-sha1 = "ee07ae00e3cc277dcfa5507ce25be522313ecc3e"
+uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
+version = "3.1.1"
+
 [[Base64]]
 uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
 
@@ -11,6 +17,11 @@ repo-url = "https://github.com/FluxML/Functors.jl"
 uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 version = "0.1.0"
 
+[[IfElse]]
+git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
+uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
+version = "0.1.0"
+
 [[Libdl]]
 uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
 
diff --git a/Project.toml b/Project.toml
index 62b70e6f..5b95afd1 100644
--- a/Project.toml
+++ b/Project.toml
@@ -4,6 +4,7 @@ authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
 version = "0.1.0"
 
 [deps]
+ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
 Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -12,6 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 [compat]
 Functors = "0"
 Requires = "0.5, 1"
+ArrayInterface = "3"
 
 [extras]
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
diff --git a/src/Optimisers.jl b/src/Optimisers.jl
index c57aa5a7..61db7765 100644
--- a/src/Optimisers.jl
+++ b/src/Optimisers.jl
@@ -1,9 +1,11 @@
 module Optimisers
 
 using Functors: functor, fmap, isleaf
+using ArrayInterface
 
 include("interface.jl")
 include("rules.jl")
+include("mutating.jl")
 
 export Descent, ADAM, Momentum, Nesterov, RMSProp,
        ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief,
diff --git a/src/interface.jl b/src/interface.jl
index b59dad53..669fca6a 100644
--- a/src/interface.jl
+++ b/src/interface.jl
@@ -10,7 +10,7 @@ function state(o, x)
 end
 
 function _update(o, x, x̄, st)
-  x̄, st = apply(o, x, x̄, st)
+  x̄, st = ismutable(x) ? apply!(o, x, x̄, st) : apply(o, x, x̄, st)
   return patch(x, x̄), st
 end
 
diff --git a/src/mutating.jl b/src/mutating.jl
new file mode 100644
index 00000000..d8b5106b
--- /dev/null
+++ b/src/mutating.jl
@@ -0,0 +1,169 @@
+# This file contains duplicated rules as in rules.jl
+# where all the operations are done in-place for a softer deprecation
+
+function apply!(o::Descent, x, dx, state)
+  η = convert(eltype(dx), o.eta)
+  dx .*= η
+
+  return dx, state
+end
+
+function apply!(o::Momentum, x, dx, state)
+  η, ρ, v = o.eta, o.rho, state
+  @. v = ρ * v - η * dx
+
+  return -v, v
+end
+
+function apply!(o::Nesterov, x, dx, state)
+  η, ρ, v = o.eta, o.rho, state
+  @. d = ρ^2 * v - (1+ρ) * η * dx
+  @. v = ρ * v - η * dx
+
+  return -d, v
+end
+
+function apply!(o::RMSProp, x, dx, state)
+  η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
+  @. acc = ρ * acc + (1 - ρ) * dx^2
+  @. dx = dx * (η / (sqrt(acc) + ϵ))
+
+  return dx, acc
+end
+
+function apply!(o::ADAM{T}, x, dx, state) where T
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+  mt, vt, βt = state
+
+  @. mt = β[1] * mt + (one(T) - β[1]) * dx
+  @. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2
+  @. dx = mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η
+
+  return dx, (mt, vt, βt .* β)
+end
+
+function apply!(o::RADAM, x, dx, state)
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+  ρ∞ = 2/(1-β[2])-1
+
+  mt, vt, βt, t = state
+
+  @. mt = β[1] * mt + (1 - β[1]) * dx
+  @. vt = β[2] * vt + (1 - β[2]) * dx^2
+  ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
+  if ρ > 4
+    r = sqrt((ρ - 4) * (ρ - 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
+    @. dx = mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
+  else
+    @. dx = mt / (1 - βt[1]) * η
+  end
+
+  return dx, (mt, vt, βt .* β, t + 1)
+end
+
+function apply!(o::AdaMax, x, dx, state)
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+
+  mt, ut, βt = state
+
+  @. mt = β[1] * mt + (1 - β[1]) * dx
+  @. ut = max(β[2] * ut, abs(dx))
+  @. dx = (η/(1 - βt[1])) * mt/(ut + ϵ)
+
+  return dx, (mt, ut, βt .* β)
+end
+
+function apply!(o::AdaMax, x, dx, state)
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+
+  mt, ut, βt = state
+
+  @. mt = β[1] * mt + (1 - β[1]) * dx
+  @. ut = max(β[2] * ut, abs(dx))
+  @. dx = (η/(1 - βt[1])) * mt/(ut + ϵ)
+
+  return dx, (mt, ut, βt .* β)
+end
+
+function apply!(o::OADAM, x, dx, state)
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+
+  mt, vt, βt, dx_ = state
+
+  @. mt = β[1] * mt + (1 - β[1]) * dx
+  @. vt = β[2] * vt + (1 - β[2]) * dx^2
+  @. dx = -dx_
+  @. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
+  @. dx += 2*dx_
+
+  return dx, (mt, vt, βt .* β, dx_)
+end
+
+function apply!(o::ADAGrad, x, dx, state)
+  η, ϵ = o.eta, o.epsilon
+  acc, = state
+
+  @. acc += dx^2
+  @. dx *= η / (sqrt(acc) + ϵ)
+
+  return dx, (acc,)
+end
+
+function apply!(o::ADADelta, x, dx, state)
+  ρ, ϵ = o.rho, o.epsilon
+  acc, Δacc = state
+
+  @. acc = ρ * acc + (1 - ρ) * dx^2
+  # DON'T remove epsilon from numerator
+  # or even out of the square roots
+  @. dx *= sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
+  @. Δacc = ρ * Δacc + (1 - ρ) * dx^2
+
+  return dx, (acc, Δacc)
+end
+
+function apply!(o::AMSGrad, x, dx, state)
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+
+  mt, vt, v̂t = state
+
+  @. mt = β[1] * mt + (1 - β[1]) * dx
+  @. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
+  @. v̂t = max(v̂t, vt)
+  @. dx = η * mt / (sqrt(v̂t) + ϵ)
+
+  return dx, (mt, vt, v̂t)
+end
+
+function apply!(o::NADAM, x, dx, state)
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+
+  mt, vt, βt = state
+
+  @. mt = β[1] * mt + (1 - β[1]) * dx
+  @. vt = β[2] * vt + (1 - β[2]) * dx^2
+  @. dx = (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) / 
+          (sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η
+
+  return dx, (mt, vt, βt .* β)
+end
+
+function apply!(o::AdaBelief, x, dx, state)
+  η, β, ϵ = o.eta, o.beta, o.epsilon
+  mt, st = state
+
+  @. mt = β[1] * mt + (1 - β[1]) * dx
+  @. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
+  @. dx =  η * mt / (sqrt(st) + ϵ)
+
+  return dx, (mt, st)
+end
+
+function apply!(o::OptimiserChain, x, dx, states)
+  new_states = similar(states)
+  for (i, (opt, state)) in enumerate(zip(o.opts, states))
+    _, new_states[i] = apply!(opt, x, dx, state)
+  end
+
+  return dx, new_states
+end
diff --git a/src/rules.jl b/src/rules.jl
index 9c89ea28..d52c7596 100644
--- a/src/rules.jl
+++ b/src/rules.jl
@@ -17,7 +17,6 @@ init(o::Descent, x::AbstractArray) = nothing
 
 function apply(o::Descent, x, dx, state)
   η = convert(eltype(dx), o.eta)
-  dx .*= η
   
   return dx .* η, state
 end
@@ -513,4 +512,4 @@ function apply(o::OptimiserChain, x, dx, states)
   end
 
   return dx, new_states
-end
\ No newline at end of file
+end