Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Gemm update #1518

Merged
merged 11 commits into from
Oct 31, 2024
Merged

Gemm update #1518

merged 11 commits into from
Oct 31, 2024

Conversation

jagrit06
Copy link
Member

Proposed changes

Just some re-structuring to the steel primitives. First of many updates to come.
No notable regression in performance was seen on M2 Ultra or M3 Max, some slight improvements in certain shapes

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@jagrit06
Copy link
Member Author

Here's some benchmarks (on M2 Ultra) - the numbers are around the same and within the range of variation between runs

  B,     M,     N,     K,   dtype,  t, gflops_mx, gflops_mx_updated
  1,  1024,  4096, 11008, float32, nn, 19117.517, 19068.989
  1,  2048,  4096, 11008, float32, nn, 19217.543, 19189.014
  1,  4096,  4096, 11008, float32, nn, 19131.385, 19139.942
  1,  8192,  4096, 11008, float32, nn, 19118.967, 19113.533
  1,  1024, 11008,  4096, float32, nn, 19397.411, 19469.591
  1,  2048, 11008,  4096, float32, nn, 19901.953, 19875.592
  1,  4096, 11008,  4096, float32, nn, 20224.684, 20175.094
  1,  8192, 11008,  4096, float32, nn, 20370.913, 20381.435
 16,  1024,  1024,  1024, float32, nn, 20106.087, 20122.314
  4,  1024,  1024,  4096, float32, nn, 19918.126, 19941.598
  4,  1024,  4096,  1024, float32, nn, 19956.352, 20185.274
  1,  1024, 11008,  4096, float32, nn, 19615.095, 19595.299
  1,  1024,  4096, 11008, float32, nn, 19243.008, 19246.350
  1,  4096,  4096,  4096, float32, nn, 20587.211, 20540.867
  1,  4095,  4095,  4095, float32, nn, 17565.664, 17731.251
  1,  4097,  4097,  4097, float32, nn, 17311.481, 17385.684
  1,  1333,  4096,  4096, float32, nn, 18111.473, 18077.965
  1,  1024,  4096, 11008, float32, nt, 17736.890, 17693.308
  1,  2048,  4096, 11008, float32, nt, 17723.575, 17650.992
  1,  4096,  4096, 11008, float32, nt, 17341.110, 17572.027
  1,  8192,  4096, 11008, float32, nt, 17704.210, 17737.410
  1,  1024, 11008,  4096, float32, nt, 18679.322, 18692.565
  1,  2048, 11008,  4096, float32, nt, 19124.050, 19094.897
  1,  4096, 11008,  4096, float32, nt, 19373.042, 19339.954
  1,  8192, 11008,  4096, float32, nt, 19469.665, 19446.607
 16,  1024,  1024,  1024, float32, nt, 19386.563, 19354.704
  4,  1024,  1024,  4096, float32, nt, 19140.685, 19129.418
  4,  1024,  4096,  1024, float32, nt, 19184.135, 19046.593
  1,  1024, 11008,  4096, float32, nt, 18667.955, 18625.957
  1,  1024,  4096, 11008, float32, nt, 17713.658, 17618.049
  1,  4096,  4096,  4096, float32, nt, 19705.528, 19690.371
  1,  4095,  4095,  4095, float32, nt, 17072.711, 17500.869
  1,  4097,  4097,  4097, float32, nt, 16738.796, 17068.665
  1,  1333,  4096,  4096, float32, nt, 17373.873, 17403.931
  1,  1024,  4096, 11008, float16, nn, 21899.427, 21887.580
  1,  2048,  4096, 11008, float16, nn, 22463.645, 22442.668
  1,  4096,  4096, 11008, float16, nn, 22681.248, 22641.161
  1,  8192,  4096, 11008, float16, nn, 22833.954, 22780.235
  1,  1024, 11008,  4096, float16, nn, 22200.684, 22184.197
  1,  2048, 11008,  4096, float16, nn, 22603.335, 22594.786
  1,  4096, 11008,  4096, float16, nn, 22825.676, 22810.456
  1,  8192, 11008,  4096, float16, nn, 23014.366, 22983.508
 16,  1024,  1024,  1024, float16, nn, 21990.062, 21999.520
  4,  1024,  1024,  4096, float16, nn, 21681.365, 21707.751
  4,  1024,  4096,  1024, float16, nn, 21936.512, 21915.507
  1,  1024, 11008,  4096, float16, nn, 22149.157, 22127.942
  1,  1024,  4096, 11008, float16, nn, 21929.383, 21907.708
  1,  4096,  4096,  4096, float16, nn, 22726.208, 22721.541
  1,  4095,  4095,  4095, float16, nn, 21613.586, 21850.698
  1,  4097,  4097,  4097, float16, nn, 20953.416, 21178.990
  1,  1333,  4096,  4096, float16, nn, 21177.043, 21630.681
  1,  1024,  4096, 11008, float16, nt, 22426.045, 22347.487
  1,  2048,  4096, 11008, float16, nt, 22729.592, 22668.456
  1,  4096,  4096, 11008, float16, nt, 22948.186, 22881.767
  1,  8192,  4096, 11008, float16, nt, 23014.582, 22941.651
  1,  1024, 11008,  4096, float16, nt, 22729.657, 22669.781
  1,  2048, 11008,  4096, float16, nt, 23040.536, 23003.921
  1,  4096, 11008,  4096, float16, nt, 23185.512, 23166.241
  1,  8192, 11008,  4096, float16, nt, 23274.393, 23247.154
 16,  1024,  1024,  1024, float16, nt, 22187.723, 22175.466
  4,  1024,  1024,  4096, float16, nt, 22154.835, 22129.534
  4,  1024,  4096,  1024, float16, nt, 22182.794, 22147.250
  1,  1024, 11008,  4096, float16, nt, 22726.162, 22682.880
  1,  1024,  4096, 11008, float16, nt, 22415.393, 22350.551
  1,  4096,  4096,  4096, float16, nt, 23006.383, 22972.706
  1,  4095,  4095,  4095, float16, nt, 21369.480, 22061.653
  1,  4097,  4097,  4097, float16, nt, 20792.476, 21447.853
  1,  1333,  4096,  4096, float16, nt, 21672.848, 21631.082

@jagrit06 jagrit06 force-pushed the gemm-update branch 2 times, most recently from 066a9e6 to a5bfec9 Compare October 29, 2024 17:07
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@jagrit06
Copy link
Member Author

Added some tuning updates
Here are the results now on the M2 Ultra:

  B,     M,     N,     K,   dtype,  t, gflops_mx, gflops_mx_updated
  1,  1024,  4096, 11008, float32, nn, 19117.517, 19274.840
  1,  2048,  4096, 11008, float32, nn, 19217.543, 19747.169
  1,  4096,  4096, 11008, float32, nn, 19131.385, 19872.386
  1,  8192,  4096, 11008, float32, nn, 19118.967, 20006.589
  1,  1024, 11008,  4096, float32, nn, 19397.411, 19676.072
  1,  2048, 11008,  4096, float32, nn, 19901.953, 20322.518
  1,  4096, 11008,  4096, float32, nn, 20224.684, 20665.435
  1,  8192, 11008,  4096, float32, nn, 20370.913, 20832.300
 16,  1024,  1024,  1024, float32, nn, 20106.087, 20153.883
  4,  1024,  1024,  4096, float32, nn, 19918.126, 19873.585
  4,  1024,  4096,  1024, float32, nn, 19956.352, 20186.173
  1,  1024, 11008,  4096, float32, nn, 19615.095, 19668.497
  1,  1024,  4096, 11008, float32, nn, 19243.008, 19265.997
  1,  4096,  4096,  4096, float32, nn, 20587.211, 20745.885
  1,  4095,  4095,  4095, float32, nn, 17565.664, 18285.369
  1,  4097,  4097,  4097, float32, nn, 17311.481, 17805.171
  1,  1333,  4096,  4096, float32, nn, 18111.473, 18126.697
  1,  1024,  4096, 11008, float32, nt, 17736.890, 18268.362
  1,  2048,  4096, 11008, float32, nt, 17723.575, 18671.449
  1,  4096,  4096, 11008, float32, nt, 17341.110, 18559.295
  1,  8192,  4096, 11008, float32, nt, 17704.210, 18913.444
  1,  1024, 11008,  4096, float32, nt, 18679.322, 18635.907
  1,  2048, 11008,  4096, float32, nt, 19124.050, 19341.598
  1,  4096, 11008,  4096, float32, nt, 19373.042, 19660.922
  1,  8192, 11008,  4096, float32, nt, 19469.665, 19808.260
 16,  1024,  1024,  1024, float32, nt, 19386.563, 19260.996
  4,  1024,  1024,  4096, float32, nt, 19140.685, 19043.653
  4,  1024,  4096,  1024, float32, nt, 19184.135, 19251.135
  1,  1024, 11008,  4096, float32, nt, 18667.955, 18673.967
  1,  1024,  4096, 11008, float32, nt, 17713.658, 18393.045
  1,  4096,  4096,  4096, float32, nt, 19705.528, 19697.074
  1,  4095,  4095,  4095, float32, nt, 17072.711, 17659.478
  1,  4097,  4097,  4097, float32, nt, 16738.796, 17177.833
  1,  1333,  4096,  4096, float32, nt, 17373.873, 17093.802
  1,  1024,  4096, 11008, float16, nn, 21899.427, 22535.881
  1,  2048,  4096, 11008, float16, nn, 22463.645, 23085.590
  1,  4096,  4096, 11008, float16, nn, 22681.248, 23275.723
  1,  8192,  4096, 11008, float16, nn, 22833.954, 23585.688
  1,  1024, 11008,  4096, float16, nn, 22200.684, 22912.559
  1,  2048, 11008,  4096, float16, nn, 22603.335, 23414.237
  1,  4096, 11008,  4096, float16, nn, 22825.676, 23702.309
  1,  8192, 11008,  4096, float16, nn, 23014.366, 23889.071
 16,  1024,  1024,  1024, float16, nn, 21990.062, 22488.405
  4,  1024,  1024,  4096, float16, nn, 21681.365, 22144.282
  4,  1024,  4096,  1024, float16, nn, 21936.512, 22582.744
  1,  1024, 11008,  4096, float16, nn, 22149.157, 22903.193
  1,  1024,  4096, 11008, float16, nn, 21929.383, 22514.098
  1,  4096,  4096,  4096, float16, nn, 22726.208, 23467.754
  1,  4095,  4095,  4095, float16, nn, 21613.586, 22480.942
  1,  4097,  4097,  4097, float16, nn, 20953.416, 21753.930
  1,  1333,  4096,  4096, float16, nn, 21177.043, 21997.329
  1,  1024,  4096, 11008, float16, nt, 22426.045, 22170.076
  1,  2048,  4096, 11008, float16, nt, 22729.592, 22626.339
  1,  4096,  4096, 11008, float16, nt, 22948.186, 22918.676
  1,  8192,  4096, 11008, float16, nt, 23014.582, 23269.988
  1,  1024, 11008,  4096, float16, nt, 22729.657, 22689.532
  1,  2048, 11008,  4096, float16, nt, 23040.536, 23151.335
  1,  4096, 11008,  4096, float16, nt, 23185.512, 23432.139
  1,  8192, 11008,  4096, float16, nt, 23274.393, 23709.022
 16,  1024,  1024,  1024, float16, nt, 22187.723, 22190.262
  4,  1024,  1024,  4096, float16, nt, 22154.835, 21793.520
  4,  1024,  4096,  1024, float16, nt, 22182.794, 22293.168
  1,  1024, 11008,  4096, float16, nt, 22726.162, 22725.356
  1,  1024,  4096, 11008, float16, nt, 22415.393, 22454.936
  1,  4096,  4096,  4096, float16, nt, 23006.383, 23182.257
  1,  4095,  4095,  4095, float16, nt, 21369.480, 22501.117
  1,  4097,  4097,  4097, float16, nt, 20792.476, 21871.788
  1,  1333,  4096,  4096, float16, nt, 21672.848, 21776.690

@@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////

#define GEMM_TPARAM_MACRO(devc) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@angeloskath angeloskath merged commit 960e3f0 into main Oct 31, 2024
5 checks passed
@angeloskath angeloskath deleted the gemm-update branch October 31, 2024 02:30
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants