Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add CTC loss to new Losses module #1287

Merged
merged 36 commits into from
Jan 20, 2021
Merged
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9e31e53
Add CTC loss and tests
maetshju Jul 19, 2020
37efaa0
Add ctc to Losses module
maetshju Jul 20, 2020
f471337
General updates
maetshju Oct 13, 2020
b19b88c
Reverting bad merge
maetshju Oct 13, 2020
da3564b
Revert "General updates"
maetshju Oct 13, 2020
d8242c0
General ctc updates
maetshju Oct 13, 2020
5bf2635
Get test cases working
maetshju Oct 15, 2020
1707255
Merge branch 'master' into ctc
maetshju Oct 17, 2020
e4123ab
Update NEWS.md
maetshju Oct 17, 2020
3045ed2
Re-pull from main Flux repo
maetshju Dec 7, 2020
46898ed
Change ctc to ctc_loss
maetshju Dec 14, 2020
110a608
Fix typo
maetshju Dec 14, 2020
5145222
Remove camel-casing
maetshju Dec 14, 2020
e002027
Remove some whitespace from functions
maetshju Dec 14, 2020
d0bd3bd
Adding info to comply with Apache license
maetshju Dec 14, 2020
5dafa05
Use logsumexp in CPU CTC
maetshju Dec 18, 2020
3950464
Change logsum to logsumexp
maetshju Dec 18, 2020
282cb23
Re-add logaddexp function to CPU ctc
maetshju Dec 19, 2020
50fc561
Merge branch 'master' into ctc
maetshju Dec 19, 2020
d9caac4
Regnerate Manifest.toml
maetshju Dec 19, 2020
c043855
Remove time indexing from ctc gradchecks
maetshju Dec 20, 2020
3eb9e51
Revert "Remove time indexing from ctc gradchecks"
maetshju Dec 20, 2020
bccef7a
Change typedZero to typed_zero
maetshju Dec 20, 2020
00d4125
Update gradcheck comment; remove some whitespace
maetshju Dec 22, 2020
9b37c8f
Reduce allocations for better performance
maetshju Dec 24, 2020
75bb3c1
Transpose alpha and beta to match GPU kernel
maetshju Dec 24, 2020
b00487a
Update add_blanks to use fill
maetshju Dec 24, 2020
d96a53f
Split CPU loss and gradient calculation
maetshju Dec 31, 2020
6ca07e2
Rejig CPU CTC API
maetshju Jan 3, 2021
bc7ab03
Split GPU CTC kernel and update API
maetshju Jan 3, 2021
9bb245d
Merge branch 'master' into ctc
maetshju Jan 3, 2021
807aefa
Move to onecold representation for ctc input
maetshju Jan 14, 2021
6f108c6
Merge branch 'master' into ctc
maetshju Jan 14, 2021
e1e8cc8
Apply suggestions from code review
maetshju Jan 16, 2021
6e5fb17
Remove F in ctc tests; update ctc-gpu test syntax
maetshju Jan 16, 2021
bc94a16
Fix indentation in ctc.jl
maetshju Jan 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rejig CPU CTC API
  • Loading branch information
maetshju committed Jan 3, 2021
commit 6ca07e2df2de53b4f5610c4eb6d957863cbd94ea
19 changes: 12 additions & 7 deletions src/losses/ctc.jl
Original file line number Diff line number Diff line change
@@ -34,11 +34,12 @@ function F(A, blank)
prev = A[1]
z = [prev]
for curr in A[2:end]
if curr != prev && curr != blank
if curr != prev
push!(z, curr)
end
prev = curr
end
filter!(x -> x != blank, z)
return z
end

@@ -53,7 +54,7 @@ function add_blanks(z, blank)
return z′
end

function ctc_(ŷ, y)
function ctc_alpha(ŷ::AbstractArray, y)
typed_zero = zero(ŷ[1])
ŷ = logsoftmax(ŷ)
blank = size(ŷ, 1)
@@ -84,8 +85,8 @@ function ctc_(ŷ, y)
return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ)
end

@adjoint function ctc_(ŷ, y)
loss, α, z′, ŷ = ctc_(ŷ, y)
function ∇ctc_loss(ŷ::AbstractArray, y, out)
loss, α, z′, ŷ = out
U′, T = size(α)
blank = U′
typed_zero = zero(first(α))
@@ -124,7 +125,7 @@ end
end
end
grads = exp.(ŷ) .- exp.(accum .+ loss)
return loss, g -> (g .* grads, nothing)
return grads
end

"""
@@ -147,6 +148,10 @@ solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves
or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7)
for mathematical details.
"""
function ctc_loss(ŷ::Array, y::Array)
return ctc_(ŷ, y)[1]
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss

@adjoint function ctc_loss(ŷ, y)
out = ctc_alpha(ŷ, y)
ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing)
return out.loss, ctc_loss_pullback
end