Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[WIP] KnetLayers -> Knet.Layers20 #613

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Knet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include("ops21/Ops21.jl")
include("ops21_gpu/Ops21_gpu.jl")
include("fileio_gpu/FileIO_gpu.jl")
include("train20/Train20.jl")
include("layers20/Layers20.jl")
# include("layers21/Layers21.jl")

# See if we have a gpu at initialization:
Expand Down
42 changes: 42 additions & 0 deletions src/layers20/Layers20.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module Layers20
import CUDA, Knet
using Knet.KnetArrays
using Knet.Train20
using Knet.Ops20
import Knet.FileIO_gpu: _ser, JLDMODE


"""
Used for setting default underlying array type for layer parameters.

settype!(t::T) where T<:Type{KnetArray{V}} where V <: AbstractFloat = CUDA.functional() ? (global arrtype=t) : error("No GPU available")
settype!(t::T) where T<:Type{Array{V}} where V <: AbstractFloat = (global arrtype=t)
settype!(t::Union{Type{KnetArray},Type{Array}}) = settype!(t{Float32})

# Example
```julia
julia> KnetLayers.settype!(KnetArray) # on a GPU machine
KnetArray{Float32}
```
"""
settype!(t::T) where T<:Type{KnetArray{V}} where V <: AbstractFloat = CUDA.functional() ? (global arrtype=t) : error("No GPU available")
settype!(t::T) where T<:Type{Array{V}} where V <: AbstractFloat = (global arrtype=t)
settype!(t::Union{Type{KnetArray},Type{Array}}) = settype!(t{Float32})
arrtype = Array{Float32}

include("core.jl");
include("primitive.jl"); export Bias, Multiply, Embed, Linear, Dense, BatchNorm, Diagonal, LayerNorm
include("nonlinear.jl"); export NonAct, ReLU,Sigm,Tanh,LeakyReLU,ELU, Dropout, LogSoftMax, SoftMax, LogSumExp, GeLU
include("loss.jl"); export CrossEntropyLoss, BCELoss, LogisticLoss, SigmoidCrossEntropyLoss
include("cnn.jl"); export Pool,UnPool,DeConv,Conv
include("special.jl"); export MLP
include("rnn.jl"); export RNN,SRNN,LSTM,GRU,RNNOutput,PadRNNOutput,PadSequenceArray
include("chain.jl"); export Chain
include("attention.jl"); export MultiheadAttention
include("transformer.jl"); export Transformer, TransformerDecoder, PositionEmbedding, TransformerModel

function __init__()
global arrtype = CUDA.functional() ? KnetArray{Float32} : Array{Float32}
end

end # module
103 changes: 103 additions & 0 deletions src/layers20/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Knet.Layers20

Knet.Layers20 is a submodule that provides useful deep learning layers for [Knet](https://github.com/denizyuret/Knet.jl), fostering your model development. It was originally developed as the independent KnetLayers package by @ekinakyurek.

## Overview
```JULIA
model = Chain(Dense(input=768, output=128, activation=Sigm()),
Dense(input=128, output=10, activation=nothing))

loss(model, x, y) = nll(model(x), y)
```

## Getting Started: Train an MNIST model
```Julia
using Knet, Knet.Layers20
import Knet: Data
#Data
include(Knet.dir("data","mnist.jl"))
dtrn,dtst = mnistdata(xsize=(784,:)); # dtrn and dtst = [ (x1,y1), (x2,y2), ... ] where xi,yi are

#Model
HIDDEN_SIZES = [100,50]
(m::MLP)(x,y) = nll(m(x),y)
(m::MLP)(d::Data) = mean(m(x,y) for (x,y) in d)
model = MLP(784,HIDDEN_SIZES...,10)

#Train
EPOCH=10
progress!(sgd(model,repeat(dtrn,EPOCH)))

#Test
@show 100accuracy(model, dtst)
```

## Example Models

1) [MNIST-MLP](./examples/mnist.jl)

2) [MNIST-CNN](./examples/mnist-cnn.jl)

3) [GAN-MLP](./examples/gan-mlp.ipynb)

4) [ResNet: Residual Networks for Image Recognition](./examples/resnet.jl)

5) [S2S: Sequence to Sequence Reccurent Model](./examples/s2smodel.jl)

6) [Morse.jl: Morphological Analyzer+Lemmatizer](https://github.com/ekinakyurek/Morse.jl)

7) [MAC Network: Memory-Attention-Composition Network for Visual Question Answering](https://github.com/ekinakyurek/Mac-Network)

## [Exported Layers Refence](https://ekinakyurek.github.io/KnetLayers.jl/latest/reference.html#Function-Index-1)

## Example Layers and Usage
```JULIA
using Knet.Layers20

#Instantiate an MLP model with random parameters
mlp = MLP(100,50,20; activation=Sigm()) # input size=100, hidden=50 and output=20

#Do a prediction with the mlp model
prediction = mlp(randn(Float32,100,1))

#Instantiate a convolutional layer with random parameters
cnn = Conv(height=3, width=3, inout=3=>10, padding=1, stride=1) # A conv layer

#Filter your input with the convolutional layer
output = cnn(randn(Float32,224,224,3,1))

#Instantiate an LSTM model
lstm = LSTM(input=100, hidden=100, embed=50)

#You can use integers to represent one-hot vectors.
#Each integer corresponds to vocabulary index of corresponding element in your data.

#For example a pass over 5-Length sequence
rnnoutput = lstm([3,2,1,4,5];hy=true,cy=true)

#After you get the output, you may acces to hidden states and
#intermediate hidden states produced by the lstm model
rnnoutput.y
rnnoutput.hidden
rnnoutput.memory

#You can also use normal array inputs for low-level control
#One iteration of LSTM with a random input
rnnoutput = lstm(randn(100,1);hy=true,cy=true)

#Pass over a random 10-length sequence:
rnnoutput = lstm(randn(100,1,10);hy=true,cy=true)

#Pass over a mini-batch data which includes unequal length sequences
rnnoutput = lstm([[1,2,3,4],[5,6]];sorted=true,hy=true,cy=true)

#To see and modify rnn params in a structured view
lstm.gatesview
```


## TO-DO
3) Examples
4) Special layers such Google's `inception`
5) Known embeddings such `Gloove`
6) Pretrained Models
4 changes: 4 additions & 0 deletions src/layers20/TODO
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- fix README examples
- integrate docs
- integrate examples
- fix __init__()
137 changes: 137 additions & 0 deletions src/layers20/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

abstract type AbstractAttention <: Layer end

struct MultiheadAttention <: AbstractAttention
head::Int
future::Bool
iqproj::Dense
ikproj::Dense
ivproj::Dense
oproj::Dense
drop::Dropout
end

"""
MultiheadAttention(head::Int, is::Int, hs::Int, os::Int; future::Bool=true, pdrop = 0.1)
Multihead dot product Attention Layer, `head` is the number of head, `is` is the input size, `hs` is the hidden size of input projection layer of each head,
`os` is the output size. When `future` is `false`, the k-th token can't see tokens at > k. `pdrop` is the dropout rate.
"""
MultiheadAttention(head::Int,
is::Int,
hs::Int,
os::Int;
future::Bool=true, pdrop = 0.1) = MultiheadAttention(head,
future,
Dense(input=is,output=hs*head),
Dense(input=is,output=hs*head),
Dense(input=is,output=hs*head),
Dense(input=hs*head, output=os),
Dropout(pdrop),
)


function Base.show(io::IO, mh::MultiheadAttention)
hs = div(size(mh.iqproj)[1], mh.head)
is = size(mh.iqproj)[end]
os = size(mh.oproj)[1]

print(io, "MultiheadAttention(")
print(io, "head=$(mh.head), ")
print(io, "head_size=$(hs), ")
print(io, "$(is)=>$(os)")

print(io, ", dropout=$(mh.drop.p))")
end

function (mh::MultiheadAttention)(query,
key,
value;
mask=nothing)
qs = size(query)
ks = size(key)
vs = size(value)
if length(qs) == 3
#size(ipq) == (h, q_seq_len, batch)
ipq = mh.iqproj(query)
ipk = mh.ikproj(key)
ipv = mh.ivproj(value)

h = size(ipq, 1)
hs = div(h, mh.head)

#size(ipq) == (hs, q_seq_len, head, batch)
ipq = permutedims(reshape(ipq, hs, mh.head, qs[2], qs[3]), [1, 3, 2, 4])
ipk = permutedims(reshape(ipk, hs, mh.head, ks[2], ks[3]), [1, 3, 2, 4])
ipv = permutedims(reshape(ipv, hs, mh.head, vs[2], vs[3]), [1, 3, 2, 4])

#size(ipq) == (hs, q_seq_len, head * batch)
ipq = reshape(ipq, hs, qs[2], :)
ipk = reshape(ipk, hs, ks[2], :)
ipv = reshape(ipv, hs, vs[2], :)

atten = attention(ipq,ipk,ipv;
mask=mask,
future=mh.future,
dropout=mh.drop)

atten = permutedims(reshape(atten, hs, qs[2], mh.head, qs[3]), [1, 3, 2, 4]) #size(atten) == (hs, head, ql, b)
atten = reshape(atten, h, qs[2], qs[3]) #size(atten) == (h, ql, b)

return mh.oproj(atten)
else
ipq = mh.iqproj(query)
ipk = mh.ikproj(key)
ipv = mh.ivproj(value)

h = size(ipq)[1] #h == hs * head
hs = div(h, mh.head)

#size(hq) == (hs, seq_len, head)
hq = permutedims(reshape(ipq, hs, mh.head, :), [1, 3, 2])
hk = permutedims(reshape(ipk, hs, mh.head, :), [1, 3, 2])
hv = permutedims(reshape(ipv, hs, mh.head, :), [1, 3, 2])

atten = attention(hq, hk, hv;
mask=mask,
future=mh.future,
dropout=mh.drop)

# size(atten) == (head*hs, seq_len)
atten = reshape(permutedims(atten, [1, 3, 2]), h, :)

return mh.oproj(atten)
end
end

function attention(query,
key,
value;
mask=nothing,
future::Bool = false,
dropout=nothing)
T = eltype(query)
dk = size(key, 1)
score = bmm(key, query; transA=true)
score = score ./ convert(T , sqrt(dk))

s = size(score)

if mask !== nothing
#weird issue on @. mask = (1 - mask) * -1e9 which casue mask to be -Inf
mask = (T(1.0) .- mask) .* T(-1e9)
ms = size(mask)
#score = score .+ mask; use broadcast instead of repeat mask for head
score = reshape(reshape(score, s[1:end-1]..., :, ms[end]) .+ reshape(mask, ms[1:end-1]..., 1, ms[end]), s)
end

if !future
#without ... will cause data move back to cpu
fmask = convert(arrtype,tril!(fill!(Matrix{T}(undef,s[1:end-1]...),T(-1e9)),-1))
#fmask = tril!(fill!(similar(score, s[1:end-1]...), convert(T, -1e9)), -1)
score = score .+ fmask
end

score = softmax(score;dims=1)
dropout !== nothing && (score = dropout(score))
bmm(value, score) #size(return) == (dims, q_seq_len, batch)
end
40 changes: 40 additions & 0 deletions src/layers20/chain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Implementation is taken from [Flux.jl](https://github.com/FluxML/Flux.jl)
"""
Chain(layers...)
Chain multiple layers / functions together, so that they are called in sequence
on a given input.
```julia
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
m = Chain(Dense(input=10, output=5), Dense(input=5, output=2))
x = rand(10)
m(x) == m[2](m[1](x))
```
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers.
"""
struct Chain{T<:Tuple}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
end

children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(Base.tail(fs), first(fs)(x))

(c::Chain)(x) = applychain(c.layers, x)

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
Base.getindex(c::Chain, i::Integer) = c.layers[i]
Base.getindex(c::Chain, ::Colon) = c
Base.length(c::Chain) = length(c.layers)
Base.iterate(c::Chain) = Base.iterate(c.layers)
Base.iterate(c::Chain, state) = Base.iterate(c.layers, state)

function Base.show(io::IO, c::Chain)
print(io, "Chain(")
join(io, c.layers, ", ")
print(io, ")")
end
Loading