Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow coercing safe-to-call target_feature functions to safe fn pointers #135504

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1654,7 +1654,20 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
match *cast_kind {
CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer, coercion_source) => {
let is_implicit_coercion = coercion_source == CoercionSource::Implicit;
let src_sig = op.ty(body, tcx).fn_sig(tcx);
let src_ty = op.ty(body, tcx);
let mut src_sig = src_ty.fn_sig(tcx);
if let ty::FnDef(def_id, _) = src_ty.kind()
&& let ty::FnPtr(_, target_hdr) = *ty.kind()
&& tcx.codegen_fn_attrs(def_id).safe_target_features
&& target_hdr.safety.is_safe()
&& let Some(safe_sig) = tcx.adjust_target_feature_sig(
*def_id,
src_sig,
body.source.def_id(),
)
{
src_sig = safe_sig;
}

// HACK: This shouldn't be necessary... We can remove this when we actually
// get binders with where clauses, then elaborate implied bounds into that
Expand Down
23 changes: 12 additions & 11 deletions compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {

match b.kind() {
ty::FnPtr(_, b_hdr) => {
let a_sig = a.fn_sig(self.tcx);
let mut a_sig = a.fn_sig(self.tcx);
if let ty::FnDef(def_id, _) = *a.kind() {
// Intrinsics are not coercible to function pointers
if self.tcx.intrinsic(def_id).is_some() {
Expand All @@ -932,19 +932,20 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
return Err(TypeError::ForceInlineCast);
}

let fn_attrs = self.tcx.codegen_fn_attrs(def_id);
if matches!(fn_attrs.inline, InlineAttr::Force { .. }) {
return Err(TypeError::ForceInlineCast);
}

// FIXME(target_feature): Safe `#[target_feature]` functions could be cast to safe fn pointers (RFC 2396),
// as you can already write that "cast" in user code by wrapping a target_feature fn call in a closure,
// which is safe. This is sound because you already need to be executing code that is satisfying the target
// feature constraints..
if b_hdr.safety.is_safe()
&& self.tcx.codegen_fn_attrs(def_id).safe_target_features
{
return Err(TypeError::TargetFeatureCast(def_id));
// Allow the coercion if the current function has all the features that would be
// needed to call the coercee safely.
if let Some(safe_sig) = self.tcx.adjust_target_feature_sig(
def_id,
a_sig,
self.fcx.body_id.into(),
) {
a_sig = safe_sig;
} else {
return Err(TypeError::TargetFeatureCast(def_id));
}
}
}

Expand Down
33 changes: 32 additions & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ use crate::dep_graph::{DepGraph, DepKindStruct};
use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos};
use crate::lint::lint_level;
use crate::metadata::ModChild;
use crate::middle::codegen_fn_attrs::CodegenFnAttrs;
use crate::middle::codegen_fn_attrs::{CodegenFnAttrs, TargetFeature};
use crate::middle::{resolve_bound_vars, stability};
use crate::mir::interpret::{self, Allocation, ConstAllocation};
use crate::mir::{Body, Local, Place, PlaceElem, ProjectionKind, Promoted};
Expand Down Expand Up @@ -1776,6 +1776,37 @@ impl<'tcx> TyCtxt<'tcx> {
pub fn dcx(self) -> DiagCtxtHandle<'tcx> {
self.sess.dcx()
}

pub fn is_target_feature_call_safe(
self,
callee_features: &[TargetFeature],
body_features: &[TargetFeature],
) -> bool {
// If the called function has target features the calling function hasn't,
// the call requires `unsafe`. Don't check this on wasm
// targets, though. For more information on wasm see the
// is_like_wasm check in hir_analysis/src/collect.rs
self.sess.target.options.is_like_wasm
|| callee_features
.iter()
.all(|feature| body_features.iter().any(|f| f.name == feature.name))
}

/// Returns the safe version of the signature of the given function, if calling it
/// would be safe in the context of the given caller.
pub fn adjust_target_feature_sig(
self,
fun_def: DefId,
fun_sig: ty::Binder<'tcx, ty::FnSig<'tcx>>,
caller: DefId,
) -> Option<ty::Binder<'tcx, ty::FnSig<'tcx>>> {
let fun_features = &self.codegen_fn_attrs(fun_def).target_features;
let callee_features = &self.codegen_fn_attrs(caller).target_features;
if self.is_target_feature_call_safe(&fun_features, &callee_features) {
return Some(fun_sig.map_bound(|sig| ty::FnSig { safety: hir::Safety::Safe, ..sig }));
}
None
}
}

impl<'tcx> TyCtxtAt<'tcx> {
Expand Down
11 changes: 3 additions & 8 deletions compiler/rustc_mir_build/src/check_unsafety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
};
self.requires_unsafe(expr.span, CallToUnsafeFunction(func_id));
} else if let &ty::FnDef(func_did, _) = fn_ty.kind() {
// If the called function has target features the calling function hasn't,
// the call requires `unsafe`. Don't check this on wasm
// targets, though. For more information on wasm see the
// is_like_wasm check in hir_analysis/src/collect.rs
if !self.tcx.sess.target.options.is_like_wasm
&& !callee_features.iter().all(|feature| {
self.body_target_features.iter().any(|f| f.name == feature.name)
})
if !self
.tcx
.is_target_feature_call_safe(callee_features, self.body_target_features)
{
let missing: Vec<_> = callee_features
.iter()
Expand Down
14 changes: 14 additions & 0 deletions tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,23 @@

#![feature(target_feature_11)]

#[target_feature(enable = "avx")]
fn foo_avx() {}

#[target_feature(enable = "sse2")]
fn foo() {}

#[target_feature(enable = "sse2")]
fn bar() {
let foo: fn() = foo; // this is OK, as we have the necessary target features.
let foo: fn() = foo_avx; //~ ERROR mismatched types
}

fn main() {
if std::is_x86_feature_detected!("sse2") {
unsafe {
bar();
}
}
let foo: fn() = foo; //~ ERROR mismatched types
}
19 changes: 17 additions & 2 deletions tests/ui/rfcs/rfc-2396-target_feature-11/fn-ptr.stderr
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
error[E0308]: mismatched types
--> $DIR/fn-ptr.rs:9:21
--> $DIR/fn-ptr.rs:14:21
|
LL | #[target_feature(enable = "avx")]
| --------------------------------- `#[target_feature]` added here
...
LL | let foo: fn() = foo_avx;
| ---- ^^^^^^^ cannot coerce functions with `#[target_feature]` to safe function pointers
| |
| expected due to this
|
= note: expected fn pointer `fn()`
found fn item `#[target_features] fn() {foo_avx}`
= note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers

error[E0308]: mismatched types
--> $DIR/fn-ptr.rs:23:21
|
LL | #[target_feature(enable = "sse2")]
| ---------------------------------- `#[target_feature]` added here
Expand All @@ -13,6 +28,6 @@ LL | let foo: fn() = foo;
found fn item `#[target_features] fn() {foo}`
= note: functions with `#[target_feature]` can only be coerced to `unsafe` function pointers

error: aborting due to 1 previous error
error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0308`.
22 changes: 22 additions & 0 deletions tests/ui/rfcs/rfc-2396-target_feature-11/return-fn-ptr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//@ only-x86_64
//@ run-pass

#![feature(target_feature_11)]

#[target_feature(enable = "sse2")]
fn foo() -> bool {
true
}

#[target_feature(enable = "sse2")]
fn bar() -> fn() -> bool {
foo
}

fn main() {
if !std::is_x86_feature_detected!("sse2") {
return;
}
let f = unsafe { bar() };
assert!(f());
}
Loading