Skip to content

Commit 950d91b

Browse files
committed
addressing a ton of feedback, thank you
1 parent fe8b699 commit 950d91b

File tree

8 files changed

+73
-74
lines changed

8 files changed

+73
-74
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ fn generate_enzyme_call<'ll>(
256256
trace!("no dbg info");
257257
}
258258
// Dump module:
259-
llvm::LLVMDumpModule(cx.llmod);
259+
//llvm::LLVMDumpModule(cx.llmod);
260260
// now print the last instruction:
261-
llvm::LLVMDumpValue(last_inst);
261+
//llvm::LLVMDumpValue(last_inst);
262262

263263
// Now that we copied the metadata, get rid of dummy code.
264264
llvm::LLVMRustEraseInstBefore(entry, last_inst);

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

+25-49
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
1818
};
1919
use rustc_middle::mir::mono::Linkage;
2020
use rustc_middle::query::Providers;
21+
use rustc_middle::span_bug;
2122
use rustc_middle::ty::{self as ty, TyCtxt};
2223
use rustc_session::parse::feature_err;
2324
use rustc_session::{Session, lint};
@@ -872,114 +873,89 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
872873

873874
// check for exactly one autodiff attribute on placeholder functions.
874875
// There should only be one, since we generate a new placeholder per ad macro.
875-
// TODO: re-enable this. We should fix that rustc_autodiff isn't applied multiple times to the
876-
// source function.
877-
let msg_once = "cg_ssa: implementation bug. Autodiff attribute can only be applied once";
876+
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
877+
// looks strange e.g. under cargo-expand.
878878
let attr = match attrs.len() {
879879
0 => return AutoDiffAttrs::error(),
880880
1 => attrs.get(0).unwrap(),
881881
_ => {
882882
attrs.get(0).unwrap()
883-
//tcx.dcx().struct_span_err(attrs[1].span, msg_once).with_note("more than one").emit();
884-
//return AutoDiffAttrs::error();
883+
//FIXME(ZuseZ4): re-enable this check
884+
//span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
885885
}
886886
};
887887

888888
let list = attr.meta_item_list().unwrap_or_default();
889889

890890
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
891-
if list.len() == 0 {
891+
if list.is_empty() {
892892
return AutoDiffAttrs::source();
893893
}
894894

895895
let [mode, input_activities @ .., ret_activity] = &list[..] else {
896-
tcx.dcx()
897-
.struct_span_err(attr.span, msg_once)
898-
.with_note("Implementation bug in autodiff_attrs. Please report this!")
899-
.emit();
900-
return AutoDiffAttrs::error();
896+
span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
901897
};
902898
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
903899
p1.segments.first().unwrap().ident
904900
} else {
905-
let msg = "autodiff attribute must contain autodiff mode";
906-
tcx.dcx().struct_span_err(attr.span, msg).with_note("empty argument list").emit();
907-
return AutoDiffAttrs::error();
901+
span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
908902
};
909903

910904
// parse mode
911-
let msg_mode = "mode should be either forward or reverse";
912905
let mode = match mode.as_str() {
913906
"Forward" => DiffMode::Forward,
914907
"Reverse" => DiffMode::Reverse,
915908
"ForwardFirst" => DiffMode::ForwardFirst,
916909
"ReverseFirst" => DiffMode::ReverseFirst,
917910
_ => {
918-
tcx.dcx().struct_span_err(attr.span, msg_mode).with_note("invalid mode").emit();
919-
return AutoDiffAttrs::error();
911+
span_bug!(attr.span, "rustc_autodiff attribute contains invalid mode");
920912
}
921913
};
922914

923915
// First read the ret symbol from the attribute
924916
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
925917
p1.segments.first().unwrap().ident
926918
} else {
927-
let msg = "autodiff attribute must contain the return activity";
928-
tcx.dcx().struct_span_err(attr.span, msg).with_note("missing return activity").emit();
929-
return AutoDiffAttrs::error();
919+
span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
930920
};
931921

932922
// Then parse it into an actual DiffActivity
933-
let msg_unknown_ret_activity = "unknown return activity";
934-
let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) {
935-
Ok(x) => x,
936-
Err(_) => {
937-
tcx.dcx()
938-
.struct_span_err(attr.span, msg_unknown_ret_activity)
939-
.with_note("invalid return activity")
940-
.emit();
941-
return AutoDiffAttrs::error();
942-
}
923+
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
924+
span_bug!(attr.span, "invalid return activity");
943925
};
944926

945927
// Now parse all the intermediate (input) activities
946-
let msg_arg_activity = "autodiff attribute must contain the return activity";
947928
let mut arg_activities: Vec<DiffActivity> = vec![];
948929
for arg in input_activities {
949930
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
950-
p2.segments.first().unwrap().ident
931+
match p2.segments.first() {
932+
Some(x) => x.ident,
933+
None => {
934+
span_bug!(
935+
attr.span,
936+
"rustc_autodiff attribute must contain the input activity"
937+
);
938+
}
939+
}
951940
} else {
952-
tcx.dcx()
953-
.struct_span_err(attr.span, msg_arg_activity)
954-
.with_note("Implementation bug, please report this!")
955-
.emit();
956-
return AutoDiffAttrs::error();
941+
span_bug!(attr.span, "rustc_autodiff attribute must contain the input activity");
957942
};
958943

959944
match DiffActivity::from_str(arg_symbol.as_str()) {
960945
Ok(arg_activity) => arg_activities.push(arg_activity),
961946
Err(_) => {
962-
tcx.dcx()
963-
.struct_span_err(attr.span, msg_unknown_ret_activity)
964-
.with_note("invalid input activity")
965-
.emit();
966-
return AutoDiffAttrs::error();
947+
span_bug!(attr.span, "invalid input activity");
967948
}
968949
}
969950
}
970951

971-
let mut msg = "".to_string();
972952
for &input in &arg_activities {
973953
if !valid_input_activity(mode, input) {
974-
msg = format!("Invalid input activity {} for {} mode", input, mode);
954+
span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
975955
}
976956
}
977957
if !valid_ret_activity(mode, ret_activity) {
978-
msg = format!("Invalid return activity {} for {} mode", ret_activity, mode);
979-
}
980-
if msg != "".to_string() {
981-
tcx.dcx().struct_span_err(attr.span, msg).with_note("invalid activity").emit();
982-
return AutoDiffAttrs::error();
958+
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
983959
}
984960

985961
AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }

compiler/rustc_codegen_ssa/src/traits/write.rs

-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
1313
type ModuleBuffer: ModuleBufferMethods;
1414
type ThinData: Send + Sync;
1515
type ThinBuffer: ThinBufferMethods;
16-
//type TypeTree: Clone;
1716

1817
/// Merge all modules into main_module and returning it
1918
fn run_link(
@@ -38,7 +37,6 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
3837
) -> Result<(Vec<LtoModuleCodegen<Self>>, Vec<WorkProduct>), FatalError>;
3938
fn print_pass_timings(&self);
4039
fn print_statistics(&self);
41-
// does enzyme prep work, should do ad too.
4240
unsafe fn optimize(
4341
cgcx: &CodegenContext<Self>,
4442
dcx: DiagCtxtHandle<'_>,

compiler/rustc_middle/src/query/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1393,7 +1393,7 @@ rustc_queries! {
13931393
feedable
13941394
}
13951395

1396-
/// The list autodiff extern functions in current crate
1396+
/// List of autodiff extern functions in the current crate.
13971397
query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs {
13981398
desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) }
13991399
arena_cache

compiler/rustc_monomorphize/src/partitioning.rs

+15-16
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,13 @@ where
255255
);
256256

257257
// We can't differentiate something that got inlined.
258-
let autodiff_active = match characteristic_def_id {
259-
Some(def_id) => cx.tcx.autodiff_attrs(def_id).is_active(),
260-
None => false,
258+
let autodiff_active = if cfg!(llvm_enzyme) {
259+
match characteristic_def_id {
260+
Some(def_id) => cx.tcx.autodiff_attrs(def_id).is_active(),
261+
None => false,
262+
}
263+
} else {
264+
false
261265
};
262266

263267
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
@@ -1128,9 +1132,8 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
11281132
fn_ty: Ty<'tcx>,
11291133
da: &mut Vec<DiffActivity>,
11301134
) {
1131-
if !fn_ty.is_fn() {
1132-
// Error?
1133-
return;
1135+
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
1136+
bug!("expected fn def for autodiff, got {:?}", fn_ty);
11341137
}
11351138
let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
11361139

@@ -1149,7 +1152,7 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
11491152
}
11501153
let inner_ty = ty.builtin_deref(true).unwrap();
11511154
if inner_ty.is_slice() {
1152-
// We know that the lenght will be passed as extra arg.
1155+
// We know that the length will be passed as extra arg.
11531156
if !da.is_empty() {
11541157
// We are looking at a slice. The length of that slice will become an
11551158
// extra integer on llvm level. Integers are always const.
@@ -1161,12 +1164,11 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
11611164
| DiffActivity::DuplicatedOnly
11621165
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
11631166
DiffActivity::Const => DiffActivity::Const,
1164-
_ => panic!("unexpected activity for ptr/ref"),
1167+
_ => bug!("unexpected activity for ptr/ref"),
11651168
};
11661169
new_activities.push(activity);
11671170
new_positions.push(i + 1);
11681171
}
1169-
trace!("ABI MATCHING!");
11701172
continue;
11711173
}
11721174
}
@@ -1245,7 +1247,7 @@ fn collect_and_partition_mono_items(
12451247
})
12461248
.collect();
12471249

1248-
let autodiff_items2: Vec<_> = items
1250+
let autodiff_mono_items: Vec<_> = items
12491251
.iter()
12501252
.filter_map(|item| match *item {
12511253
MonoItem::Fn(ref instance) => Some((item, instance)),
@@ -1254,7 +1256,7 @@ fn collect_and_partition_mono_items(
12541256
.collect();
12551257
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
12561258

1257-
for (item, instance) in autodiff_items2 {
1259+
for (item, instance) in autodiff_mono_items {
12581260
let target_id = instance.def_id();
12591261
let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id);
12601262
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
@@ -1283,18 +1285,15 @@ fn collect_and_partition_mono_items(
12831285
None => continue,
12841286
};
12851287

1286-
println!("source_id: {:?}", inst.def_id());
1288+
debug!("source_id: {:?}", inst.def_id());
12871289
let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
1288-
//let fn_ty = inst.ty(tcx, ParamEnv::empty());ty::TypingEnv<'tcx>::fully_monomorphized()
12891290
assert!(fn_ty.is_fn());
12901291
adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
1291-
//let (inputs, output) = (fnc_tree.args, fnc_tree.ret);
12921292
let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
12931293

12941294
let mut new_target_attrs = target_attrs.clone();
12951295
new_target_attrs.input_activity = input_activities;
12961296
let itm = new_target_attrs.into_item(symb, target_symbol);
1297-
//let itm = new_target_attrs.into_item(symb, target_symbol, inputs, output);
12981297
autodiff_items.push(itm);
12991298
}
13001299
let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
@@ -1359,7 +1358,7 @@ fn collect_and_partition_mono_items(
13591358
}
13601359
}
13611360

1362-
if autodiff_items.len() > 0 {
1361+
if !autodiff_items.is_empty() {
13631362
trace!("AUTODIFF ITEMS EXIST");
13641363
for item in &mut *autodiff_items {
13651364
trace!("{}", &item);

compiler/rustc_session/src/config.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ pub enum CoverageLevel {
195195
// }
196196
//}
197197

198-
/// The different settings that the `-Z ad` flag can have.
198+
/// The different settings that the `-Z autodiff` flag can have.
199199
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
200200
pub enum AutoDiff {
201201
/// Print TypeAnalysis information

compiler/rustc_session/src/options.rs

+20-3
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ mod desc {
398398
pub(crate) const parse_list: &str = "a space-separated list of strings";
399399
pub(crate) const parse_list_with_polarity: &str =
400400
"a comma-separated list of strings, with elements beginning with + or -";
401-
pub(crate) const parse_autodiff: &str = "various values";
401+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`";
402402
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
403403
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
404404
pub(crate) const parse_number: &str = "a number";
@@ -1051,7 +1051,10 @@ pub mod parse {
10511051
"EnableFncOpt" => AutoDiff::EnableFncOpt,
10521052
"NoVecUnroll" => AutoDiff::NoVecUnroll,
10531053
"Inline" => AutoDiff::Inline,
1054-
_ => return false,
1054+
_ => {
1055+
// FIXME(ZuseZ4): print an error saying which value is not recognized
1056+
return false;
1057+
}
10551058
};
10561059
slot.push(variant);
10571060
}
@@ -1767,7 +1770,21 @@ options! {
17671770
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
17681771
"make cfg(version) treat the current version as incomplete (default: no)"),
17691772
autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
1770-
"a list autodiff flags to enable (comma separated)"),
1773+
"a list of optional autodiff flags to enable
1774+
Optional extra settings:
1775+
`=PrintTA`
1776+
`=PrintAA`
1777+
`=PrintPerf`
1778+
`=Print`
1779+
`=PrintModBefore`
1780+
`=PrintModAfterOpts`
1781+
`=PrintModAfterEnzyme`
1782+
`=LooseTypes`
1783+
`=NoModOptAfter`
1784+
`=EnableFncOpt`
1785+
`=NoVecUnroll`
1786+
`=Inline`
1787+
Multiple options can be combined with commas."),
17711788
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
17721789
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
17731790
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# `autodiff`
2+
3+
The tracking issue for this feature is: [#124509](https://github.com/rust-lang/rust/issues/124509).
4+
5+
------------------------
6+
7+
This feature allows you to differentiate functions using automatic differentiation.
8+
Set the `-Zautodiff=<options>` compiler flag to adjust the behaviour of the autodiff feature.
9+

0 commit comments

Comments
 (0)