Skip to content

[BOLT] Explicitly check for returns when extending call continuation profile #143295

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: users/aaupov/spr/main.bolt-explicitly-check-for-returns-when-extending-call-continuation-profile
Choose a base branch
from

Conversation

aaupov
Copy link
Contributor

@aaupov aaupov commented Jun 8, 2025

Call continuation logic relies on assumptions about fall-through origin:

  • the branch is external to the function,
  • fall-through start is at the beginning of the block,
  • the block is not an entry point or a landing pad.

Leverage trace information to explicitly check whether the origin is a
return instruction, and defer to checks above only in case of
DSO-external branch source.

This covers both regular and BAT cases, addressing call continuation
fall-through undercounting in the latter mode.

Depends on #143289.

Test Plan: updated callcont-fallthru.s

aaupov added 2 commits June 7, 2025 21:18
Created using spr 1.3.4
Created using spr 1.3.4
@aaupov aaupov marked this pull request as ready for review June 10, 2025 20:35
@llvmbot llvmbot added the BOLT label Jun 10, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2025

@llvm/pr-subscribers-bolt

Author: Amir Ayupov (aaupov)

Changes

Call continuation logic relies on assumptions about fall-through origin:

  • the branch is external to the function,
  • fall-through start is at the beginning of the block,
  • the block is not an entry point or a landing pad.

Leverage trace information to explicitly check whether the origin is a
return instruction, and defer to checks above only in case of
DSO-external branch source.

This covers both regular and BAT cases, addressing call continuation
fall-through undercounting in the latter mode.

Depends on #143289.

Test Plan: updated callcont-fallthru.s


Full diff: https://github.com/llvm/llvm-project/pull/143295.diff

5 Files Affected:

  • (modified) bolt/include/bolt/Profile/BoltAddressTranslation.h (+2-1)
  • (modified) bolt/include/bolt/Profile/DataAggregator.h (+9-3)
  • (modified) bolt/lib/Profile/BoltAddressTranslation.cpp (+8-2)
  • (modified) bolt/lib/Profile/DataAggregator.cpp (+41-30)
  • (modified) bolt/test/X86/callcont-fallthru.s (+40-32)
diff --git a/bolt/include/bolt/Profile/BoltAddressTranslation.h b/bolt/include/bolt/Profile/BoltAddressTranslation.h
index fcc578f35e322..917531964e9b6 100644
--- a/bolt/include/bolt/Profile/BoltAddressTranslation.h
+++ b/bolt/include/bolt/Profile/BoltAddressTranslation.h
@@ -103,7 +103,8 @@ class BoltAddressTranslation {
   /// otherwise.
   std::optional<FallthroughListTy> getFallthroughsInTrace(uint64_t FuncAddress,
                                                           uint64_t From,
-                                                          uint64_t To) const;
+                                                          uint64_t To,
+                                                          bool IsReturn) const;
 
   /// If available, fetch the address of the hot part linked to the cold part
   /// at \p Address. Return 0 otherwise.
diff --git a/bolt/include/bolt/Profile/DataAggregator.h b/bolt/include/bolt/Profile/DataAggregator.h
index 10d96fbeca3e2..96969cf53baca 100644
--- a/bolt/include/bolt/Profile/DataAggregator.h
+++ b/bolt/include/bolt/Profile/DataAggregator.h
@@ -132,6 +132,9 @@ class DataAggregator : public DataReader {
   /// and use them later for processing and assigning profile.
   std::unordered_map<Trace, TakenBranchInfo, TraceHash> TraceMap;
   std::vector<std::pair<Trace, TakenBranchInfo>> Traces;
+  /// Pre-populated addresses of returns, coming from pre-aggregated data or
+  /// disassembly. Used to disambiguate call-continuation fall-throughs.
+  std::unordered_set<uint64_t> Returns;
   std::unordered_map<uint64_t, uint64_t> BasicSamples;
   std::vector<PerfMemSample> MemSamples;
 
@@ -204,8 +207,8 @@ class DataAggregator : public DataReader {
   /// Return a vector of offsets corresponding to a trace in a function
   /// if the trace is valid, std::nullopt otherwise.
   std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
-  getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
-                         uint64_t Count) const;
+  getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace, uint64_t Count,
+                         bool IsReturn) const;
 
   /// Record external entry into the function \p BF.
   ///
@@ -265,11 +268,14 @@ class DataAggregator : public DataReader {
                      uint64_t From, uint64_t To, uint64_t Count,
                      uint64_t Mispreds);
 
+  /// Checks if \p Addr corresponds to a return instruction.
+  bool checkReturn(uint64_t Addr);
+
   /// Register a \p Branch.
   bool doBranch(uint64_t From, uint64_t To, uint64_t Count, uint64_t Mispreds);
 
   /// Register a trace between two LBR entries supplied in execution order.
-  bool doTrace(const Trace &Trace, uint64_t Count);
+  bool doTrace(const Trace &Trace, uint64_t Count, bool IsReturn);
 
   /// Parser helpers
   /// Return false if we exhausted our parser buffer and finished parsing
diff --git a/bolt/lib/Profile/BoltAddressTranslation.cpp b/bolt/lib/Profile/BoltAddressTranslation.cpp
index a253522e4fb15..732229b52c221 100644
--- a/bolt/lib/Profile/BoltAddressTranslation.cpp
+++ b/bolt/lib/Profile/BoltAddressTranslation.cpp
@@ -511,8 +511,8 @@ uint64_t BoltAddressTranslation::translate(uint64_t FuncAddress,
 
 std::optional<BoltAddressTranslation::FallthroughListTy>
 BoltAddressTranslation::getFallthroughsInTrace(uint64_t FuncAddress,
-                                               uint64_t From,
-                                               uint64_t To) const {
+                                               uint64_t From, uint64_t To,
+                                               bool IsReturn) const {
   SmallVector<std::pair<uint64_t, uint64_t>, 16> Res;
 
   // Filter out trivial case
@@ -530,6 +530,12 @@ BoltAddressTranslation::getFallthroughsInTrace(uint64_t FuncAddress,
   auto FromIter = Map.upper_bound(From);
   if (FromIter == Map.begin())
     return Res;
+
+  // For fall-throughs originating at returns, go back one entry to cover call
+  // site.
+  if (IsReturn)
+    --FromIter;
+
   // Skip instruction entries, to create fallthroughs we are only interested in
   // BB boundaries
   do {
diff --git a/bolt/lib/Profile/DataAggregator.cpp b/bolt/lib/Profile/DataAggregator.cpp
index e852a38acaf04..11d282e98413b 100644
--- a/bolt/lib/Profile/DataAggregator.cpp
+++ b/bolt/lib/Profile/DataAggregator.cpp
@@ -721,50 +721,54 @@ bool DataAggregator::doInterBranch(BinaryFunction *FromFunc,
   return true;
 }
 
+bool DataAggregator::checkReturn(uint64_t Addr) {
+  auto isReturn = [&](auto MI) { return MI && BC->MIB->isReturn(*MI); };
+  if (llvm::is_contained(Returns, Addr))
+    return true;
+
+  BinaryFunction *Func = getBinaryFunctionContainingAddress(Addr);
+  if (!Func)
+    return false;
+
+  const uint64_t Offset = Addr - Func->getAddress();
+  if (Func->hasInstructions()
+          ? isReturn(Func->getInstructionAtOffset(Offset))
+          : isReturn(Func->disassembleInstructionAtOffset(Offset))) {
+    Returns.emplace(Addr);
+    return true;
+  }
+  return false;
+}
+
 bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
                               uint64_t Mispreds) {
-  // Returns whether \p Offset in \p Func contains a return instruction.
-  auto checkReturn = [&](const BinaryFunction &Func, const uint64_t Offset) {
-    auto isReturn = [&](auto MI) { return MI && BC->MIB->isReturn(*MI); };
-    return Func.hasInstructions()
-               ? isReturn(Func.getInstructionAtOffset(Offset))
-               : isReturn(Func.disassembleInstructionAtOffset(Offset));
-  };
-
   // Mutates \p Addr to an offset into the containing function, performing BAT
   // offset translation and parent lookup.
   //
-  // Returns the containing function (or BAT parent) and whether the address
-  // corresponds to a return (if \p IsFrom) or a call continuation (otherwise).
+  // Returns the containing function (or BAT parent).
   auto handleAddress = [&](uint64_t &Addr, bool IsFrom) {
     BinaryFunction *Func = getBinaryFunctionContainingAddress(Addr);
     if (!Func) {
       Addr = 0;
-      return std::pair{Func, false};
+      return Func;
     }
 
     Addr -= Func->getAddress();
 
-    bool IsRet = IsFrom && checkReturn(*Func, Addr);
-
     if (BAT)
       Addr = BAT->translate(Func->getAddress(), Addr, IsFrom);
 
     if (BinaryFunction *ParentFunc = getBATParentFunction(*Func))
-      Func = ParentFunc;
+      return ParentFunc;
 
-    return std::pair{Func, IsRet};
+    return Func;
   };
 
-  auto [FromFunc, IsReturn] = handleAddress(From, /*IsFrom*/ true);
-  auto [ToFunc, _] = handleAddress(To, /*IsFrom*/ false);
+  BinaryFunction *FromFunc = handleAddress(From, /*IsFrom*/ true);
+  BinaryFunction *ToFunc = handleAddress(To, /*IsFrom*/ false);
   if (!FromFunc && !ToFunc)
     return false;
 
-  // Ignore returns.
-  if (IsReturn)
-    return true;
-
   // Treat recursive control transfers as inter-branches.
   if (FromFunc == ToFunc && To != 0) {
     recordBranch(*FromFunc, From, To, Count, Mispreds);
@@ -774,7 +778,8 @@ bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
   return doInterBranch(FromFunc, ToFunc, From, To, Count, Mispreds);
 }
 
-bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
+bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count,
+                             bool IsReturn) {
   const uint64_t From = Trace.From, To = Trace.To;
   BinaryFunction *FromFunc = getBinaryFunctionContainingAddress(From);
   BinaryFunction *ToFunc = getBinaryFunctionContainingAddress(To);
@@ -798,8 +803,8 @@ bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
   const uint64_t FuncAddress = FromFunc->getAddress();
   std::optional<BoltAddressTranslation::FallthroughListTy> FTs =
       BAT && BAT->isBATFunction(FuncAddress)
-          ? BAT->getFallthroughsInTrace(FuncAddress, From, To)
-          : getFallthroughsInTrace(*FromFunc, Trace, Count);
+          ? BAT->getFallthroughsInTrace(FuncAddress, From, To, IsReturn)
+          : getFallthroughsInTrace(*FromFunc, Trace, Count, IsReturn);
   if (!FTs) {
     LLVM_DEBUG(dbgs() << "Invalid trace " << Trace << '\n');
     NumInvalidTraces += Count;
@@ -821,7 +826,7 @@ bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
 
 std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
 DataAggregator::getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
-                                       uint64_t Count) const {
+                                       uint64_t Count, bool IsReturn) const {
   SmallVector<std::pair<uint64_t, uint64_t>, 16> Branches;
 
   BinaryContext &BC = BF.getBinaryContext();
@@ -855,9 +860,13 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
 
   // Adjust FromBB if the first LBR is a return from the last instruction in
   // the previous block (that instruction should be a call).
-  if (Trace.Branch != Trace::FT_ONLY && !BF.containsAddress(Trace.Branch) &&
-      From == FromBB->getOffset() && !FromBB->isEntryPoint() &&
-      !FromBB->isLandingPad()) {
+  if (IsReturn) {
+    if (From)
+      FromBB = BF.getBasicBlockContainingOffset(From - 1);
+    else
+      LLVM_DEBUG(dbgs() << "return to the function start: " << Trace << '\n');
+  } else if (Trace.Branch == Trace::EXTERNAL && From == FromBB->getOffset() &&
+             !FromBB->isEntryPoint() && !FromBB->isLandingPad()) {
     const BinaryBasicBlock *PrevBB =
         BF.getLayout().getBlock(FromBB->getIndex() - 1);
     if (PrevBB->getSuccessor(FromBB->getLabel())) {
@@ -1545,11 +1554,13 @@ void DataAggregator::processBranchEvents() {
                      TimerGroupName, TimerGroupDesc, opts::TimeAggregator);
 
   for (const auto &[Trace, Info] : Traces) {
-    if (Trace.Branch != Trace::FT_ONLY &&
+    bool IsReturn = checkReturn(Trace.Branch);
+    // Ignore returns.
+    if (!IsReturn && Trace.Branch != Trace::FT_ONLY &&
         Trace.Branch != Trace::FT_EXTERNAL_ORIGIN)
       doBranch(Trace.Branch, Trace.From, Info.TakenCount, Info.MispredCount);
     if (Trace.To != Trace::BR_ONLY)
-      doTrace(Trace, Info.TakenCount);
+      doTrace(Trace, Info.TakenCount, IsReturn);
   }
   printBranchSamplesDiagnostics();
 }
diff --git a/bolt/test/X86/callcont-fallthru.s b/bolt/test/X86/callcont-fallthru.s
index 4994cfb541eef..c2ef024db9475 100644
--- a/bolt/test/X86/callcont-fallthru.s
+++ b/bolt/test/X86/callcont-fallthru.s
@@ -4,29 +4,43 @@
 # RUN: %clang %cflags -fpic -shared -xc /dev/null -o %t.so
 ## Link against a DSO to ensure PLT entries.
 # RUN: %clangxx %cxxflags %s %t.so -o %t -Wl,-q -nostdlib
-# RUN: link_fdata %s %t %t.pat PREAGGT1
-# RUN: link_fdata %s %t %t.pat2 PREAGGT2
-# RUN-DISABLED: link_fdata %s %t %t.patplt PREAGGPLT
+# Trace to a call continuation, not a landing pad/entry point
+# RUN: link_fdata %s %t %t.pa-base PREAGG-BASE
+# Trace from a return to a landing pad/entry point call continuation
+# RUN: link_fdata %s %t %t.pa-ret PREAGG-RET
+# Trace from an external location to a landing pad/entry point call continuation
+# RUN: link_fdata %s %t %t.pa-ext PREAGG-EXT
+# RUN-DISABLED: link_fdata %s %t %t.pa-plt PREAGG-PLT
 
 # RUN: llvm-strip --strip-unneeded %t -o %t.strip
 # RUN: llvm-objcopy --remove-section=.eh_frame %t.strip %t.noeh
 
 ## Check pre-aggregated traces attach call continuation fallthrough count
-# RUN: llvm-bolt %t.noeh --pa -p %t.pat -o %t.out \
-# RUN:   --print-cfg --print-only=main | FileCheck %s
-
-## Check pre-aggregated traces don't attach call continuation fallthrough count
-## to secondary entry point (unstripped)
-# RUN: llvm-bolt %t --pa -p %t.pat2 -o %t.out \
-# RUN:   --print-cfg --print-only=main | FileCheck %s --check-prefix=CHECK3
-## Check pre-aggregated traces don't attach call continuation fallthrough count
-## to landing pad (stripped, LP)
-# RUN: llvm-bolt %t.strip --pa -p %t.pat2 -o %t.out \
-# RUN:   --print-cfg --print-only=main | FileCheck %s --check-prefix=CHECK3
+## in the basic case (not an entry point, not a landing pad).
+# RUN: llvm-bolt %t.noeh --pa -p %t.pa-base -o %t.out \
+# RUN:   --print-cfg --print-only=main | FileCheck %s --check-prefix=CHECK-BASE
+
+## Check pre-aggregated traces from a return attach call continuation
+## fallthrough count to secondary entry point (unstripped)
+# RUN: llvm-bolt %t --pa -p %t.pa-ret -o %t.out \
+# RUN:   --print-cfg --print-only=main | FileCheck %s --check-prefix=CHECK-ATTACH
+## Check pre-aggregated traces from a return attach call continuation
+## fallthrough count to landing pad (stripped, landing pad)
+# RUN: llvm-bolt %t.strip --pa -p %t.pa-ret -o %t.out \
+# RUN:   --print-cfg --print-only=main | FileCheck %s --check-prefix=CHECK-ATTACH
+
+## Check pre-aggregated traces from external location don't attach call
+## continuation fallthrough count to secondary entry point (unstripped)
+# RUN: llvm-bolt %t --pa -p %t.pa-ext -o %t.out \
+# RUN:   --print-cfg --print-only=main | FileCheck %s --check-prefix=CHECK-SKIP
+## Check pre-aggregated traces from external location don't attach call
+## continuation fallthrough count to landing pad (stripped, landing pad)
+# RUN: llvm-bolt %t.strip --pa -p %t.pa-ext -o %t.out \
+# RUN:   --print-cfg --print-only=main | FileCheck %s --check-prefix=CHECK-SKIP
 
 ## Check pre-aggregated traces don't report zero-sized PLT fall-through as
 ## invalid trace
-# RUN-DISABLED: llvm-bolt %t.strip --pa -p %t.patplt -o %t.out | FileCheck %s \
+# RUN-DISABLED: llvm-bolt %t.strip --pa -p %t.pa-plt -o %t.out | FileCheck %s \
 # RUN-DISABLED:   --check-prefix=CHECK-PLT
 # CHECK-PLT: traces mismatching disassembled function contents: 0
 
@@ -56,11 +70,11 @@ main:
 Ltmp0_br:
 	callq	puts@PLT
 ## Check PLT traces are accepted
-# PREAGGPLT: T #Ltmp0_br# #puts@plt# #puts@plt# 3
+# PREAGG-PLT: T #Ltmp0_br# #puts@plt# #puts@plt# 3
 ## Target is an external-origin call continuation
-# PREAGGT1: T X:0 #Ltmp1# #Ltmp4_br# 2
-# CHECK:      callq puts@PLT
-# CHECK-NEXT: count: 2
+# PREAGG-BASE: T X:0 #Ltmp1# #Ltmp4_br# 2
+# CHECK-BASE:      callq puts@PLT
+# CHECK-BASE-NEXT: count: 2
 
 Ltmp1:
 	movq	-0x10(%rbp), %rax
@@ -71,24 +85,18 @@ Ltmp4:
 	cmpl	$0x0, -0x14(%rbp)
 Ltmp4_br:
 	je	Ltmp0
-# CHECK2:      je .Ltmp0
-# CHECK2-NEXT: count: 3
 
 	movl	$0xa, -0x18(%rbp)
 	callq	foo
 ## Target is a binary-local call continuation
-# PREAGGT1: T #Lfoo_ret# #Ltmp3# #Ltmp3_br# 1
-# CHECK:      callq foo
-# CHECK-NEXT: count: 1
-
-## PLT call continuation fallthrough spanning the call
-# CHECK2:      callq foo
-# CHECK2-NEXT: count: 3
-
+# PREAGG-RET: T #Lfoo_ret# #Ltmp3# #Ltmp3_br# 1
 ## Target is a secondary entry point (unstripped) or a landing pad (stripped)
-# PREAGGT2: T X:0 #Ltmp3# #Ltmp3_br# 2
-# CHECK3:      callq foo
-# CHECK3-NEXT: count: 0
+# PREAGG-EXT: T X:0 #Ltmp3# #Ltmp3_br# 1
+
+# CHECK-ATTACH:      callq foo
+# CHECK-ATTACH-NEXT: count: 1
+# CHECK-SKIP:        callq foo
+# CHECK-SKIP-NEXT:   count: 0
 
 Ltmp3:
 	cmpl	$0x0, -0x18(%rbp)

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants