@@ -261,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
261
261
@back (w, NNlib.∇conv_filter (Δ, data (x), data (w); stride = stride, pad = pad))
262
262
end
263
263
264
- _maxpool (x, k, pad) = maxpool (x, k; pad = pad)
264
+ _maxpool (x, k, pad, stride ) = maxpool (x, k; pad = pad, stride = stride )
265
265
266
- maxpool (x:: TrackedArray , k; pad = map (_-> 0 ,k)) =
267
- track (_maxpool, x, k, pad)
266
+ maxpool (x:: TrackedArray , k; pad = map (_-> 0 ,k), stride = k ) =
267
+ track (_maxpool, x, k, pad, stride )
268
268
269
- back_ (:: typeof (_maxpool), y, Δ, x, k, pad) =
270
- back (x, NNlib.∇maxpool (Δ, y, data (x), k, pad= pad))
269
+ back_ (:: typeof (_maxpool), y, Δ, x, k, pad, stride ) =
270
+ back (x, NNlib.∇maxpool (Δ, y, data (x), k, pad= pad, stride = stride ))
271
271
272
- _meanpool (x, k, pad) = meanpool (x, k; pad = pad)
272
+ _meanpool (x, k, pad, stride ) = meanpool (x, k; pad = pad, stride = stride )
273
273
274
- meanpool (x:: TrackedArray , k; pad = map (_-> 0 ,k)) =
275
- track (_meanpool, x, k, pad)
274
+ meanpool (x:: TrackedArray , k; pad = map (_-> 0 ,k), stride = k ) =
275
+ track (_meanpool, x, k, pad, stride )
276
276
277
- back_ (:: typeof (_meanpool), y, Δ, x, k, pad) =
278
- back (x, NNlib.∇meanpool (Δ, y, data (x), k, pad= pad))
277
+ back_ (:: typeof (_meanpool), y, Δ, x, k, pad, stride ) =
278
+ back (x, NNlib.∇meanpool (Δ, y, data (x), k, pad= pad, stride = stride ))
279
279
280
280
# Broadcasting
281
281
0 commit comments