@@ -242,21 +242,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
242
242
@back (w, NNlib.∇conv_filter (Δ, data (x), data (w); stride = stride, pad = pad))
243
243
end
244
244
245
- _maxpool (x, k, pad) = maxpool (x, k; pad = pad)
245
+ _maxpool (x, k, pad, stride ) = maxpool (x, k; pad = pad, stride = stride )
246
246
247
- maxpool (x:: TrackedArray , k; pad = map (_-> 0 ,k)) =
248
- track (_maxpool, x, k, pad)
247
+ maxpool (x:: TrackedArray , k; pad = map (_-> 0 ,k), stride = k ) =
248
+ track (_maxpool, x, k, pad, stride )
249
249
250
- back_ (:: typeof (_maxpool), y, Δ, x, k, pad) =
251
- back (x, NNlib.∇maxpool (Δ, y, data (x), k, pad= pad))
250
+ back_ (:: typeof (_maxpool), y, Δ, x, k, pad, stride ) =
251
+ back (x, NNlib.∇maxpool (Δ, y, data (x), k, pad= pad, stride = stride ))
252
252
253
- _meanpool (x, k, pad) = meanpool (x, k; pad = pad)
253
+ _meanpool (x, k, pad, stride ) = meanpool (x, k; pad = pad, stride = stride )
254
254
255
- meanpool (x:: TrackedArray , k; pad = map (_-> 0 ,k)) =
256
- track (_meanpool, x, k, pad)
255
+ meanpool (x:: TrackedArray , k; pad = map (_-> 0 ,k), stride = k ) =
256
+ track (_meanpool, x, k, pad, stride )
257
257
258
- back_ (:: typeof (_meanpool), y, Δ, x, k, pad) =
259
- back (x, NNlib.∇meanpool (Δ, y, data (x), k, pad= pad))
258
+ back_ (:: typeof (_meanpool), y, Δ, x, k, pad, stride ) =
259
+ back (x, NNlib.∇meanpool (Δ, y, data (x), k, pad= pad, stride = stride ))
260
260
261
261
# Broadcasting
262
262
0 commit comments