diff --git a/src/resnet.jl b/src/resnet.jl index bdea58e47..65a865b34 100644 --- a/src/resnet.jl +++ b/src/resnet.jl @@ -39,8 +39,7 @@ function Bottleneck(filters::Int, downsample::Bool = false, res_top::Bool = fals end end -function resnet50() - local layers = [3, 4, 6, 3] +function resnet(layers = [3,4,6,3]) local layer_arr = [] push!(layer_arr, Conv((7,7), 3=>64, pad = (3,3), stride = (2,2))) @@ -63,24 +62,24 @@ function resnet50() Chain(layer_arr...) end -function resnet_layers() - weight = Metalhead.weights("resnet.bson") - weights = Dict{Any ,Any}() - for ele in keys(weight) - weights[string(ele)] = convert(Array{Float64, N} where N, weight[ele]) - end - ls = resnet50() - ls[1].weight .= weights["gpu_0/conv1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] - count = 2 - for j in [3:5, 6:9, 10:15, 16:18] - for p in j - ls[p].conv_layers[1].weight .= weights["gpu_0/res$(count)_$(p-j[1])_branch2a_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] - ls[p].conv_layers[2].weight .= weights["gpu_0/res$(count)_$(p-j[1])_branch2b_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] - ls[p].conv_layers[3].weight .= weights["gpu_0/res$(count)_$(p-j[1])_branch2c_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] - end - count += 1 - end - ls[21].W .= transpose(weights["gpu_0/pred_w_0"]); ls[21].b .= weights["gpu_0/pred_b_0"] +function resnet_layers(layers = [3,4,6,3]) + # weight = Metalhead.weights("resnet.bson") + # weights = Dict{Any ,Any}() + # for ele in keys(weight) + # weights[string(ele)] = convert(Array{Float64, N} where N, weight[ele]) + # end + ls = resnet(layers) + # ls[1].weight .= weights["gpu_0/conv1_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] + # count = 2 + # for j in [3:5, 6:9, 10:15, 16:18] + # for p in j + # ls[p].conv_layers[1].weight .= weights["gpu_0/res$(count)_$(p-j[1])_branch2a_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] + # ls[p].conv_layers[2].weight .= weights["gpu_0/res$(count)_$(p-j[1])_branch2b_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] + # ls[p].conv_layers[3].weight .= weights["gpu_0/res$(count)_$(p-j[1])_branch2c_w_0"][end:-1:1,:,:,:][:,end:-1:1,:,:] + # end + # count += 1 + # end + # ls[21].W .= transpose(weights["gpu_0/pred_w_0"]); ls[21].b .= weights["gpu_0/pred_b_0"] return ls end @@ -89,6 +88,9 @@ struct ResNet <: ClassificationModel{ImageNet.ImageNet1k} end ResNet() = ResNet(resnet_layers()) +ResNet50() = ResNet() +ResNet101() = ResNet(resnet_layers([3,4,23,3])) +ResNet152() = ResNet(resnet_layers([3,8,36,3])) Base.show(io::IO, ::ResNet) = print(io, "ResNet()")