Skip to content

Commit 027a922

Browse files
committed
Merge branch 'strides'
2 parents 0ba5ce4 + c724237 commit 027a922

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/tracker/array.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
261261
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
262262
end
263263

264-
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
264+
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
265265

266-
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
267-
track(_maxpool, x, k, pad)
266+
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
267+
track(_maxpool, x, k, pad, stride)
268268

269-
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
270-
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
269+
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
270+
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
271271

272-
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
272+
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
273273

274-
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
275-
track(_meanpool, x, k, pad)
274+
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
275+
track(_meanpool, x, k, pad, stride)
276276

277-
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
278-
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
277+
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
278+
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
279279

280280
# Broadcasting
281281

0 commit comments

Comments
 (0)