Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Extend the parametric keywords to include anything (#66) #67

Merged
merged 2 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CompositionalNetworks"
uuid = "4b67e4b5-442d-4ef5-b760-3f5df3a57537"
authors = ["Jean-François Baffier"]
version = "0.5.6"
version = "0.5.7"

[deps]
ConstraintCommons = "e37357d9-0691-492f-a822-e5ea6a920954"
Expand Down
8 changes: 4 additions & 4 deletions src/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function generate(c::Composition, name, ::Val{:Julia})
co = reduce_symbols(symbs[4], ", ", false; prefix = CN * "co_")

documentation = """\"\"\"
$name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
$name(x; X = zeros(length(x), $tr_length), params...)

Composition `$name` generated by CompositionalNetworks.jl.
```
Expand All @@ -85,10 +85,10 @@ function generate(c::Composition, name, ::Val{:Julia})
"""

output = """
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
$(CN)tr_in(Tuple($tr), X, x, param)
function $name(x; X = zeros(length(x), $tr_length), dom_size, params...)
$(CN)tr_in(Tuple($tr), X, x; params)
X[1:length(x), 1] .= 1:length(x) .|> (i -> $ar(@view X[i, 1:$tr_length]))
return $ag(@view X[:, 1]) |> (y -> $co(y; param, dom_size, nvars=length(x)))
return $ag(@view X[:, 1]) |> (y -> $co(y; dom_size, nvars=length(x), params...))
end
"""
return documentation * format_text(output, BlueStyle(); pipe_to_function_call = false)
Expand Down
11 changes: 3 additions & 8 deletions src/icn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,11 @@ function _compose(icn::ICN)
end
end

function composition(
x;
X = zeros(length(x), length(funcs[1])),
param = nothing,
dom_size,
)
tr_in(Tuple(funcs[1]), X, x, param)
function composition(x; X = zeros(length(x), length(funcs[1])), dom_size, params...)
tr_in(Tuple(funcs[1]), X, x; params...)
X[1:length(x), 1] .=
1:length(x) .|> (i -> funcs[2][1](@view X[i, 1:length(funcs[1])]))
return (y -> funcs[4][1](y; param, dom_size, nvars = length(x)))(
return (y -> funcs[4][1](y; dom_size, nvars = length(x), params...))(
funcs[3][1](@view X[:, 1]),
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end

"""
generate_exclusive_operation(max_op_number)
Generates the operations (weigths) of a layer with exclusive operations.
Generates the operations (weights) of a layer with exclusive operations.
"""
function generate_exclusive_operation(max_op_number)
op = rand(1:max_op_number)
Expand Down
80 changes: 39 additions & 41 deletions src/layers/comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,59 @@
co_identity(x)
Identity function. Already defined in Julia as `identity`, specialized for scalars in the `comparison` layer.
"""
co_identity(x; param = nothing, dom_size = 0, nvars = 0) = identity(x)
co_identity(x; params...) = identity(x)

"""
co_abs_diff_val_param(x; param)
Return the absolute difference between `x` and `param`.
co_abs_diff_var_val(x; val)
Return the absolute difference between `x` and `val`.
"""
co_abs_diff_val_param(x; param, dom_size = 0, nvars = 0) = abs(x - param)
co_abs_diff_var_val(x; val, params...) = abs(x - val)

"""
co_val_minus_param(x; param)
Return the difference `x - param` if positive, `0.0` otherwise.
co_var_minus_val(x; val)
Return the difference `x - val` if positive, `0.0` otherwise.
"""
co_val_minus_param(x; param, dom_size = 0, nvars = 0) = max(0.0, x - param)
co_var_minus_val(x; val, params...) = max(0.0, x - val)

"""
co_param_minus_val(x; param)
Return the difference `param - x` if positive, `0.0` otherwise.
co_val_minus_var(x; val)
Return the difference `val - x` if positive, `0.0` otherwise.
"""
co_param_minus_val(x; param, dom_size = 0, nvars = 0) = max(0.0, param - x)
co_val_minus_var(x; val, params...) = max(0.0, val - x)

"""
co_euclidean_param(x; param, dom_size)
Compute an euclidean norm with domain size `dom_size`, weighted by `param`, of a scalar.
co_euclidean_val(x; val, dom_size)
Compute an euclidean norm with domain size `dom_size`, weighted by `val`, of a scalar.
"""
function co_euclidean_param(x; param, dom_size, nvars = 0)
return x == param ? 0.0 : (1.0 + abs(x - param) / dom_size)
function co_euclidean_val(x; val, dom_size, params...)
return x == val ? 0.0 : (1.0 + abs(x - val) / dom_size)
end

"""
co_euclidean(x; dom_size)
Compute an euclidean norm with domain size `dom_size` of a scalar.
"""
function co_euclidean(x; param = nothing, dom_size, nvars = 0)
return co_euclidean_param(x; param = 0.0, dom_size = dom_size)
function co_euclidean(x; dom_size, params...)
return co_euclidean_val(x; val = 0.0, dom_size)
end

"""
co_abs_diff_val_vars(x; nvars)
co_abs_diff_var_vars(x; nvars)
Return the absolute difference between `x` and the number of variables `nvars`.
"""
co_abs_diff_val_vars(x; param = nothing, dom_size = 0, nvars) = abs(x - nvars)
co_abs_diff_var_vars(x; nvars, params...) = abs(x - nvars)

"""
co_val_minus_vars(x; nvars)
co_var_minus_vars(x; nvars)
Return the difference `x - nvars` if positive, `0.0` otherwise, where `nvars` denotes the numbers of variables.
"""
co_val_minus_vars(x; param = nothing, dom_size = 0, nvars) =
co_val_minus_param(x; param = nvars)
co_var_minus_vars(x; nvars, params...) = co_var_minus_val(x; val = nvars)

"""
co_vars_minus_val(x; nvars)
co_vars_minus_var(x; nvars)
Return the difference `nvars - x` if positive, `0.0` otherwise, where `nvars` denotes the numbers of variables.
"""
co_vars_minus_val(x; param = nothing, dom_size = 0, nvars) =
co_param_minus_val(x; param = nvars)
co_vars_minus_var(x; nvars, params...) = co_val_minus_var(x; val = nvars)


# Parametric layers
Expand All @@ -66,18 +64,18 @@ function make_comparisons(::Val{:none})
return LittleDict{Symbol,Function}(
:identity => co_identity,
:euclidean => co_euclidean,
:abs_diff_val_vars => co_abs_diff_val_vars,
:val_minus_vars => co_val_minus_vars,
:vars_minus_val => co_vars_minus_val,
:abs_diff_var_vars => co_abs_diff_var_vars,
:var_minus_vars => co_var_minus_vars,
:vars_minus_var => co_vars_minus_var,
)
end

function make_comparisons(::Val{:val})
return LittleDict{Symbol,Function}(
:abs_diff_val_param => co_abs_diff_val_param,
:val_minus_param => co_val_minus_param,
:param_minus_val => co_param_minus_val,
:euclidean_param => co_euclidean_param,
:abs_diff_var_val => co_abs_diff_var_val,
:var_minus_val => co_var_minus_val,
:val_minus_var => co_val_minus_var,
:euclidean_val => co_euclidean_val,
)
end

Expand Down Expand Up @@ -113,21 +111,21 @@ end
end

funcs_param = [
CN.co_abs_diff_val_param => [2, 5],
CN.co_val_minus_param => [2, 0],
CN.co_param_minus_val => [0, 5],
CN.co_abs_diff_var_val => [2, 5],
CN.co_var_minus_val => [2, 0],
CN.co_val_minus_var => [0, 5],
]

for (f, results) in funcs_param
for (key, vals) in enumerate(data)
@test f(vals.first; param = vals.second[1]) == results[key]
@test f(vals.first; val = vals.second[1]) == results[key]
end
end

funcs_vars = [
CN.co_abs_diff_val_vars => [2, 0],
CN.co_val_minus_vars => [0, 0],
CN.co_vars_minus_val => [2, 0],
CN.co_abs_diff_var_vars => [2, 0],
CN.co_var_minus_vars => [0, 0],
CN.co_vars_minus_var => [2, 0],
]

for (f, results) in funcs_vars
Expand All @@ -136,11 +134,11 @@ end
end
end

funcs_param_dom = [CN.co_euclidean_param => [1.4, 2.0]]
funcs_val_dom = [CN.co_euclidean_val => [1.4, 2.0]]

for (f, results) in funcs_param_dom
for (f, results) in funcs_val_dom
for (key, vals) in enumerate(data)
@test f(vals.first, param = vals.second[1], dom_size = vals.second[2]) ≈
@test f(vals.first, val = vals.second[1], dom_size = vals.second[2]) ≈
results[key]
end
end
Expand Down
Loading
Loading