Skip to content

Commit 49f62df

Browse files
committed
Allows customization of builtin functions under FatLTO
1 parent c6b2967 commit 49f62df

File tree

5 files changed

+151
-28
lines changed

5 files changed

+151
-28
lines changed

Diff for: compiler/rustc_codegen_llvm/src/back/lto.rs

+44-11
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl SerializedModuleInfo {
6060
fn prepare_lto(
6161
cgcx: &CodegenContext<LlvmCodegenBackend>,
6262
dcx: &DiagCtxt,
63-
) -> Result<(Vec<CString>, Vec<SerializedModuleInfo>), FatalError> {
63+
) -> Result<(Vec<CString>, Vec<SerializedModuleInfo>, Vec<CString>), FatalError> {
6464
let export_threshold = match cgcx.lto {
6565
// We're just doing LTO for our one crate
6666
Lto::ThinLocal => SymbolExportLevel::Rust,
@@ -85,6 +85,17 @@ fn prepare_lto(
8585
};
8686
info!("{} symbols to preserve in this crate", symbols_below_threshold.len());
8787

88+
let compiler_builtins_exported_symbols = match cgcx.compiler_builtins {
89+
Some(crate_num) => {
90+
if let Some(exported_symbols) = exported_symbols.get(&crate_num) {
91+
exported_symbols.iter().filter_map(symbol_filter).collect::<Vec<CString>>()
92+
} else {
93+
Vec::new()
94+
}
95+
}
96+
None => Vec::new(),
97+
};
98+
8899
// If we're performing LTO for the entire crate graph, then for each of our
89100
// upstream dependencies, find the corresponding rlib and load the bitcode
90101
// from the archive.
@@ -167,7 +178,7 @@ fn prepare_lto(
167178
// __llvm_profile_runtime, therefore we won't know until link time if this symbol
168179
// should have default visibility.
169180
symbols_below_threshold.push(CString::new("__llvm_profile_counter_bias").unwrap());
170-
Ok((symbols_below_threshold, upstream_modules))
181+
Ok((symbols_below_threshold, upstream_modules, compiler_builtins_exported_symbols))
171182
}
172183

173184
fn get_bitcode_slice_from_object_data<'a>(
@@ -218,10 +229,21 @@ pub(crate) fn run_fat(
218229
cached_modules: Vec<(SerializedModule<ModuleBuffer>, WorkProduct)>,
219230
) -> Result<LtoModuleCodegen<LlvmCodegenBackend>, FatalError> {
220231
let dcx = cgcx.create_dcx();
221-
let (symbols_below_threshold, upstream_modules) = prepare_lto(cgcx, &dcx)?;
232+
let (symbols_below_threshold, upstream_modules, compiler_builtins_exported_symbols) =
233+
prepare_lto(cgcx, &dcx)?;
222234
let symbols_below_threshold =
223235
symbols_below_threshold.iter().map(|c| c.as_ptr()).collect::<Vec<_>>();
224-
fat_lto(cgcx, &dcx, modules, cached_modules, upstream_modules, &symbols_below_threshold)
236+
let compiler_builtins_exported_symbols =
237+
compiler_builtins_exported_symbols.iter().map(|c| c.as_ptr()).collect::<Vec<_>>();
238+
fat_lto(
239+
cgcx,
240+
&dcx,
241+
modules,
242+
cached_modules,
243+
upstream_modules,
244+
&symbols_below_threshold,
245+
&compiler_builtins_exported_symbols,
246+
)
225247
}
226248

227249
/// Performs thin LTO by performing necessary global analysis and returning two
@@ -233,7 +255,7 @@ pub(crate) fn run_thin(
233255
cached_modules: Vec<(SerializedModule<ModuleBuffer>, WorkProduct)>,
234256
) -> Result<(Vec<LtoModuleCodegen<LlvmCodegenBackend>>, Vec<WorkProduct>), FatalError> {
235257
let dcx = cgcx.create_dcx();
236-
let (symbols_below_threshold, upstream_modules) = prepare_lto(cgcx, &dcx)?;
258+
let (symbols_below_threshold, upstream_modules, _) = prepare_lto(cgcx, &dcx)?;
237259
let symbols_below_threshold =
238260
symbols_below_threshold.iter().map(|c| c.as_ptr()).collect::<Vec<_>>();
239261
if cgcx.opts.cg.linker_plugin_lto.enabled() {
@@ -258,6 +280,7 @@ fn fat_lto(
258280
cached_modules: Vec<(SerializedModule<ModuleBuffer>, WorkProduct)>,
259281
mut serialized_modules: Vec<SerializedModuleInfo>,
260282
symbols_below_threshold: &[*const libc::c_char],
283+
compiler_builtins_exported_symbols: &[*const libc::c_char],
261284
) -> Result<LtoModuleCodegen<LlvmCodegenBackend>, FatalError> {
262285
let _timer = cgcx.prof.generic_activity("LLVM_fat_lto_build_monolithic_module");
263286
info!("going for a fat lto");
@@ -372,17 +395,19 @@ fn fat_lto(
372395
// above, this is all mostly handled in C++. Like above, though, we don't
373396
// know much about the memory management here so we err on the side of being
374397
// save and persist everything with the original module.
375-
let mut linker = Linker::new(llmod);
398+
let mut linker = Linker::new(llmod, compiler_builtins_exported_symbols);
376399
for serialized_module in serialized_modules {
377-
let SerializedModuleInfo { module, name, .. } = serialized_module;
400+
let SerializedModuleInfo { module, name, compiler_builtins } = serialized_module;
378401
let _timer = cgcx
379402
.prof
380403
.generic_activity_with_arg_recorder("LLVM_fat_lto_link_module", |recorder| {
381404
recorder.record_arg(format!("{name:?}"))
382405
});
383406
info!("linking {:?}", name);
384407
let data = module.data();
385-
linker.add(data).map_err(|()| write::llvm_err(dcx, LlvmError::LoadBitcode { name }))?;
408+
linker
409+
.add(data, compiler_builtins)
410+
.map_err(|()| write::llvm_err(dcx, LlvmError::LoadBitcode { name }))?;
386411
serialized_bitcode.push(module);
387412
}
388413
drop(linker);
@@ -406,16 +431,24 @@ fn fat_lto(
406431
pub(crate) struct Linker<'a>(&'a mut llvm::Linker<'a>);
407432

408433
impl<'a> Linker<'a> {
409-
pub(crate) fn new(llmod: &'a llvm::Module) -> Self {
410-
unsafe { Linker(llvm::LLVMRustLinkerNew(llmod)) }
434+
pub(crate) fn new(llmod: &'a llvm::Module, builtin_syms: &[*const libc::c_char]) -> Self {
435+
let ptr = builtin_syms.as_ptr();
436+
unsafe {
437+
Linker(llvm::LLVMRustLinkerNew(
438+
llmod,
439+
ptr as *const *const libc::c_char,
440+
builtin_syms.len() as libc::size_t,
441+
))
442+
}
411443
}
412444

413-
pub(crate) fn add(&mut self, bytecode: &[u8]) -> Result<(), ()> {
445+
pub(crate) fn add(&mut self, bytecode: &[u8], compiler_builtins: bool) -> Result<(), ()> {
414446
unsafe {
415447
if llvm::LLVMRustLinkerAdd(
416448
self.0,
417449
bytecode.as_ptr() as *const libc::c_char,
418450
bytecode.len(),
451+
compiler_builtins,
419452
) {
420453
Ok(())
421454
} else {

Diff for: compiler/rustc_codegen_llvm/src/back/write.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -634,12 +634,12 @@ pub(crate) fn link(
634634
let (first, elements) =
635635
modules.split_first().expect("Bug! modules must contain at least one module.");
636636

637-
let mut linker = Linker::new(first.module_llvm.llmod());
637+
let mut linker = Linker::new(first.module_llvm.llmod(), &[]);
638638
for module in elements {
639639
let _timer = cgcx.prof.generic_activity_with_arg("LLVM_link_module", &*module.name);
640640
let buffer = ModuleBuffer::new(module.module_llvm.llmod());
641641
linker
642-
.add(buffer.data())
642+
.add(buffer.data(), false)
643643
.map_err(|()| llvm_err(dcx, LlvmError::SerializeModule { name: &module.name }))?;
644644
}
645645
drop(linker);

Diff for: compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -2356,11 +2356,16 @@ extern "C" {
23562356
out_len: &mut usize,
23572357
) -> *const u8;
23582358

2359-
pub fn LLVMRustLinkerNew(M: &Module) -> &mut Linker<'_>;
2359+
pub fn LLVMRustLinkerNew(
2360+
M: &Module,
2361+
builtin_syms: *const *const c_char,
2362+
len: size_t,
2363+
) -> &mut Linker<'_>;
23602364
pub fn LLVMRustLinkerAdd(
23612365
linker: &Linker<'_>,
23622366
bytecode: *const c_char,
23632367
bytecode_len: usize,
2368+
compiler_builtins: bool,
23642369
) -> bool;
23652370
pub fn LLVMRustLinkerFree<'a>(linker: &'a mut Linker<'a>);
23662371
#[allow(improper_ctypes)]

Diff for: compiler/rustc_llvm/llvm-wrapper/Linker.cpp

+54-14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/IR/DiagnosticInfo.h"
1515
#include "llvm/IR/DiagnosticPrinter.h"
1616
#include "llvm/Linker/IRMover.h"
17+
#include "llvm/Object/ModuleSymbolTable.h"
1718
#include "llvm/Support/Error.h"
1819

1920
#include "LLVMWrapper.h"
@@ -44,7 +45,10 @@ enum class LinkFrom { Dst, Src, Both };
4445
/// entrypoint for this file.
4546
class ModuleLinker {
4647
IRMover &Mover;
48+
const StringSet<> &CompilerBuiltinsSymbols;
49+
StringSet<> UserBuiltinsSymbols;
4750
std::unique_ptr<Module> SrcM;
51+
bool SrcIsCompilerBuiltins;
4852

4953
SetVector<GlobalValue *> ValuesToLink;
5054

@@ -122,11 +126,14 @@ class ModuleLinker {
122126
bool linkIfNeeded(GlobalValue &GV, SmallVectorImpl<GlobalValue *> &GVToClone);
123127

124128
public:
125-
ModuleLinker(IRMover &Mover, std::unique_ptr<Module> SrcM, unsigned Flags,
129+
ModuleLinker(IRMover &Mover, const StringSet<> &CompilerBuiltinsSymbols,
130+
std::unique_ptr<Module> SrcM, bool SrcIsCompilerBuiltins,
131+
unsigned Flags,
126132
std::function<void(Module &, const StringSet<> &)>
127133
InternalizeCallback = {})
128-
: Mover(Mover), SrcM(std::move(SrcM)), Flags(Flags),
129-
InternalizeCallback(std::move(InternalizeCallback)) {}
134+
: Mover(Mover), CompilerBuiltinsSymbols(CompilerBuiltinsSymbols),
135+
SrcM(std::move(SrcM)), SrcIsCompilerBuiltins(SrcIsCompilerBuiltins),
136+
Flags(Flags), InternalizeCallback(std::move(InternalizeCallback)) {}
130137

131138
bool run();
132139
};
@@ -342,6 +349,10 @@ bool ModuleLinker::shouldLinkFromSource(bool &LinkFromSrc,
342349

343350
bool ModuleLinker::linkIfNeeded(GlobalValue &GV,
344351
SmallVectorImpl<GlobalValue *> &GVToClone) {
352+
// If a builtin symbol is defined in a non-compiler-builtins, the symbol of
353+
// compiler-builtins is a non-prevailing symbol.
354+
if (SrcIsCompilerBuiltins && UserBuiltinsSymbols.contains(GV.getName()))
355+
return false;
345356
GlobalValue *DGV = getLinkedToGlobal(&GV);
346357

347358
if (shouldLinkOnlyNeeded()) {
@@ -501,6 +512,27 @@ bool ModuleLinker::run() {
501512
ReplacedDstComdats.insert(DstC);
502513
}
503514

515+
if (SrcIsCompilerBuiltins) {
516+
ModuleSymbolTable SymbolTable;
517+
SymbolTable.addModule(&DstM);
518+
for (auto &Sym : SymbolTable.symbols()) {
519+
uint32_t Flags = SymbolTable.getSymbolFlags(Sym);
520+
if ((Flags & object::BasicSymbolRef::SF_Weak) ||
521+
!(Flags & object::BasicSymbolRef::SF_Global))
522+
continue;
523+
if (GlobalValue *GV = dyn_cast_if_present<GlobalValue *>(Sym)) {
524+
if (CompilerBuiltinsSymbols.contains(GV->getName()))
525+
UserBuiltinsSymbols.insert(GV->getName());
526+
} else if (auto *AS =
527+
dyn_cast_if_present<ModuleSymbolTable::AsmSymbol *>(Sym)) {
528+
if (CompilerBuiltinsSymbols.contains(AS->first))
529+
UserBuiltinsSymbols.insert(AS->first);
530+
} else {
531+
llvm::report_fatal_error("unknown symbol type");
532+
}
533+
}
534+
}
535+
504536
// Alias have to go first, since we are not able to find their comdats
505537
// otherwise.
506538
for (GlobalAlias &GV : llvm::make_early_inc_range(DstM.aliases()))
@@ -617,6 +649,7 @@ namespace {
617649
struct RustLinker {
618650
IRMover Mover;
619651
LLVMContext &Ctx;
652+
StringSet<> CompilerBuiltinsSymbols;
620653

621654
enum Flags {
622655
None = 0,
@@ -634,37 +667,44 @@ struct RustLinker {
634667
/// callback.
635668
///
636669
/// Returns true on error.
637-
bool linkInModule(std::unique_ptr<Module> Src, unsigned Flags = Flags::None,
670+
bool linkInModule(std::unique_ptr<Module> Src, bool SrcIsCompilerBuiltins,
671+
unsigned Flags = Flags::None,
638672
std::function<void(Module &, const StringSet<> &)>
639673
InternalizeCallback = {});
640674

641-
RustLinker(Module &M) : Mover(M), Ctx(M.getContext()) {}
675+
RustLinker(Module &M, StringSet<> CompilerBuiltinsSymbols)
676+
: Mover(M), Ctx(M.getContext()),
677+
CompilerBuiltinsSymbols(CompilerBuiltinsSymbols) {}
642678
};
643679

644680
} // namespace
645681

646682
bool RustLinker::linkInModule(
647-
std::unique_ptr<Module> Src, unsigned Flags,
683+
std::unique_ptr<Module> Src, bool SrcIsCompilerBuiltins, unsigned Flags,
648684
std::function<void(Module &, const StringSet<> &)> InternalizeCallback) {
649-
ModuleLinker ModLinker(Mover, std::move(Src), Flags,
685+
ModuleLinker ModLinker(Mover, CompilerBuiltinsSymbols, std::move(Src),
686+
SrcIsCompilerBuiltins, Flags,
650687
std::move(InternalizeCallback));
651688
return ModLinker.run();
652689
}
653690

654-
extern "C" RustLinker*
655-
LLVMRustLinkerNew(LLVMModuleRef DstRef) {
691+
extern "C" RustLinker *LLVMRustLinkerNew(LLVMModuleRef DstRef, char **Symbols,
692+
size_t Len) {
656693
Module *Dst = unwrap(DstRef);
657-
658-
return new RustLinker(*Dst);
694+
StringSet<> CompilerBuiltinsSymbols;
695+
for (size_t I = 0; I < Len; I++) {
696+
CompilerBuiltinsSymbols.insert(Symbols[I]);
697+
}
698+
return new RustLinker(*Dst, CompilerBuiltinsSymbols);
659699
}
660700

661701
extern "C" void
662702
LLVMRustLinkerFree(RustLinker *L) {
663703
delete L;
664704
}
665705

666-
extern "C" bool
667-
LLVMRustLinkerAdd(RustLinker *L, char *BC, size_t Len) {
706+
extern "C" bool LLVMRustLinkerAdd(RustLinker *L, char *BC, size_t Len,
707+
bool CompilerBuiltins) {
668708
std::unique_ptr<MemoryBuffer> Buf =
669709
MemoryBuffer::getMemBufferCopy(StringRef(BC, Len));
670710

@@ -677,7 +717,7 @@ LLVMRustLinkerAdd(RustLinker *L, char *BC, size_t Len) {
677717

678718
auto Src = std::move(*SrcOrError);
679719

680-
if (L->linkInModule(std::move(Src))) {
720+
if (L->linkInModule(std::move(Src), CompilerBuiltins)) {
681721
LLVMRustSetLastError("");
682722
return false;
683723
}

Diff for: tests/assembly/lto-custom-builtins.rs

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// assembly-output: emit-asm
2+
// compile-flags: --crate-type cdylib -C lto=fat -C prefer-dynamic=no
3+
// only-x86_64-unknown-linux-gnu
4+
5+
#![feature(lang_items, linkage)]
6+
#![no_std]
7+
#![no_main]
8+
9+
#![crate_type = "bin"]
10+
11+
// We want to use customized __subdf3.
12+
// CHECK: .globl __subdf3
13+
// CHECK-NEXT: __subdf3:
14+
// CHECK-NEXT: movq $2, %rax
15+
core::arch::global_asm!(".global __subdf3", "__subdf3:", "mov rax, 2");
16+
17+
// We want to use __addsf3 of compiler-builtins.
18+
// CHECK: .globl __addsf3
19+
// CHECK: __addsf3:
20+
// CHECK: xorl %eax, %eax
21+
// CHECK-NEXT retq
22+
#[no_mangle]
23+
pub extern "C" fn __addsf3() -> i32 {
24+
0
25+
}
26+
27+
// We want to use customized __adddf3.
28+
// CHECK: .globl __adddf3
29+
// CHECK: __adddf3:
30+
// CHECK-NEXT: .cfi_startproc
31+
// CHECK-NOT: movl $1, %eax
32+
// CHECK: movq %xmm0, %rdx
33+
#[no_mangle]
34+
#[linkage = "weak"]
35+
pub extern "C" fn __adddf3() -> i32 {
36+
1
37+
}
38+
39+
#[panic_handler]
40+
fn panic(_panic: &core::panic::PanicInfo<'_>) -> ! {
41+
loop {}
42+
}
43+
44+
#[lang = "eh_personality"]
45+
extern "C" fn eh_personality() {}

0 commit comments

Comments
 (0)