Skip to content

Commit b8d8ea6

Browse files
committed
Auto merge of #48300 - eddyb:split-aggregates, r=<try>
rustc_mir: add a pass for splitting locals into their fields (aka SROA). **DO NOT MERGE**: based on #48052.
2 parents b298607 + 9a18264 commit b8d8ea6

File tree

18 files changed

+632
-137
lines changed

18 files changed

+632
-137
lines changed

Diff for: src/librustc/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
#![feature(underscore_lifetimes)]
7070
#![feature(universal_impl_trait)]
7171
#![feature(trace_macros)]
72+
#![feature(trusted_len)]
7273
#![feature(catch_expr)]
7374
#![feature(test)]
7475

Diff for: src/librustc/mir/mod.rs

+62-3
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ use std::ascii;
3434
use std::borrow::{Cow};
3535
use std::cell::Ref;
3636
use std::fmt::{self, Debug, Formatter, Write};
37-
use std::{iter, u32};
37+
use std::{iter, mem, u32};
3838
use std::ops::{Index, IndexMut};
3939
use std::rc::Rc;
4040
use std::vec::IntoIter;
4141
use syntax::ast::{self, Name};
4242
use syntax::symbol::InternedString;
43-
use syntax_pos::Span;
43+
use syntax_pos::{Span, DUMMY_SP};
4444

4545
mod cache;
4646
pub mod tcx;
@@ -984,11 +984,62 @@ impl<'tcx> BasicBlockData<'tcx> {
984984
pub fn retain_statements<F>(&mut self, mut f: F) where F: FnMut(&mut Statement) -> bool {
985985
for s in &mut self.statements {
986986
if !f(s) {
987-
s.kind = StatementKind::Nop;
987+
s.make_nop();
988988
}
989989
}
990990
}
991991

992+
pub fn expand_statements<F, I>(&mut self, mut f: F)
993+
where F: FnMut(&mut Statement<'tcx>) -> Option<I>,
994+
I: iter::TrustedLen<Item = Statement<'tcx>>
995+
{
996+
// Gather all the iterators we'll need to splice in, and their positions.
997+
let mut splices: Vec<(usize, I)> = vec![];
998+
let mut extra_stmts = 0;
999+
for (i, s) in self.statements.iter_mut().enumerate() {
1000+
if let Some(mut new_stmts) = f(s) {
1001+
if let Some(first) = new_stmts.next() {
1002+
// We can already store the first new statement.
1003+
*s = first;
1004+
1005+
// Save the other statements for optimized splicing.
1006+
let remaining = new_stmts.size_hint().0;
1007+
if remaining > 0 {
1008+
splices.push((i + 1 + extra_stmts, new_stmts));
1009+
extra_stmts += remaining;
1010+
}
1011+
} else {
1012+
s.make_nop();
1013+
}
1014+
}
1015+
}
1016+
1017+
// Splice in the new statements, from the end of the block.
1018+
// FIXME(eddyb) This could be more efficient with a "gap buffer"
1019+
// where a range of elements ("gap") is left uninitialized, with
1020+
// splicing adding new elements to the end of that gap and moving
1021+
// existing elements from before the gap to the end of the gap.
1022+
// For now, this is safe code, emulating a gap but initializing it.
1023+
let mut gap = self.statements.len()..self.statements.len()+extra_stmts;
1024+
self.statements.resize(gap.end, Statement {
1025+
source_info: SourceInfo {
1026+
span: DUMMY_SP,
1027+
scope: ARGUMENT_VISIBILITY_SCOPE
1028+
},
1029+
kind: StatementKind::Nop
1030+
});
1031+
for (splice_start, new_stmts) in splices.into_iter().rev() {
1032+
let splice_end = splice_start + new_stmts.size_hint().0;
1033+
while gap.end > splice_end {
1034+
gap.start -= 1;
1035+
gap.end -= 1;
1036+
self.statements.swap(gap.start, gap.end);
1037+
}
1038+
self.statements.splice(splice_start..splice_end, new_stmts);
1039+
gap.end = splice_start;
1040+
}
1041+
}
1042+
9921043
pub fn visitable(&self, index: usize) -> &dyn MirVisitable<'tcx> {
9931044
if index < self.statements.len() {
9941045
&self.statements[index]
@@ -1157,6 +1208,14 @@ impl<'tcx> Statement<'tcx> {
11571208
pub fn make_nop(&mut self) {
11581209
self.kind = StatementKind::Nop
11591210
}
1211+
1212+
/// Changes a statement to a nop and returns the original statement.
1213+
pub fn replace_nop(&mut self) -> Self {
1214+
Statement {
1215+
source_info: self.source_info,
1216+
kind: mem::replace(&mut self.kind, StatementKind::Nop)
1217+
}
1218+
}
11601219
}
11611220

11621221
#[derive(Clone, Debug, RustcEncodable, RustcDecodable)]

Diff for: src/librustc_mir/transform/deaggregator.rs

+72-98
Original file line numberDiff line numberDiff line change
@@ -21,116 +21,90 @@ impl MirPass for Deaggregator {
2121
tcx: TyCtxt<'a, 'tcx, 'tcx>,
2222
source: MirSource,
2323
mir: &mut Mir<'tcx>) {
24-
let node_path = tcx.item_path_str(source.def_id);
25-
debug!("running on: {:?}", node_path);
26-
// we only run when mir_opt_level > 2
27-
if tcx.sess.opts.debugging_opts.mir_opt_level <= 2 {
28-
return;
29-
}
30-
3124
// Don't run on constant MIR, because trans might not be able to
3225
// evaluate the modified MIR.
3326
// FIXME(eddyb) Remove check after miri is merged.
3427
let id = tcx.hir.as_local_node_id(source.def_id).unwrap();
3528
match (tcx.hir.body_owner_kind(id), source.promoted) {
36-
(hir::BodyOwnerKind::Fn, None) => {},
37-
_ => return
38-
}
39-
// In fact, we might not want to trigger in other cases.
40-
// Ex: when we could use SROA. See issue #35259
29+
(_, Some(_)) |
30+
(hir::BodyOwnerKind::Const, _) |
31+
(hir::BodyOwnerKind::Static(_), _) => return,
4132

42-
for bb in mir.basic_blocks_mut() {
43-
let mut curr: usize = 0;
44-
while let Some(idx) = get_aggregate_statement_index(curr, &bb.statements) {
45-
// do the replacement
46-
debug!("removing statement {:?}", idx);
47-
let src_info = bb.statements[idx].source_info;
48-
let suffix_stmts = bb.statements.split_off(idx+1);
49-
let orig_stmt = bb.statements.pop().unwrap();
50-
let (lhs, rhs) = match orig_stmt.kind {
51-
StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs),
52-
_ => span_bug!(src_info.span, "expected assign, not {:?}", orig_stmt),
53-
};
54-
let (agg_kind, operands) = match rhs {
55-
&Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands),
56-
_ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs),
57-
};
58-
let (adt_def, variant, substs) = match **agg_kind {
59-
AggregateKind::Adt(adt_def, variant, substs, None)
60-
=> (adt_def, variant, substs),
61-
_ => span_bug!(src_info.span, "expected struct, not {:?}", rhs),
62-
};
63-
let n = bb.statements.len();
64-
bb.statements.reserve(n + operands.len() + suffix_stmts.len());
65-
for (i, op) in operands.iter().enumerate() {
66-
let ref variant_def = adt_def.variants[variant];
67-
let ty = variant_def.fields[i].ty(tcx, substs);
68-
let rhs = Rvalue::Use(op.clone());
33+
(hir::BodyOwnerKind::Fn, _) => {
34+
if tcx.is_const_fn(source.def_id) {
35+
// Don't run on const functions, as, again, trans might not be able to evaluate
36+
// the optimized IR.
37+
return
38+
}
39+
}
40+
}
6941

70-
let lhs_cast = if adt_def.is_enum() {
71-
Place::Projection(Box::new(PlaceProjection {
72-
base: lhs.clone(),
73-
elem: ProjectionElem::Downcast(adt_def, variant),
74-
}))
42+
let (basic_blocks, local_decls) = mir.basic_blocks_and_local_decls_mut();
43+
let local_decls = &*local_decls;
44+
for bb in basic_blocks {
45+
bb.expand_statements(|stmt| {
46+
// FIXME(eddyb) don't match twice on `stmt.kind` (post-NLL).
47+
if let StatementKind::Assign(_, ref rhs) = stmt.kind {
48+
if let Rvalue::Aggregate(ref kind, _) = *rhs {
49+
// FIXME(#48193) Deaggregate arrays when it's cheaper to do so.
50+
if let AggregateKind::Array(_) = **kind {
51+
return None;
52+
}
7553
} else {
76-
lhs.clone()
77-
};
78-
79-
let lhs_proj = Place::Projection(Box::new(PlaceProjection {
80-
base: lhs_cast,
81-
elem: ProjectionElem::Field(Field::new(i), ty),
82-
}));
83-
let new_statement = Statement {
84-
source_info: src_info,
85-
kind: StatementKind::Assign(lhs_proj, rhs),
86-
};
87-
debug!("inserting: {:?} @ {:?}", new_statement, idx + i);
88-
bb.statements.push(new_statement);
54+
return None;
55+
}
56+
} else {
57+
return None;
8958
}
9059

91-
// if the aggregate was an enum, we need to set the discriminant
92-
if adt_def.is_enum() {
93-
let set_discriminant = Statement {
94-
kind: StatementKind::SetDiscriminant {
95-
place: lhs.clone(),
96-
variant_index: variant,
97-
},
98-
source_info: src_info,
99-
};
100-
bb.statements.push(set_discriminant);
60+
let stmt = stmt.replace_nop();
61+
let source_info = stmt.source_info;
62+
let (mut lhs, kind, operands) = match stmt.kind {
63+
StatementKind::Assign(lhs, Rvalue::Aggregate(kind, operands))
64+
=> (lhs, kind, operands),
65+
_ => bug!()
10166
};
10267

103-
curr = bb.statements.len();
104-
bb.statements.extend(suffix_stmts);
105-
}
106-
}
107-
}
108-
}
68+
let mut set_discriminant = None;
69+
let active_field_index = match *kind {
70+
AggregateKind::Adt(adt_def, variant_index, _, active_field_index) => {
71+
if adt_def.is_enum() {
72+
set_discriminant = Some(Statement {
73+
kind: StatementKind::SetDiscriminant {
74+
place: lhs.clone(),
75+
variant_index,
76+
},
77+
source_info,
78+
});
79+
lhs = lhs.downcast(adt_def, variant_index);
80+
}
81+
active_field_index
82+
}
83+
_ => None
84+
};
10985

110-
fn get_aggregate_statement_index<'a, 'tcx, 'b>(start: usize,
111-
statements: &Vec<Statement<'tcx>>)
112-
-> Option<usize> {
113-
for i in start..statements.len() {
114-
let ref statement = statements[i];
115-
let rhs = match statement.kind {
116-
StatementKind::Assign(_, ref rhs) => rhs,
117-
_ => continue,
118-
};
119-
let (kind, operands) = match rhs {
120-
&Rvalue::Aggregate(ref kind, ref operands) => (kind, operands),
121-
_ => continue,
122-
};
123-
let (adt_def, variant) = match **kind {
124-
AggregateKind::Adt(adt_def, variant, _, None) => (adt_def, variant),
125-
_ => continue,
126-
};
127-
if operands.len() == 0 {
128-
// don't deaggregate ()
129-
continue;
86+
Some(operands.into_iter().enumerate().map(move |(i, op)| {
87+
let lhs_field = if let AggregateKind::Array(_) = *kind {
88+
// FIXME(eddyb) `offset` should be u64.
89+
let offset = i as u32;
90+
assert_eq!(offset as usize, i);
91+
lhs.clone().elem(ProjectionElem::ConstantIndex {
92+
offset,
93+
// FIXME(eddyb) `min_length` doesn't appear to be used.
94+
min_length: offset + 1,
95+
from_end: false
96+
})
97+
} else {
98+
let ty = op.ty(local_decls, tcx);
99+
let field = Field::new(active_field_index.unwrap_or(i));
100+
lhs.clone().field(field, ty)
101+
};
102+
Statement {
103+
source_info,
104+
kind: StatementKind::Assign(lhs_field, Rvalue::Use(op)),
105+
}
106+
}).chain(set_discriminant))
107+
});
130108
}
131-
debug!("getting variant {:?}", variant);
132-
debug!("for adt_def {:?}", adt_def);
133-
return Some(i);
134-
};
135-
None
109+
}
136110
}

Diff for: src/librustc_mir/transform/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub mod generator;
4545
pub mod inline;
4646
pub mod lower_128bit;
4747
pub mod uniform_array_move_out;
48+
pub mod split_local_fields;
4849

4950
pub(crate) fn provide(providers: &mut Providers) {
5051
self::qualify_consts::provide(providers);
@@ -258,8 +259,14 @@ fn optimized_mir<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>, def_id: DefId) -> &'tcx
258259

259260
// Optimizations begin.
260261
inline::Inline,
262+
263+
// Lowering generator control-flow and variables
264+
// has to happen before we do anything else to them.
265+
generator::StateTransform,
266+
261267
instcombine::InstCombine,
262268
deaggregator::Deaggregator,
269+
split_local_fields::SplitLocalFields,
263270
copy_prop::CopyPropagation,
264271
remove_noop_landing_pads::RemoveNoopLandingPads,
265272
simplify::SimplifyCfg::new("final"),

Diff for: src/librustc_mir/transform/simplify.rs

+13-4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use rustc_data_structures::indexed_vec::{Idx, IndexVec};
4242
use rustc::ty::TyCtxt;
4343
use rustc::mir::*;
4444
use rustc::mir::visit::{MutVisitor, Visitor, PlaceContext};
45+
use rustc::session::config::FullDebugInfo;
4546
use std::borrow::Cow;
4647
use transform::{MirPass, MirSource};
4748

@@ -281,16 +282,24 @@ pub struct SimplifyLocals;
281282

282283
impl MirPass for SimplifyLocals {
283284
fn run_pass<'a, 'tcx>(&self,
284-
_: TyCtxt<'a, 'tcx, 'tcx>,
285+
tcx: TyCtxt<'a, 'tcx, 'tcx>,
285286
_: MirSource,
286287
mir: &mut Mir<'tcx>) {
287288
let mut marker = DeclMarker { locals: BitVector::new(mir.local_decls.len()) };
288289
marker.visit_mir(mir);
289290
// Return pointer and arguments are always live
290-
marker.locals.insert(0);
291-
for idx in mir.args_iter() {
292-
marker.locals.insert(idx.index());
291+
marker.locals.insert(RETURN_PLACE.index());
292+
for arg in mir.args_iter() {
293+
marker.locals.insert(arg.index());
293294
}
295+
296+
// We may need to keep dead user variables live for debuginfo.
297+
if tcx.sess.opts.debuginfo == FullDebugInfo {
298+
for local in mir.vars_iter() {
299+
marker.locals.insert(local.index());
300+
}
301+
}
302+
294303
let map = make_local_map(&mut mir.local_decls, marker.locals);
295304
// Update references to all vars and tmps now
296305
LocalUpdater { map: map }.visit_mir(mir);

0 commit comments

Comments
 (0)