Skip to content

Commit c16f63f

Browse files
committed
Add an in-place function tensor_projection!
1 parent e102d7c commit c16f63f

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

src/nlp/api.jl

+26-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export jth_hess_coord, jth_hess_coord!, jth_hess
1111
export jth_hprod, jth_hprod!, ghjvprod, ghjvprod!
1212
export hess_structure!, hess_structure, hess_coord!, hess_coord
1313
export hess, hprod, hprod!, hess_op, hess_op!
14-
export tensor_projection
14+
export tensor_projection, tensor_projection!
1515
export varscale, lagscale, conscale
1616

1717
"""
@@ -1306,13 +1306,33 @@ Returns the projection of the n-th derivative of the objective of `nlp` at `x` a
13061306
13071307
#### Input arguments
13081308
1309-
- `nlp::AbstractNLPModel`: An NLP model;
1310-
- `n::Int`: The order of the derivative to compute;
1311-
- `x::AbstractVector`: The point at which the derivative is evaluated;
1312-
- `directions::Tuple{Int, Vararg{Int}}`: A tuple of indices specifying the directions (e.g., `(1, 2)` for a tensor projection along the first and second axes);
1309+
- `nlp`: An NLP model;
1310+
- `n`: The order of the derivative to compute;
1311+
- `x`: The point at which the derivative is evaluated;
1312+
- `directions`: A tuple of indices specifying the directions (e.g., `(1, 2)` for a tensor projection along the first and second axes);
13131313
- `args...`: A list of vectors, one for each direction specified in `directions`.
13141314
"""
1315-
function tensor_projection end
1315+
function tensor_projection(
1316+
nlp::AbstractNLPModel{T, S},
1317+
n::Int,
1318+
x::AbstractVector,
1319+
directions::Tuple{Int, Vararg{Int}},
1320+
args...
1321+
) where {T, S}
1322+
@lencheck nlp.meta.nvar x
1323+
m = n - length(directions)
1324+
@assert m 1
1325+
dim = NTuple{m, Int}(nlp.meta.nvar for i = 1:m)
1326+
P = similar(x, dim)
1327+
return tensor_projection!(nlp, n, x, directions, P, args...)
1328+
end
1329+
1330+
"""
1331+
tensor_projection!(nlp, n, x, directions, P, args...)
1332+
1333+
In-place version of the function [`tensor_projection`](@ref) where the result is stored in `P`.
1334+
"""
1335+
function tensor_projection! end
13161336

13171337
function varscale end
13181338
function lagscale end

0 commit comments

Comments
 (0)