4
4
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
5
5
6
6
submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
7
+ use stdlib_linalg_blas, only: gemm
8
+ use stdlib_constants
7
9
implicit none
8
10
9
11
contains
@@ -36,38 +38,84 @@ contains
36
38
end do
37
39
end function matmul_chain_order
38
40
39
- #:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
41
+ #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
40
42
41
- pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
43
+ pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p ) result(r)
42
44
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43
- integer, intent(in) :: start, s(:,2:)
44
- ${t}$, allocatable :: r(:,:)
45
- integer :: tmp
46
- tmp = s(start, start + 2)
47
-
48
- if (tmp == start) then
49
- r = matmul(m1, matmul(m2, m3))
50
- else if (tmp == start + 1) then
51
- r = matmul(matmul(m1, m2), m3)
45
+ integer, intent(in) :: start, s(:,2:), p(:)
46
+ ${t}$, allocatable :: r(:,:), temp(:,:)
47
+ integer :: ord, m, n, k
48
+ ord = s(start, start + 2)
49
+ allocate(r(p(start), p(start + 3)))
50
+
51
+ if (ord == start) then
52
+ ! m1*(m2*m3)
53
+ m = p(start + 1)
54
+ n = p(start + 3)
55
+ k = p(start + 2)
56
+ allocate(temp(m,n))
57
+ call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m)
58
+ m = p(start)
59
+ n = p(start + 3)
60
+ k = p(start + 1)
61
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
62
+ else if (ord == start + 1) then
63
+ ! (m1*m2)*m3
64
+ m = p(start)
65
+ n = p(start + 2)
66
+ k = p(start + 1)
67
+ allocate(temp(m, n))
68
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
69
+ m = p(start)
70
+ n = p(start + 3)
71
+ k = p(start + 1)
72
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m)
52
73
else
53
74
error stop "stdlib_matmul: error: unexpected s(i,j)"
54
75
end if
55
76
56
77
end function matmul_chain_mult_${s}$_3
57
78
58
- pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
79
+ pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p ) result(r)
59
80
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
60
- integer, intent(in) :: start, s(:,2:)
61
- ${t}$, allocatable :: r(:,:)
62
- integer :: tmp
63
- tmp = s(start, start + 3)
64
-
65
- if (tmp == start) then
66
- r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
67
- else if (tmp == start + 1) then
68
- r = matmul(matmul(m1, m2), matmul(m3, m4))
69
- else if (tmp == start + 2) then
70
- r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
81
+ integer, intent(in) :: start, s(:,2:), p(:)
82
+ ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
83
+ integer :: ord, m, n, k
84
+ ord = s(start, start + 3)
85
+ allocate(r(p(start), p(start + 4)))
86
+
87
+ if (ord == start) then
88
+ ! m1*(m2*m3*m4)
89
+ temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p)
90
+ m = p(start)
91
+ n = p(start + 4)
92
+ k = p(start + 1)
93
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
94
+ else if (ord == start + 1) then
95
+ ! (m1*m2)*(m3*m4)
96
+ m = p(start)
97
+ n = p(start + 2)
98
+ k = p(start + 1)
99
+ allocate(temp(m,n))
100
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
101
+
102
+ m = p(start + 2)
103
+ n = p(start + 4)
104
+ k = p(start + 3)
105
+ allocate(temp1(m,n))
106
+ call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m)
107
+
108
+ m = p(start)
109
+ n = p(start + 4)
110
+ k = p(start + 2)
111
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
112
+ else if (ord == start + 2) then
113
+ ! (m1*m2*m3)*m4
114
+ temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p)
115
+ m = p(start)
116
+ n = p(start + 4)
117
+ k = p(start + 3)
118
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m)
71
119
else
72
120
error stop "stdlib_matmul: error: unexpected s(i,j)"
73
121
end if
@@ -77,8 +125,8 @@ contains
77
125
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
78
126
${t}$, intent(in) :: m1(:,:), m2(:,:)
79
127
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
80
- ${t}$, allocatable :: r(:,:)
81
- integer :: p(6), num_present
128
+ ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
129
+ integer :: p(6), num_present, m, n, k
82
130
integer, allocatable :: s(:,:)
83
131
84
132
p(1) = size(m1, 1)
@@ -102,8 +150,13 @@ contains
102
150
num_present = num_present + 1
103
151
end if
104
152
153
+ allocate(r(p(1), p(num_present + 1)))
154
+
105
155
if (num_present == 2) then
106
- r = matmul(m1, m2)
156
+ m = p(1)
157
+ n = p(3)
158
+ k = p(2)
159
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m)
107
160
return
108
161
end if
109
162
@@ -113,24 +166,56 @@ contains
113
166
s = matmul_chain_order(p(1: num_present + 1))
114
167
115
168
if (num_present == 3) then
116
- r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s)
169
+ r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4) )
117
170
return
118
171
else if (num_present == 4) then
119
- r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s)
172
+ r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5) )
120
173
return
121
174
end if
122
175
123
176
! Now num_present is 5
124
177
125
178
select case (s(1, 5))
126
179
case (1)
127
- r = matmul(m1, matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s))
180
+ ! m1*(m2*m3*m4*m5)
181
+ temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p)
182
+ m = p(1)
183
+ n = p(6)
184
+ k = p(2)
185
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
128
186
case (2)
129
- r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s))
187
+ ! (m1*m2)*(m3*m4*m5)
188
+ m = p(1)
189
+ n = p(3)
190
+ k = p(2)
191
+ allocate(temp(m,n))
192
+ call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
193
+
194
+ temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p)
195
+
196
+ k = n
197
+ n = p(6)
198
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
130
199
case (3)
131
- r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5))
200
+ ! (m1*m2*m3)*(m4*m5)
201
+ temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
202
+
203
+ m = p(4)
204
+ n = p(6)
205
+ k = p(5)
206
+ allocate(temp1(m,n))
207
+ call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m)
208
+
209
+ k = m
210
+ m = p(1)
211
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
132
212
case (4)
133
- r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5)
213
+ ! (m1*m2*m3*m4)*m5
214
+ temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
215
+ m = p(1)
216
+ n = p(6)
217
+ k = p(5)
218
+ call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
134
219
case default
135
220
error stop "stdlib_matmul: error: unexpected s(i,j)"
136
221
end select
0 commit comments