Skip to content

Commit c724237

Browse files
committed
added stride for pooling in tracker
1 parent bdd8162 commit c724237

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/tracker/array.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
242242
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
243243
end
244244

245-
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
245+
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
246246

247-
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
248-
track(_maxpool, x, k, pad)
247+
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
248+
track(_maxpool, x, k, pad, stride)
249249

250-
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
251-
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
250+
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
251+
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
252252

253-
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
253+
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
254254

255-
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
256-
track(_meanpool, x, k, pad)
255+
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
256+
track(_meanpool, x, k, pad, stride)
257257

258-
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
259-
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
258+
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
259+
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
260260

261261
# Broadcasting
262262

0 commit comments

Comments
 (0)