Skip to content

Commit de76e08

Browse files
bors[bot]CarloLucibello
andauthoredJun 17, 2021
Merge #1613
1613: use ArrayInterface.restructure in update! r=CarloLucibello a=CarloLucibello Suggestion coming from @ChrisRackauckas in FluxML/Zygote.jl#989. Now `update!` handles basically any gradient Zygote emits, e.g. FillArrays and Zygote.OneElement. Fix #1510 Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
2 parents 27c4c77 + a77b32f commit de76e08

File tree

6 files changed

+164
-68
lines changed

6 files changed

+164
-68
lines changed
 

‎Manifest.toml

+95-55
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@ version = "0.3.4"
1313

1414
[[Adapt]]
1515
deps = ["LinearAlgebra"]
16-
git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db"
16+
git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7"
1717
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
18-
version = "3.3.0"
18+
version = "3.3.1"
1919

2020
[[ArgTools]]
2121
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
2222

23+
[[ArrayInterface]]
24+
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
25+
git-tree-sha1 = "045ff5e1bc8c6fb1ecb28694abba0a0d55b5f4f5"
26+
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
27+
version = "3.1.17"
28+
2329
[[Artifacts]]
2430
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
2531

@@ -38,22 +44,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
3844
version = "0.4.1"
3945

4046
[[CUDA]]
41-
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"]
42-
git-tree-sha1 = "a6ce96dcf22fc4f1bfdfac02d54f0b77ecf2a4cc"
47+
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
48+
git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad"
4349
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
44-
version = "3.0.3"
50+
version = "3.2.1"
4551

4652
[[ChainRules]]
47-
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"]
48-
git-tree-sha1 = "1f410fba5c04d03ab712f348f1542e6059376547"
53+
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
54+
git-tree-sha1 = "720fa9a9ce61ff18842a40f501d6a1f8ba771c64"
4955
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
50-
version = "0.7.61"
56+
version = "0.8.6"
5157

5258
[[ChainRulesCore]]
5359
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
54-
git-tree-sha1 = "42e3c181483fbd2c416087a0a93838803e358358"
60+
git-tree-sha1 = "8b31cc69cbc38c5c826aaa1c890c694be3622d99"
5561
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
56-
version = "0.9.38"
62+
version = "0.10.3"
5763

5864
[[CodecZlib]]
5965
deps = ["TranscodingStreams", "Zlib_jll"]
@@ -63,15 +69,15 @@ version = "0.7.0"
6369

6470
[[ColorTypes]]
6571
deps = ["FixedPointNumbers", "Random"]
66-
git-tree-sha1 = "32a2b8af383f11cbb65803883837a149d10dfe8a"
72+
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
6773
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
68-
version = "0.10.12"
74+
version = "0.11.0"
6975

7076
[[Colors]]
7177
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"]
72-
git-tree-sha1 = "82f4e6ff9f847eca3e5ebc666ea2cd7b48e8b47e"
78+
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40"
7379
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
74-
version = "0.12.7"
80+
version = "0.12.8"
7581

7682
[[CommonSubexpressions]]
7783
deps = ["MacroTools", "Test"]
@@ -81,9 +87,9 @@ version = "0.3.0"
8187

8288
[[Compat]]
8389
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
84-
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
90+
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
8591
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
86-
version = "3.27.0"
92+
version = "3.30.0"
8793

8894
[[CompilerSupportLibraries_jll]]
8995
deps = ["Artifacts", "Libdl"]
@@ -124,6 +130,12 @@ version = "1.0.2"
124130
deps = ["Random", "Serialization", "Sockets"]
125131
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
126132

133+
[[DocStringExtensions]]
134+
deps = ["LibGit2"]
135+
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
136+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
137+
version = "0.8.5"
138+
127139
[[Downloads]]
128140
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
129141
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
@@ -158,23 +170,28 @@ uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
158170
version = "0.2.1"
159171

160172
[[GPUArrays]]
161-
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
162-
git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957"
173+
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
174+
git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086"
163175
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
164-
version = "6.2.2"
176+
version = "6.4.1"
165177

166178
[[GPUCompiler]]
167179
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
168-
git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6"
180+
git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a"
169181
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
170-
version = "0.11.4"
182+
version = "0.11.5"
171183

172184
[[IRTools]]
173185
deps = ["InteractiveUtils", "MacroTools", "Test"]
174186
git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
175187
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
176188
version = "0.4.2"
177189

190+
[[IfElse]]
191+
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
192+
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
193+
version = "0.1.0"
194+
178195
[[InteractiveUtils]]
179196
deps = ["Markdown"]
180197
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -193,9 +210,9 @@ version = "0.8.4"
193210

194211
[[LLVM]]
195212
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
196-
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
213+
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
197214
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
198-
version = "3.6.0"
215+
version = "3.7.1"
199216

200217
[[LazyArtifacts]]
201218
deps = ["Artifacts", "Pkg"]
@@ -224,6 +241,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
224241
deps = ["Libdl"]
225242
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
226243

244+
[[LogExpFunctions]]
245+
deps = ["DocStringExtensions", "LinearAlgebra"]
246+
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
247+
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
248+
version = "0.2.4"
249+
227250
[[Logging]]
228251
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
229252

@@ -255,9 +278,9 @@ version = "0.4.4"
255278

256279
[[Missings]]
257280
deps = ["DataAPI"]
258-
git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c"
281+
git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
259282
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
260-
version = "0.4.5"
283+
version = "1.0.0"
261284

262285
[[Mmap]]
263286
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@@ -267,15 +290,15 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
267290

268291
[[NNlib]]
269292
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
270-
git-tree-sha1 = "80b8360670f445d88b3475e88b33bbcc92f7866e"
293+
git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
271294
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
272-
version = "0.7.19"
295+
version = "0.7.21"
273296

274297
[[NNlibCUDA]]
275298
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
276-
git-tree-sha1 = "4b368b466bcdd25d448a5b20de4b7e481d68b88e"
299+
git-tree-sha1 = "bd8b29bf75be7a6c2b288b4b9a4e8903d0376ac1"
277300
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
278-
version = "0.1.0"
301+
version = "0.1.3"
279302

280303
[[NaNMath]]
281304
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
@@ -287,24 +310,24 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
287310

288311
[[OpenSpecFun_jll]]
289312
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
290-
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
313+
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
291314
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
292-
version = "0.5.3+4"
315+
version = "0.5.5+0"
293316

294317
[[OrderedCollections]]
295-
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
318+
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
296319
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
297-
version = "1.4.0"
320+
version = "1.4.1"
298321

299322
[[Pkg]]
300323
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
301324
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
302325

303326
[[Preferences]]
304327
deps = ["TOML"]
305-
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
328+
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
306329
uuid = "21216c6a-2e73-6563-6e65-726566657250"
307-
version = "1.2.1"
330+
version = "1.2.2"
308331

309332
[[Printf]]
310333
deps = ["Unicode"]
@@ -322,16 +345,22 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
322345
deps = ["Serialization"]
323346
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
324347

348+
[[Random123]]
349+
deps = ["Libdl", "Random", "RandomNumbers"]
350+
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
351+
uuid = "74087812-796a-5b5d-8853-05524746bad3"
352+
version = "1.3.1"
353+
325354
[[RandomNumbers]]
326355
deps = ["Random", "Requires"]
327356
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
328357
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
329358
version = "1.4.0"
330359

331360
[[Reexport]]
332-
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
361+
git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220"
333362
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
334-
version = "1.0.0"
363+
version = "1.1.0"
335364

336365
[[Requires]]
337366
deps = ["UUIDs"]
@@ -344,9 +373,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
344373

345374
[[Scratch]]
346375
deps = ["Dates"]
347-
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
376+
git-tree-sha1 = "0b4b7f1393cff97c33891da2a0bf69c6ed241fda"
348377
uuid = "6c6a2e73-6563-6170-7368-637461726353"
349-
version = "1.0.3"
378+
version = "1.1.0"
350379

351380
[[Serialization]]
352381
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -359,36 +388,47 @@ uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
359388
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
360389

361390
[[SortingAlgorithms]]
362-
deps = ["DataStructures", "Random", "Test"]
363-
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
391+
deps = ["DataStructures"]
392+
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
364393
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
365-
version = "0.3.1"
394+
version = "1.0.0"
366395

367396
[[SparseArrays]]
368397
deps = ["LinearAlgebra", "Random"]
369398
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
370399

371400
[[SpecialFunctions]]
372-
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
373-
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
401+
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
402+
git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49"
374403
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
375-
version = "1.3.0"
404+
version = "1.5.1"
405+
406+
[[Static]]
407+
deps = ["IfElse"]
408+
git-tree-sha1 = "2740ea27b66a41f9d213561a04573da5d3823d4b"
409+
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
410+
version = "0.2.5"
376411

377412
[[StaticArrays]]
378413
deps = ["LinearAlgebra", "Random", "Statistics"]
379-
git-tree-sha1 = "e8cd1b100d37f5b4cfd2c83f45becf61c762eaf7"
414+
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
380415
uuid = "90137ffa-7385-5640-81b9-e52037218182"
381-
version = "1.1.1"
416+
version = "1.2.2"
382417

383418
[[Statistics]]
384419
deps = ["LinearAlgebra", "SparseArrays"]
385420
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
386421

422+
[[StatsAPI]]
423+
git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
424+
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
425+
version = "1.0.0"
426+
387427
[[StatsBase]]
388-
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
389-
git-tree-sha1 = "4bc58880426274277a066de306ef19ecc22a6863"
428+
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
429+
git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
390430
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
391-
version = "0.33.5"
431+
version = "0.33.8"
392432

393433
[[TOML]]
394434
deps = ["Dates"]
@@ -403,10 +443,10 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
403443
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
404444

405445
[[TimerOutputs]]
406-
deps = ["Printf"]
407-
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
446+
deps = ["ExprTools", "Printf"]
447+
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
408448
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
409-
version = "0.5.8"
449+
version = "0.5.9"
410450

411451
[[TranscodingStreams]]
412452
deps = ["Random", "Test"]
@@ -433,9 +473,9 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
433473

434474
[[Zygote]]
435475
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
436-
git-tree-sha1 = "927209c83efa62256788a9880c191774c07c5b51"
476+
git-tree-sha1 = "b1d95edd4e693066c38c13a10aab0a8f6a6e2f65"
437477
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
438-
version = "0.6.10"
478+
version = "0.6.12"
439479

440480
[[ZygoteRules]]
441481
deps = ["MacroTools"]

‎Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.12.4"
55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
910
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
1011
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
@@ -29,6 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2930
[compat]
3031
AbstractTrees = "0.3"
3132
Adapt = "3.0"
33+
ArrayInterface = "3.1"
3234
CUDA = "3"
3335
CodecZlib = "0.7"
3436
Colors = "0.12"
@@ -44,10 +46,12 @@ Zygote = "0.6"
4446
julia = "1.6"
4547

4648
[extras]
49+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4750
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
51+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4852
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
4953
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5054
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5155

5256
[targets]
53-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra"]
57+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]

‎src/Flux.jl

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using Zygote, MacroTools, Juno, Reexport
88
using MacroTools: @forward
99
@reexport using NNlib
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd
11-
1211
export gradient
1312

1413
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,

‎src/optimise/Optimise.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module Optimise
22

33
using LinearAlgebra
4+
import ArrayInterface
45

56
export train!, update!,
67
Descent, ADAM, Momentum, Nesterov, RMSProp,

0 commit comments

Comments
 (0)