Skip to content

Commit 9b6911c

Browse files
committed
Initial conversion to const generics
1 parent fc71718 commit 9b6911c

File tree

5 files changed

+82
-95
lines changed

5 files changed

+82
-95
lines changed

crates/assert-instr-macro/src/lib.rs

+16-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub fn assert_instr(
6262
);
6363
let mut inputs = Vec::new();
6464
let mut input_vals = Vec::new();
65+
let mut const_vals = Vec::new();
6566
let ret = &func.sig.output;
6667
for arg in func.sig.inputs.iter() {
6768
let capture = match *arg {
@@ -82,6 +83,20 @@ pub fn assert_instr(
8283
input_vals.push(quote! { #ident });
8384
}
8485
}
86+
for arg in func.sig.generics.params.iter() {
87+
let c = match *arg {
88+
syn::GenericParam::Const(ref c) => c,
89+
ref v => panic!(
90+
"only const generics are allowed: `{:?}`",
91+
v.clone().into_token_stream()
92+
),
93+
};
94+
if let Some(&(_, ref tokens)) = invoc.args.iter().find(|a| c.ident == a.0) {
95+
const_vals.push(quote! { #tokens });
96+
} else {
97+
panic!("const generics must have a value for tests");
98+
}
99+
}
85100

86101
let attrs = func
87102
.attrs
@@ -133,7 +148,7 @@ pub fn assert_instr(
133148
std::mem::transmute(#shim_name_str.as_bytes().as_ptr()),
134149
std::sync::atomic::Ordering::Relaxed,
135150
);
136-
#name(#(#input_vals),*)
151+
#name::<#(#const_vals),*>(#(#input_vals),*)
137152
}
138153
};
139154

crates/core_arch/src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
clippy::shadow_reuse,
5454
clippy::cognitive_complexity,
5555
clippy::similar_names,
56-
clippy::many_single_char_names
56+
clippy::many_single_char_names,
57+
non_upper_case_globals
5758
)]
5859
#![cfg_attr(test, allow(unused_imports))]
5960
#![no_std]

crates/core_arch/src/x86/avx512f.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -22694,7 +22694,7 @@ pub unsafe fn _mm_mask_shuffle_ps(
2269422694
) -> __m128 {
2269522695
macro_rules! call {
2269622696
($imm8:expr) => {
22697-
_mm_shuffle_ps(a, b, $imm8)
22697+
_mm_shuffle_ps::<$imm8>(a, b)
2269822698
};
2269922699
}
2270022700
let r = constify_imm8_sae!(imm8, call);
@@ -22711,7 +22711,7 @@ pub unsafe fn _mm_mask_shuffle_ps(
2271122711
pub unsafe fn _mm_maskz_shuffle_ps(k: __mmask8, a: __m128, b: __m128, imm8: i32) -> __m128 {
2271222712
macro_rules! call {
2271322713
($imm8:expr) => {
22714-
_mm_shuffle_ps(a, b, $imm8)
22714+
_mm_shuffle_ps::<$imm8>(a, b)
2271522715
};
2271622716
}
2271722717
let r = constify_imm8_sae!(imm8, call);

crates/core_arch/src/x86/sse.rs

+33-65
Original file line numberDiff line numberDiff line change
@@ -1007,52 +1007,20 @@ pub const fn _MM_SHUFFLE(z: u32, y: u32, x: u32, w: u32) -> i32 {
10071007
#[inline]
10081008
#[target_feature(enable = "sse")]
10091009
#[cfg_attr(test, assert_instr(shufps, mask = 3))]
1010-
#[rustc_args_required_const(2)]
1011-
#[stable(feature = "simd_x86", since = "1.27.0")]
1012-
pub unsafe fn _mm_shuffle_ps(a: __m128, b: __m128, mask: i32) -> __m128 {
1013-
let mask = (mask & 0xFF) as u8;
1014-
1015-
macro_rules! shuffle_done {
1016-
($x01:expr, $x23:expr, $x45:expr, $x67:expr) => {
1017-
simd_shuffle4(a, b, [$x01, $x23, $x45, $x67])
1018-
};
1019-
}
1020-
macro_rules! shuffle_x67 {
1021-
($x01:expr, $x23:expr, $x45:expr) => {
1022-
match (mask >> 6) & 0b11 {
1023-
0b00 => shuffle_done!($x01, $x23, $x45, 4),
1024-
0b01 => shuffle_done!($x01, $x23, $x45, 5),
1025-
0b10 => shuffle_done!($x01, $x23, $x45, 6),
1026-
_ => shuffle_done!($x01, $x23, $x45, 7),
1027-
}
1028-
};
1029-
}
1030-
macro_rules! shuffle_x45 {
1031-
($x01:expr, $x23:expr) => {
1032-
match (mask >> 4) & 0b11 {
1033-
0b00 => shuffle_x67!($x01, $x23, 4),
1034-
0b01 => shuffle_x67!($x01, $x23, 5),
1035-
0b10 => shuffle_x67!($x01, $x23, 6),
1036-
_ => shuffle_x67!($x01, $x23, 7),
1037-
}
1038-
};
1039-
}
1040-
macro_rules! shuffle_x23 {
1041-
($x01:expr) => {
1042-
match (mask >> 2) & 0b11 {
1043-
0b00 => shuffle_x45!($x01, 0),
1044-
0b01 => shuffle_x45!($x01, 1),
1045-
0b10 => shuffle_x45!($x01, 2),
1046-
_ => shuffle_x45!($x01, 3),
1047-
}
1048-
};
1049-
}
1050-
match mask & 0b11 {
1051-
0b00 => shuffle_x23!(0),
1052-
0b01 => shuffle_x23!(1),
1053-
0b10 => shuffle_x23!(2),
1054-
_ => shuffle_x23!(3),
1055-
}
1010+
#[rustc_legacy_const_generics(2)]
1011+
#[stable(feature = "simd_x86", since = "1.27.0")]
1012+
pub unsafe fn _mm_shuffle_ps<const mask: i32>(a: __m128, b: __m128) -> __m128 {
1013+
assert!(mask >= 0 && mask <= 255);
1014+
simd_shuffle4(
1015+
a,
1016+
b,
1017+
[
1018+
mask as u32 & 0b11,
1019+
(mask as u32 >> 2) & 0b11,
1020+
((mask as u32 >> 4) & 0b11) + 4,
1021+
((mask as u32 >> 6) & 0b11) + 4,
1022+
],
1023+
)
10561024
}
10571025

10581026
/// Unpacks and interleave single-precision (32-bit) floating-point elements
@@ -1725,6 +1693,14 @@ pub const _MM_HINT_T2: i32 = 1;
17251693
#[stable(feature = "simd_x86", since = "1.27.0")]
17261694
pub const _MM_HINT_NTA: i32 = 0;
17271695

1696+
/// See [`_mm_prefetch`](fn._mm_prefetch.html).
1697+
#[stable(feature = "simd_x86", since = "1.27.0")]
1698+
pub const _MM_HINT_ET0: i32 = 7;
1699+
1700+
/// See [`_mm_prefetch`](fn._mm_prefetch.html).
1701+
#[stable(feature = "simd_x86", since = "1.27.0")]
1702+
pub const _MM_HINT_ET1: i32 = 6;
1703+
17281704
/// Fetch the cache line that contains address `p` using the given `strategy`.
17291705
///
17301706
/// The `strategy` must be one of:
@@ -1742,6 +1718,10 @@ pub const _MM_HINT_NTA: i32 = 0;
17421718
/// but outside of the cache hierarchy. This is used to reduce access latency
17431719
/// without polluting the cache.
17441720
///
1721+
/// * [`_MM_HINT_ET0`](constant._MM_HINT_ET0.html) and
1722+
/// [`_MM_HINT_ET1`](constant._MM_HINT_ET1.html) are similar to `_MM_HINT_T0`
1723+
/// and `_MM_HINT_T1` but indicate an anticipation to write to the address.
1724+
///
17451725
/// The actual implementation depends on the particular CPU. This instruction
17461726
/// is considered a hint, so the CPU is also free to simply ignore the request.
17471727
///
@@ -1769,24 +1749,12 @@ pub const _MM_HINT_NTA: i32 = 0;
17691749
#[cfg_attr(test, assert_instr(prefetcht1, strategy = _MM_HINT_T1))]
17701750
#[cfg_attr(test, assert_instr(prefetcht2, strategy = _MM_HINT_T2))]
17711751
#[cfg_attr(test, assert_instr(prefetchnta, strategy = _MM_HINT_NTA))]
1772-
#[rustc_args_required_const(1)]
1773-
#[stable(feature = "simd_x86", since = "1.27.0")]
1774-
pub unsafe fn _mm_prefetch(p: *const i8, strategy: i32) {
1775-
// The `strategy` must be a compile-time constant, so we use a short form
1776-
// of `constify_imm8!` for now.
1777-
// We use the `llvm.prefetch` instrinsic with `rw` = 0 (read), and
1778-
// `cache type` = 1 (data cache). `locality` is based on our `strategy`.
1779-
macro_rules! pref {
1780-
($imm8:expr) => {
1781-
match $imm8 {
1782-
0 => prefetch(p, 0, 0, 1),
1783-
1 => prefetch(p, 0, 1, 1),
1784-
2 => prefetch(p, 0, 2, 1),
1785-
_ => prefetch(p, 0, 3, 1),
1786-
}
1787-
};
1788-
}
1789-
pref!(strategy)
1752+
#[rustc_legacy_const_generics(1)]
1753+
#[stable(feature = "simd_x86", since = "1.27.0")]
1754+
pub unsafe fn _mm_prefetch<const strategy: i32>(p: *const i8) {
1755+
// We use the `llvm.prefetch` instrinsic with `cache type` = 1 (data cache).
1756+
// `locality` and `rw` are based on our `strategy`.
1757+
prefetch(p, (strategy >> 2) & 1, strategy & 3, 1);
17901758
}
17911759

17921760
/// Returns vector of type __m128 with undefined elements.
@@ -2976,7 +2944,7 @@ mod tests {
29762944
unsafe fn test_mm_shuffle_ps() {
29772945
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
29782946
let b = _mm_setr_ps(5.0, 6.0, 7.0, 8.0);
2979-
let r = _mm_shuffle_ps(a, b, 0b00_01_01_11);
2947+
let r = _mm_shuffle_ps::<0b00_01_01_11>(a, b);
29802948
assert_eq_m128(r, _mm_setr_ps(4.0, 2.0, 6.0, 5.0));
29812949
}
29822950

crates/stdarch-verify/src/lib.rs

+29-26
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,21 @@ fn functions(input: TokenStream, dirs: &[&str]) -> TokenStream {
8181
let name = &f.sig.ident;
8282
// println!("{}", name);
8383
let mut arguments = Vec::new();
84+
let mut const_arguments = Vec::new();
8485
for input in f.sig.inputs.iter() {
8586
let ty = match *input {
8687
syn::FnArg::Typed(ref c) => &c.ty,
8788
_ => panic!("invalid argument on {}", name),
8889
};
8990
arguments.push(to_type(ty));
9091
}
92+
for generic in f.sig.generics.params.iter() {
93+
let ty = match *generic {
94+
syn::GenericParam::Const(ref c) => &c.ty,
95+
_ => panic!("invalid generic argument on {}", name),
96+
};
97+
const_arguments.push(to_type(ty));
98+
}
9199
let ret = match f.sig.output {
92100
syn::ReturnType::Default => quote! { None },
93101
syn::ReturnType::Type(_, ref t) => {
@@ -101,7 +109,20 @@ fn functions(input: TokenStream, dirs: &[&str]) -> TokenStream {
101109
} else {
102110
quote! { None }
103111
};
104-
let required_const = find_required_const(&f.attrs);
112+
113+
let required_const = find_required_const("rustc_args_required_const", &f.attrs);
114+
let mut legacy_const_generics =
115+
find_required_const("rustc_legacy_const_generics", &f.attrs);
116+
if !required_const.is_empty() && !legacy_const_generics.is_empty() {
117+
panic!(
118+
"Can't have both #[rustc_args_required_const] and \
119+
#[rustc_legacy_const_generics]"
120+
);
121+
}
122+
legacy_const_generics.sort();
123+
for (idx, ty) in legacy_const_generics.into_iter().zip(const_arguments.into_iter()) {
124+
arguments.insert(idx, ty);
125+
}
105126

106127
// strip leading underscore from fn name when building a test
107128
// _mm_foo -> mm_foo such that the test name is test_mm_foo.
@@ -238,16 +259,8 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
238259

239260
s => panic!("unsupported type: \"{}\"", s),
240261
},
241-
syn::Type::Ptr(syn::TypePtr {
242-
ref elem,
243-
ref mutability,
244-
..
245-
})
246-
| syn::Type::Reference(syn::TypeReference {
247-
ref elem,
248-
ref mutability,
249-
..
250-
}) => {
262+
syn::Type::Ptr(syn::TypePtr { ref elem, ref mutability, .. })
263+
| syn::Type::Reference(syn::TypeReference { ref elem, ref mutability, .. }) => {
251264
// Both pointers and references can have a mut token (*mut and &mut)
252265
if mutability.is_some() {
253266
let tokens = to_type(&elem);
@@ -278,11 +291,7 @@ fn extract_path_ident(path: &syn::Path) -> syn::Ident {
278291
syn::PathArguments::None => {}
279292
_ => panic!("unsupported path that has path arguments"),
280293
}
281-
path.segments
282-
.first()
283-
.expect("segment not found")
284-
.ident
285-
.clone()
294+
path.segments.first().expect("segment not found").ident.clone()
286295
}
287296

288297
fn walk(root: &Path, files: &mut Vec<(syn::File, String)>) {
@@ -359,11 +368,7 @@ fn find_instrs(attrs: &[syn::Attribute]) -> Vec<String> {
359368
attrs
360369
.iter()
361370
.filter(|a| a.path.is_ident("cfg_attr"))
362-
.filter_map(|a| {
363-
syn::parse2::<AssertInstr>(a.tokens.clone())
364-
.ok()
365-
.map(|a| a.instr)
366-
})
371+
.filter_map(|a| syn::parse2::<AssertInstr>(a.tokens.clone()).ok().map(|a| a.instr))
367372
.collect()
368373
}
369374

@@ -390,14 +395,12 @@ fn find_target_feature(attrs: &[syn::Attribute]) -> Option<syn::Lit> {
390395
})
391396
}
392397

393-
fn find_required_const(attrs: &[syn::Attribute]) -> Vec<usize> {
398+
fn find_required_const(name: &str, attrs: &[syn::Attribute]) -> Vec<usize> {
394399
attrs
395400
.iter()
396401
.flat_map(|a| {
397-
if a.path.segments[0].ident == "rustc_args_required_const" {
398-
syn::parse::<RustcArgsRequiredConst>(a.tokens.clone().into())
399-
.unwrap()
400-
.args
402+
if a.path.segments[0].ident == name {
403+
syn::parse::<RustcArgsRequiredConst>(a.tokens.clone().into()).unwrap().args
401404
} else {
402405
Vec::new()
403406
}

0 commit comments

Comments
 (0)