Skip to content

Commit 0d470da

Browse files
authored
docs: move notes below the fold and highlight RFC 2119 keywords
PR-URL: #955 Ref: #397
1 parent a52285e commit 0d470da

File tree

1 file changed

+62
-56
lines changed

1 file changed

+62
-56
lines changed

src/array_api_stubs/_draft/linear_algebra_functions.py

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,57 @@ def matmul(x1: array, x2: array, /) -> array:
88
"""
99
Computes the matrix product.
1010
11-
.. note::
12-
The ``matmul`` function must implement the same semantics as the built-in ``@`` operator (see `PEP 465 <https://www.python.org/dev/peps/pep-0465>`_).
13-
1411
Parameters
1512
----------
1613
x1: array
17-
first input array. Should have a numeric data type. Must have at least one dimension. If ``x1`` is one-dimensional having shape ``(M,)`` and ``x2`` has more than one dimension, ``x1`` must be promoted to a two-dimensional array by prepending ``1`` to its dimensions (i.e., must have shape ``(1, M)``). After matrix multiplication, the prepended dimensions in the returned array must be removed. If ``x1`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x1)[:-2]`` must be compatible with ``shape(x2)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`). If ``x1`` has shape ``(..., M, K)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
18-
x2: array
19-
second input array. Should have a numeric data type. Must have at least one dimension. If ``x2`` is one-dimensional having shape ``(N,)`` and ``x1`` has more than one dimension, ``x2`` must be promoted to a two-dimensional array by appending ``1`` to its dimensions (i.e., must have shape ``(N, 1)``). After matrix multiplication, the appended dimensions in the returned array must be removed. If ``x2`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`). If ``x2`` has shape ``(..., K, N)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
14+
first input array. **Should** have a numeric data type. **Must** have at least one dimension.
15+
16+
- If ``x1`` is a one-dimensional array having shape ``(M,)`` and ``x2`` has more than one dimension, ``x1`` **must** be promoted to a two-dimensional array by prepending ``1`` to its dimensions (i.e., **must** have shape ``(1, M)``). After matrix multiplication, the prepended dimensions in the returned array **must** be removed.
17+
- If ``x1`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x1)[:-2]`` **must** be compatible with ``shape(x2)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`).
18+
- If ``x1`` has shape ``(..., M, K)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
2019
20+
x2: array
21+
second input array. **Should** have a numeric data type. **Must** have at least one dimension.
2122
22-
.. note::
23-
If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the matrix product.
23+
- If ``x2`` is one-dimensional array having shape ``(N,)`` and ``x1`` has more than one dimension, ``x2`` **must** be promoted to a two-dimensional array by appending ``1`` to its dimensions (i.e., **must** have shape ``(N, 1)``). After matrix multiplication, the appended dimensions in the returned array **must** be removed.
24+
- If ``x2`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x2)[:-2]`` **must** be compatible with ``shape(x1)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`).
25+
- If ``x2`` has shape ``(..., K, N)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
2426
2527
Returns
2628
-------
2729
out: array
28-
- if both ``x1`` and ``x2`` are one-dimensional arrays having shape ``(N,)``, a zero-dimensional array containing the inner product as its only element.
29-
- if ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, a two-dimensional array containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ and having shape ``(M, N)``.
30-
- if ``x1`` is a one-dimensional array having shape ``(K,)`` and ``x2`` is an array having shape ``(..., K, N)``, an array having shape ``(..., N)`` (i.e., prepended dimensions during vector-to-matrix promotion must be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
31-
- if ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a one-dimensional array having shape ``(K,)``, an array having shape ``(..., M)`` (i.e., appended dimensions during vector-to-matrix promotion must be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
32-
- if ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is an array having shape ``(..., K, N)``, an array having shape ``(..., M, N)`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
33-
- if ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, an array having shape ``(..., M, N)`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
34-
- if either ``x1`` or ``x2`` has more than two dimensions, an array having a shape determined by :ref:`broadcasting` ``shape(x1)[:-2]`` against ``shape(x2)[:-2]`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
30+
output array.
31+
32+
- If both ``x1`` and ``x2`` are one-dimensional arrays having shape ``(N,)``, the returned array **must** be a zero-dimensional array and **must** contain the inner product as its only element.
33+
- If ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, the returned array **must** be a two-dimensional array and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ and having shape ``(M, N)``.
34+
- If ``x1`` is a one-dimensional array having shape ``(K,)`` and ``x2`` is an array having shape ``(..., K, N)``, the returned array **must** be an array having shape ``(..., N)`` (i.e., prepended dimensions during vector-to-matrix promotion **must** be removed) and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
35+
- If ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a one-dimensional array having shape ``(K,)``, the returned array **must** be an array having shape ``(..., M)`` (i.e., appended dimensions during vector-to-matrix promotion **must** be removed) and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
36+
- If ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is an array having shape ``(..., K, N)``, the returned array **must** be an array having shape ``(..., M, N)`` and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
37+
- If ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, the returned array **must** be an array having shape ``(..., M, N)`` and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
38+
- If either ``x1`` or ``x2`` has more than two dimensions, the returned array **must** be an array having a shape determined by :ref:`broadcasting` ``shape(x1)[:-2]`` against ``shape(x2)[:-2]`` and **must** contain the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
39+
40+
The returned array **must** have a data type determined by :ref:`type-promotion`.
3541
36-
The returned array must have a data type determined by :ref:`type-promotion`.
42+
Raises
43+
------
44+
Exception
45+
an exception **should** be raised in the following circumstances:
46+
47+
- if either ``x1`` or ``x2`` is a zero-dimensional array.
48+
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
49+
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
50+
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
51+
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
3752
3853
Notes
3954
-----
4055
41-
.. versionchanged:: 2022.12
42-
Added complex data type support.
56+
- The ``matmul`` function **must** implement the same semantics as the built-in ``@`` operator (see `PEP 465 <https://www.python.org/dev/peps/pep-0465>`_).
4357
44-
**Raises**
45-
46-
- if either ``x1`` or ``x2`` is a zero-dimensional array.
47-
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
48-
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
49-
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
50-
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
58+
- If either ``x1`` or ``x2`` has a complex floating-point data type, the function **must not** complex-conjugate or tranpose either argument. If conjugation and/or transposition is desired, a user can explicitly perform these operations prior to computing the matrix product.
5159
60+
.. versionchanged:: 2022.12
61+
Added complex data type support.
5262
"""
5363

5464

@@ -64,7 +74,7 @@ def matrix_transpose(x: array, /) -> array:
6474
Returns
6575
-------
6676
out: array
67-
an array containing the transpose for each matrix and having shape ``(..., N, M)``. The returned array must have the same data type as ``x``.
77+
an array containing the transpose for each matrix. The returned array **must** have shape ``(..., N, M)``. The returned array **must** have the same data type as ``x``.
6878
"""
6979

7080

@@ -78,42 +88,37 @@ def tensordot(
7888
"""
7989
Returns a tensor contraction of ``x1`` and ``x2`` over specific axes.
8090
81-
.. note::
82-
The ``tensordot`` function corresponds to the generalized matrix product.
83-
8491
Parameters
8592
----------
8693
x1: array
87-
first input array. Should have a numeric data type.
94+
first input array. **Should** have a numeric data type.
8895
x2: array
89-
second input array. Should have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal.
90-
91-
.. note::
92-
Contracted axes (dimensions) must not be broadcasted.
96+
second input array. **Should** have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` **must** be equal.
9397
9498
axes: Union[int, Tuple[Sequence[int], Sequence[int]]]
95-
number of axes (dimensions) to contract or explicit sequences of axis (dimension) indices for ``x1`` and ``x2``, respectively.
99+
number of axes to contract or explicit sequences of axis indices for ``x1`` and ``x2``, respectively.
96100
97-
If ``axes`` is an ``int`` equal to ``N``, then contraction must be performed over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. The size of each corresponding axis (dimension) must match. Must be nonnegative.
101+
If ``axes`` is an ``int`` equal to ``N``, then contraction **must** be performed over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. The size of each corresponding axis **must** match. An integer ``axes`` value **must** be nonnegative.
98102
99-
- If ``N`` equals ``0``, the result is the tensor (outer) product.
100-
- If ``N`` equals ``1``, the result is the tensor dot product.
101-
- If ``N`` equals ``2``, the result is the tensor double contraction (default).
103+
- If ``N`` equals ``0``, the result **must** be the tensor (outer) product.
104+
- If ``N`` equals ``1``, the result **must** be the tensor dot product.
105+
- If ``N`` equals ``2``, the result **must** be the tensor double contraction (default).
102106
103-
If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x1`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each index referred to in a sequence must be unique. If ``x1`` has rank (i.e, number of dimensions) ``N``, a valid ``x1`` axis must reside on the half-open interval ``[-N, N)``. If ``x2`` has rank ``M``, a valid ``x2`` axis must reside on the half-open interval ``[-M, M)``.
107+
If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence **must** apply to ``x1`` and the second sequence **must** apply to ``x2``. Both sequences **must** have the same length. Each axis ``x1_axes[i]`` for ``x1`` **must** have the same size as the respective axis ``x2_axes[i]`` for ``x2``. Each index referred to in a sequence **must** be unique. A valid axis **must** be an integer on the interval ``[-S, S)``, where ``S`` is the number of axes in respective array. Hence, if ``x1`` has ``N`` axes, a valid ``x1`` axes **must** be an integer on the interval ``[-N, N)``. If ``x2`` has ``M`` axes, a valid ``x2`` axes **must** be an integer on the interval ``[-M, M)``. If an axis is specified as a negative integer, the function **must** determine the axis along which to perform the operation by counting backward from the last axis (where ``-1`` refers to the last axis). If provided an invalid axis, the function **must** raise an exception.
104108
105109
106-
.. note::
107-
If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the generalized matrix product.
108-
109110
Returns
110111
-------
111112
out: array
112-
an array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the first array ``x1``, followed by the non-contracted axes (dimensions) of the second array ``x2``. The returned array must have a data type determined by :ref:`type-promotion`.
113+
an array containing the tensor contraction. The returned array **must** have a shape which consists of the non-contracted axes of the first array ``x1``, followed by the non-contracted axes of the second array ``x2``. The returned array **must** have a data type determined by :ref:`type-promotion`.
113114
114115
Notes
115116
-----
116117
118+
- The ``tensordot`` function corresponds to the generalized matrix product.
119+
- Contracted axes **must** not be broadcasted.
120+
- If either ``x1`` or ``x2`` has a complex floating-point data type, the function **must not** complex-conjugate or transpose either argument. If conjugation and/or transposition is desired, a user can explicitly perform these operations prior to computing the generalized matrix product.
121+
117122
.. versionchanged:: 2022.12
118123
Added complex data type support.
119124
@@ -131,32 +136,33 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
131136
.. math::
132137
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i
133138
134-
over the dimension specified by ``axis`` and where :math:`n` is the dimension size and :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i` is complex and the identity if :math:`a_i` is real-valued.
139+
over the axis specified by ``axis`` and where :math:`n` is the axis size and :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i` is complex and the identity if :math:`a_i` is real-valued.
135140
136141
Parameters
137142
----------
138143
x1: array
139-
first input array. Should have a floating-point data type.
144+
first input array. **Should** have a floating-point data type.
140145
x2: array
141-
second input array. Must be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product must be the same size as the respective axis in ``x1``. Should have a floating-point data type.
142-
143-
.. note::
144-
The contracted axis (dimension) must not be broadcasted.
145-
146+
second input array. **Must** be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product **must** be the same size as the respective axis in ``x1``. **Should** have a floating-point data type.
146147
axis: int
147-
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
148+
axis of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. **Should** be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function **must** determine the axis along which to perform the operation by counting backward from the last axis (where ``-1`` refers to the last axis). By default, the function **must** compute the dot product over the last axis. Default: ``-1``.
148149
149150
Returns
150151
-------
151152
out: array
152-
if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having rank ``N-1``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting` along the non-contracted axes. The returned array must have a data type determined by :ref:`type-promotion`.
153+
if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having ``N-1`` axes, where ``N`` is number of axes in the shape determined according to :ref:`broadcasting` along the non-contracted axes. The returned array **must** have a data type determined by :ref:`type-promotion`.
154+
155+
Raises
156+
------
157+
Exception
158+
an exception **should** be raised in the following circumstances:
159+
160+
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
153161
154162
Notes
155163
-----
156164
157-
**Raises**
158-
159-
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
165+
- The contracted axis **must** not be broadcasted.
160166
161167
.. versionchanged:: 2022.12
162168
Added complex data type support.

0 commit comments

Comments
 (0)