Skip to content

Commit e709f83

Browse files
committed
replace all matmul's by gemm
1 parent 06ce735 commit e709f83

File tree

3 files changed

+121
-36
lines changed

3 files changed

+121
-36
lines changed

example/intrinsics/example_matmul.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ program example_matmul
44
real :: r1(50, 100), r2(100, 40), r3(40, 50)
55
real, allocatable :: res(:, :)
66
x = reshape([(0, 0), (1, 0), (1, 0), (0, 0)], [2, 2])
7-
y = reshape([(0, 0), (0, -1), (0, 1), (0, 0)], [2, 2]) ! pauli y-matrix
7+
y = reshape([(0, 0), (0, 1), (0, -1), (0, 0)], [2, 2]) ! pauli y-matrix
88

9-
print *, stdlib_matmul(y, y, y, y, y) ! should be y
9+
print *, stdlib_matmul(y, y, y) ! should be y
1010
print *, stdlib_matmul(x, x, y, x) ! should be -i x sigma_z
1111

1212
call random_seed()

src/stdlib_intrinsics.fypp

+2-2
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,10 @@ module stdlib_intrinsics
158158
!!
159159
!! matrix multiply more than two matrices with a single function call
160160
!! the multiplication with the optimal parenthesization for efficiency of computation is done automatically
161-
!! Supported data types are `real`, `integer` and `complex`.
161+
!! Supported data types are `real` and `complex`.
162162
!!
163163
!! Note: The matrices must be of compatible shapes to be multiplied
164-
#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
164+
#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
165165
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
166166
${t}$, intent(in) :: m1(:,:), m2(:,:)
167167
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)

src/stdlib_intrinsics_matmul.fypp

+117-32
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
55

66
submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
7+
use stdlib_linalg_blas, only: gemm
8+
use stdlib_constants
79
implicit none
810

911
contains
@@ -36,38 +38,84 @@ contains
3638
end do
3739
end function matmul_chain_order
3840

39-
#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
41+
#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
4042

41-
pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
43+
pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p) result(r)
4244
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
43-
integer, intent(in) :: start, s(:,2:)
44-
${t}$, allocatable :: r(:,:)
45-
integer :: tmp
46-
tmp = s(start, start + 2)
47-
48-
if (tmp == start) then
49-
r = matmul(m1, matmul(m2, m3))
50-
else if (tmp == start + 1) then
51-
r = matmul(matmul(m1, m2), m3)
45+
integer, intent(in) :: start, s(:,2:), p(:)
46+
${t}$, allocatable :: r(:,:), temp(:,:)
47+
integer :: ord, m, n, k
48+
ord = s(start, start + 2)
49+
allocate(r(p(start), p(start + 3)))
50+
51+
if (ord == start) then
52+
! m1*(m2*m3)
53+
m = p(start + 1)
54+
n = p(start + 3)
55+
k = p(start + 2)
56+
allocate(temp(m,n))
57+
call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m)
58+
m = p(start)
59+
n = p(start + 3)
60+
k = p(start + 1)
61+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
62+
else if (ord == start + 1) then
63+
! (m1*m2)*m3
64+
m = p(start)
65+
n = p(start + 2)
66+
k = p(start + 1)
67+
allocate(temp(m, n))
68+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
69+
m = p(start)
70+
n = p(start + 3)
71+
k = p(start + 1)
72+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m)
5273
else
5374
error stop "stdlib_matmul: error: unexpected s(i,j)"
5475
end if
5576

5677
end function matmul_chain_mult_${s}$_3
5778

58-
pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
79+
pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p) result(r)
5980
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
60-
integer, intent(in) :: start, s(:,2:)
61-
${t}$, allocatable :: r(:,:)
62-
integer :: tmp
63-
tmp = s(start, start + 3)
64-
65-
if (tmp == start) then
66-
r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
67-
else if (tmp == start + 1) then
68-
r = matmul(matmul(m1, m2), matmul(m3, m4))
69-
else if (tmp == start + 2) then
70-
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
81+
integer, intent(in) :: start, s(:,2:), p(:)
82+
${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
83+
integer :: ord, m, n, k
84+
ord = s(start, start + 3)
85+
allocate(r(p(start), p(start + 4)))
86+
87+
if (ord == start) then
88+
! m1*(m2*m3*m4)
89+
temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p)
90+
m = p(start)
91+
n = p(start + 4)
92+
k = p(start + 1)
93+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
94+
else if (ord == start + 1) then
95+
! (m1*m2)*(m3*m4)
96+
m = p(start)
97+
n = p(start + 2)
98+
k = p(start + 1)
99+
allocate(temp(m,n))
100+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
101+
102+
m = p(start + 2)
103+
n = p(start + 4)
104+
k = p(start + 3)
105+
allocate(temp1(m,n))
106+
call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m)
107+
108+
m = p(start)
109+
n = p(start + 4)
110+
k = p(start + 2)
111+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
112+
else if (ord == start + 2) then
113+
! (m1*m2*m3)*m4
114+
temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p)
115+
m = p(start)
116+
n = p(start + 4)
117+
k = p(start + 3)
118+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m)
71119
else
72120
error stop "stdlib_matmul: error: unexpected s(i,j)"
73121
end if
@@ -77,8 +125,8 @@ contains
77125
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
78126
${t}$, intent(in) :: m1(:,:), m2(:,:)
79127
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
80-
${t}$, allocatable :: r(:,:)
81-
integer :: p(6), num_present
128+
${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
129+
integer :: p(6), num_present, m, n, k
82130
integer, allocatable :: s(:,:)
83131

84132
p(1) = size(m1, 1)
@@ -102,8 +150,13 @@ contains
102150
num_present = num_present + 1
103151
end if
104152

153+
allocate(r(p(1), p(num_present + 1)))
154+
105155
if (num_present == 2) then
106-
r = matmul(m1, m2)
156+
m = p(1)
157+
n = p(3)
158+
k = p(2)
159+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m)
107160
return
108161
end if
109162

@@ -113,24 +166,56 @@ contains
113166
s = matmul_chain_order(p(1: num_present + 1))
114167

115168
if (num_present == 3) then
116-
r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s)
169+
r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
117170
return
118171
else if (num_present == 4) then
119-
r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s)
172+
r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
120173
return
121174
end if
122175

123176
! Now num_present is 5
124177

125178
select case (s(1, 5))
126179
case (1)
127-
r = matmul(m1, matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s))
180+
! m1*(m2*m3*m4*m5)
181+
temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p)
182+
m = p(1)
183+
n = p(6)
184+
k = p(2)
185+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
128186
case (2)
129-
r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s))
187+
! (m1*m2)*(m3*m4*m5)
188+
m = p(1)
189+
n = p(3)
190+
k = p(2)
191+
allocate(temp(m,n))
192+
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
193+
194+
temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p)
195+
196+
k = n
197+
n = p(6)
198+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
130199
case (3)
131-
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5))
200+
! (m1*m2*m3)*(m4*m5)
201+
temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)
202+
203+
m = p(4)
204+
n = p(6)
205+
k = p(5)
206+
allocate(temp1(m,n))
207+
call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m)
208+
209+
k = m
210+
m = p(1)
211+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
132212
case (4)
133-
r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5)
213+
! (m1*m2*m3*m4)*m5
214+
temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
215+
m = p(1)
216+
n = p(6)
217+
k = p(5)
218+
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
134219
case default
135220
error stop "stdlib_matmul: error: unexpected s(i,j)"
136221
end select

0 commit comments

Comments
 (0)