Skip to content

Commit 9c866bc

Browse files
committed
Check for occupied niches
1 parent 39e02f1 commit 9c866bc

File tree

24 files changed

+490
-14
lines changed

24 files changed

+490
-14
lines changed

compiler/rustc_codegen_ssa/src/mir/block.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
12951295
) -> MergingSucc {
12961296
debug!("codegen_terminator: {:?}", terminator);
12971297

1298+
if bx.tcx().may_insert_niche_checks() {
1299+
if let mir::TerminatorKind::Return = terminator.kind {
1300+
let op = mir::Operand::Copy(mir::Place::return_place());
1301+
let ty = op.ty(self.mir, bx.tcx());
1302+
let ty = self.monomorphize(ty);
1303+
if let Some(niche) = bx.layout_of(ty).largest_niche {
1304+
self.codegen_niche_check(bx, op, niche, terminator.source_info);
1305+
}
1306+
}
1307+
}
1308+
12981309
let helper = TerminatorCodegenHelper { bb, terminator };
12991310

13001311
let mergeable_succ = || {
@@ -1582,7 +1593,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
15821593
tuple.layout.fields.count()
15831594
}
15841595

1585-
fn get_caller_location(
1596+
pub fn get_caller_location(
15861597
&mut self,
15871598
bx: &mut Bx,
15881599
source_info: mir::SourceInfo,

compiler/rustc_codegen_ssa/src/mir/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub mod coverageinfo;
2020
pub mod debuginfo;
2121
mod intrinsic;
2222
mod locals;
23+
mod niche_check;
2324
pub mod operand;
2425
pub mod place;
2526
mod rvalue;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
use rustc_hir::LangItem;
2+
use rustc_middle::mir;
3+
use rustc_middle::mir::visit::Visitor;
4+
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext};
5+
use rustc_middle::ty::Mutability;
6+
use rustc_middle::ty::Ty;
7+
use rustc_middle::ty::TyCtxt;
8+
use rustc_span::def_id::LOCAL_CRATE;
9+
use rustc_span::Span;
10+
use rustc_target::abi::Float;
11+
use rustc_target::abi::Integer;
12+
use rustc_target::abi::Niche;
13+
use rustc_target::abi::Primitive;
14+
use rustc_target::abi::Size;
15+
16+
use super::FunctionCx;
17+
use crate::base;
18+
use crate::common;
19+
use crate::mir::place::PlaceValue;
20+
use crate::mir::OperandValue;
21+
use crate::traits::*;
22+
23+
pub struct NicheFinder<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> {
24+
pub fx: &'s mut FunctionCx<'a, 'tcx, Bx>,
25+
pub bx: &'s mut Bx,
26+
pub places: Vec<(mir::Operand<'tcx>, Niche)>,
27+
}
28+
29+
impl<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> Visitor<'tcx> for NicheFinder<'s, 'a, 'tcx, Bx> {
30+
fn visit_rvalue(&mut self, rvalue: &mir::Rvalue<'tcx>, location: mir::Location) {
31+
match rvalue {
32+
mir::Rvalue::Cast(mir::CastKind::Transmute, op, ty) => {
33+
let ty = self.fx.monomorphize(*ty);
34+
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
35+
self.places.push((op.clone(), niche));
36+
}
37+
}
38+
_ => self.super_rvalue(rvalue, location),
39+
}
40+
}
41+
42+
fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, _location: mir::Location) {
43+
if let mir::TerminatorKind::Return = terminator.kind {
44+
let op = mir::Operand::Copy(mir::Place::return_place());
45+
let ty = op.ty(self.fx.mir, self.bx.tcx());
46+
let ty = self.fx.monomorphize(ty);
47+
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
48+
self.places.push((op, niche));
49+
}
50+
}
51+
}
52+
53+
fn visit_place(
54+
&mut self,
55+
place: &mir::Place<'tcx>,
56+
context: PlaceContext,
57+
_location: mir::Location,
58+
) {
59+
match context {
60+
PlaceContext::NonMutatingUse(
61+
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
62+
) => {}
63+
_ => {
64+
return;
65+
}
66+
}
67+
68+
let ty = place.ty(self.fx.mir, self.bx.tcx()).ty;
69+
let ty = self.fx.monomorphize(ty);
70+
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
71+
self.places.push((mir::Operand::Copy(*place), niche));
72+
};
73+
}
74+
}
75+
76+
use rustc_target::abi::Abi;
77+
use rustc_target::abi::Scalar;
78+
use rustc_target::abi::WrappingRange;
79+
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
80+
fn value_in_niche(
81+
&mut self,
82+
bx: &mut Bx,
83+
op: crate::mir::OperandRef<'tcx, Bx::Value>,
84+
niche: Niche,
85+
) -> Option<Bx::Value> {
86+
let niche_ty = niche.ty(bx.tcx());
87+
let niche_layout = bx.layout_of(niche_ty);
88+
89+
let (imm, from_scalar, from_backend_ty) = match op.val {
90+
OperandValue::Immediate(imm) => {
91+
let Abi::Scalar(from_scalar) = op.layout.abi else { unreachable!() };
92+
let from_backend_ty = bx.backend_type(op.layout);
93+
(imm, from_scalar, from_backend_ty)
94+
}
95+
OperandValue::Pair(first, second) => {
96+
let Abi::ScalarPair(first_scalar, second_scalar) = op.layout.abi else {
97+
unreachable!()
98+
};
99+
if niche.offset == Size::ZERO {
100+
(first, first_scalar, bx.scalar_pair_element_backend_type(op.layout, 0, true))
101+
} else {
102+
// yolo
103+
(second, second_scalar, bx.scalar_pair_element_backend_type(op.layout, 1, true))
104+
}
105+
}
106+
OperandValue::ZeroSized => unreachable!(),
107+
OperandValue::Ref(PlaceValue { llval: ptr, .. }) => {
108+
// General case: Load the niche primitive via pointer arithmetic.
109+
let niche_ptr_ty = Ty::new_ptr(bx.tcx(), niche_ty, Mutability::Not);
110+
let ptr = bx.pointercast(ptr, bx.backend_type(bx.layout_of(niche_ptr_ty)));
111+
112+
let offset = niche.offset.bytes() / niche_layout.size.bytes();
113+
let niche_backend_ty = bx.backend_type(bx.layout_of(niche_ty));
114+
let ptr = bx.inbounds_gep(niche_backend_ty, ptr, &[bx.const_usize(offset)]);
115+
let value = bx.load(niche_backend_ty, ptr, rustc_target::abi::Align::ONE);
116+
return Some(value);
117+
}
118+
};
119+
120+
// Any type whose ABI is a Scalar bool is turned into an i1, so it cannot contain a value
121+
// outside of its niche.
122+
if from_scalar.is_bool() {
123+
return None;
124+
}
125+
126+
let to_scalar = Scalar::Initialized {
127+
value: niche.value,
128+
valid_range: WrappingRange::full(niche.size(bx.tcx())),
129+
};
130+
let to_backend_ty = bx.backend_type(niche_layout);
131+
if from_backend_ty == to_backend_ty {
132+
return Some(imm);
133+
}
134+
let value = self.transmute_immediate(
135+
bx,
136+
imm,
137+
from_scalar,
138+
from_backend_ty,
139+
to_scalar,
140+
to_backend_ty,
141+
);
142+
Some(value)
143+
}
144+
145+
#[instrument(level = "debug", skip(self, bx))]
146+
pub fn codegen_niche_check(
147+
&mut self,
148+
bx: &mut Bx,
149+
mir_op: mir::Operand<'tcx>,
150+
niche: Niche,
151+
source_info: mir::SourceInfo,
152+
) {
153+
let tcx = bx.tcx();
154+
let op_ty = self.monomorphize(mir_op.ty(self.mir, tcx));
155+
if op_ty == tcx.types.bool {
156+
return;
157+
}
158+
159+
let op = self.codegen_operand(bx, &mir_op);
160+
161+
let Some(value_in_niche) = self.value_in_niche(bx, op, niche) else {
162+
return;
163+
};
164+
let size = niche.size(tcx);
165+
166+
let start = niche.scalar(niche.valid_range.start, bx);
167+
let end = niche.scalar(niche.valid_range.end, bx);
168+
169+
let binop_le = base::bin_op_to_icmp_predicate(mir::BinOp::Le.to_hir_binop(), false);
170+
let binop_ge = base::bin_op_to_icmp_predicate(mir::BinOp::Ge.to_hir_binop(), false);
171+
let is_valid = if niche.valid_range.start == 0 {
172+
bx.icmp(binop_le, value_in_niche, end)
173+
} else if niche.valid_range.end == (u128::MAX >> 128 - size.bits()) {
174+
bx.icmp(binop_ge, value_in_niche, start)
175+
} else {
176+
// We need to check if the value is within a *wrapping* range. We could do this:
177+
// (niche >= start) && (niche <= end)
178+
// But what we're going to actually do is this:
179+
// max = end - start
180+
// (niche - start) <= max
181+
// The latter is much more complicated conceptually, but is actually less operations
182+
// because we can compute max in codegen.
183+
let mut max = niche.valid_range.end.wrapping_sub(niche.valid_range.start);
184+
let size = niche.size(tcx);
185+
if size.bits() < 128 {
186+
let mask = (1 << size.bits()) - 1;
187+
max &= mask;
188+
}
189+
let max_adjusted_allowed_value = niche.scalar(max, bx);
190+
191+
let biased = bx.sub(value_in_niche, start);
192+
bx.icmp(binop_le, biased, max_adjusted_allowed_value)
193+
};
194+
195+
// Create destination blocks, branching on is_valid
196+
let panic = bx.append_sibling_block("panic");
197+
let success = bx.append_sibling_block("success");
198+
bx.cond_br(is_valid, success, panic);
199+
200+
// Switch to the failure block and codegen a call to the panic intrinsic
201+
bx.switch_to_block(panic);
202+
self.set_debug_loc(bx, source_info);
203+
let location = self.get_caller_location(bx, source_info).immediate();
204+
self.codegen_panic(
205+
bx,
206+
niche.lang_item(),
207+
&[value_in_niche, start, end, location],
208+
source_info.span,
209+
);
210+
211+
// Continue codegen in the success block.
212+
bx.switch_to_block(success);
213+
self.set_debug_loc(bx, source_info);
214+
}
215+
216+
#[instrument(level = "debug", skip(self, bx))]
217+
fn codegen_panic(&mut self, bx: &mut Bx, lang_item: LangItem, args: &[Bx::Value], span: Span) {
218+
if bx.tcx().is_compiler_builtins(LOCAL_CRATE) {
219+
bx.abort()
220+
} else {
221+
let (fn_abi, fn_ptr, instance) = common::build_langcall(bx, Some(span), lang_item);
222+
let fn_ty = bx.fn_decl_backend_type(&fn_abi);
223+
let fn_attrs = if bx.tcx().def_kind(self.instance.def_id()).has_codegen_attrs() {
224+
Some(bx.tcx().codegen_fn_attrs(self.instance.def_id()))
225+
} else {
226+
None
227+
};
228+
bx.call(fn_ty, fn_attrs, Some(&fn_abi), fn_ptr, args, None, Some(instance));
229+
}
230+
bx.unreachable();
231+
}
232+
}
233+
234+
pub trait NicheExt {
235+
fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx>;
236+
fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size;
237+
fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value;
238+
fn lang_item(&self) -> LangItem;
239+
}
240+
241+
impl NicheExt for Niche {
242+
fn lang_item(&self) -> LangItem {
243+
match self.value {
244+
Primitive::Int(Integer::I8, _) => LangItem::PanicOccupiedNicheU8,
245+
Primitive::Int(Integer::I16, _) => LangItem::PanicOccupiedNicheU16,
246+
Primitive::Int(Integer::I32, _) => LangItem::PanicOccupiedNicheU32,
247+
Primitive::Int(Integer::I64, _) => LangItem::PanicOccupiedNicheU64,
248+
Primitive::Int(Integer::I128, _) => LangItem::PanicOccupiedNicheU128,
249+
Primitive::Pointer(_) => LangItem::PanicOccupiedNichePtr,
250+
Primitive::Float(Float::F16) => LangItem::PanicOccupiedNicheU16,
251+
Primitive::Float(Float::F32) => LangItem::PanicOccupiedNicheU32,
252+
Primitive::Float(Float::F64) => LangItem::PanicOccupiedNicheU64,
253+
Primitive::Float(Float::F128) => LangItem::PanicOccupiedNicheU128,
254+
}
255+
}
256+
257+
fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
258+
let types = &tcx.types;
259+
match self.value {
260+
Primitive::Int(Integer::I8, _) => types.u8,
261+
Primitive::Int(Integer::I16, _) => types.u16,
262+
Primitive::Int(Integer::I32, _) => types.u32,
263+
Primitive::Int(Integer::I64, _) => types.u64,
264+
Primitive::Int(Integer::I128, _) => types.u128,
265+
Primitive::Pointer(_) => Ty::new_ptr(tcx, types.unit, Mutability::Not),
266+
Primitive::Float(Float::F16) => types.u16,
267+
Primitive::Float(Float::F32) => types.u32,
268+
Primitive::Float(Float::F64) => types.u64,
269+
Primitive::Float(Float::F128) => types.u128,
270+
}
271+
}
272+
273+
fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size {
274+
self.value.size(&tcx)
275+
}
276+
277+
fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value {
278+
use rustc_middle::mir::interpret::Pointer;
279+
use rustc_middle::mir::interpret::Scalar;
280+
281+
let tcx = bx.tcx();
282+
let niche_ty = self.ty(tcx);
283+
let value = if niche_ty.is_any_ptr() {
284+
Scalar::from_maybe_pointer(Pointer::from_addr_invalid(val as u64), &tcx)
285+
} else {
286+
Scalar::from_uint(val, self.size(tcx))
287+
};
288+
let layout = rustc_target::abi::Scalar::Initialized {
289+
value: self.value,
290+
valid_range: WrappingRange::full(self.size(tcx)),
291+
};
292+
bx.scalar_to_backend(value, layout, bx.backend_type(bx.layout_of(self.ty(tcx))))
293+
}
294+
}

compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
161161
}
162162
}
163163

164-
fn codegen_transmute(
164+
pub fn codegen_transmute(
165165
&mut self,
166166
bx: &mut Bx,
167167
src: OperandRef<'tcx, Bx::Value>,
@@ -196,7 +196,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
196196
///
197197
/// Returns `None` for cases that can't work in that framework, such as for
198198
/// `Immediate`->`Ref` that needs an `alloc` to get the location.
199-
fn codegen_transmute_operand(
199+
pub fn codegen_transmute_operand(
200200
&mut self,
201201
bx: &mut Bx,
202202
operand: OperandRef<'tcx, Bx::Value>,
@@ -286,7 +286,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
286286
///
287287
/// `to_backend_ty` must be the *non*-immediate backend type (so it will be
288288
/// `i8`, not `i1`, for `bool`-like types.)
289-
fn transmute_immediate(
289+
pub fn transmute_immediate(
290290
&self,
291291
bx: &mut Bx,
292292
mut imm: Bx::Value,

compiler/rustc_codegen_ssa/src/mir/statement.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,35 @@
1-
use rustc_middle::mir::{self, NonDivergingIntrinsic};
1+
use rustc_middle::mir;
2+
use rustc_middle::mir::visit::Visitor;
3+
use rustc_middle::mir::NonDivergingIntrinsic;
24
use rustc_middle::span_bug;
35
use rustc_session::config::OptLevel;
46

57
use super::FunctionCx;
68
use super::LocalRef;
9+
use crate::mir::niche_check::NicheFinder;
710
use crate::traits::*;
811

912
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
13+
fn niches_to_check(
14+
&mut self,
15+
bx: &mut Bx,
16+
statement: &mir::Statement<'tcx>,
17+
) -> Vec<(mir::Operand<'tcx>, rustc_target::abi::Niche)> {
18+
let mut finder = NicheFinder { fx: self, bx, places: Vec::new() };
19+
finder.visit_statement(statement, rustc_middle::mir::Location::START);
20+
finder.places
21+
}
22+
1023
#[instrument(level = "debug", skip(self, bx))]
1124
pub fn codegen_statement(&mut self, bx: &mut Bx, statement: &mir::Statement<'tcx>) {
1225
self.set_debug_loc(bx, statement.source_info);
26+
27+
if bx.tcx().may_insert_niche_checks() {
28+
for (op, niche) in self.niches_to_check(bx, statement) {
29+
self.codegen_niche_check(bx, op, niche, statement.source_info);
30+
}
31+
}
32+
1333
match statement.kind {
1434
mir::StatementKind::Assign(box (ref place, ref rvalue)) => {
1535
if let Some(index) = place.as_local() {

0 commit comments

Comments
 (0)