Skip to content

Commit

Permalink
fix(core): Backtrace Filtering (#11)
Browse files Browse the repository at this point in the history
This PR improves the stacktrace by:
1) Filter out unnecessary stack frames in `func.h`, `func_details.h`,
etc;
2) Fix a bug when handling errors that have no error message, e.g.
`NotImplementedError`
  • Loading branch information
potatomashed authored Jan 23, 2025
1 parent 7e5eefa commit b8a7366
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 32 deletions.
22 changes: 18 additions & 4 deletions cpp/traceback.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef _MSC_VER
#include "./traceback.h"
#include <backtrace.h>
#include <cstring>
#include <cxxabi.h>
#include <iostream>
#include <mlc/c_api.h>
Expand All @@ -11,7 +12,8 @@ namespace {

static int32_t MLC_TRACEBACK_LIMIT = GetTracebackLimit();
static backtrace_state *_bt_state = backtrace_create_state(
/*filename=*/nullptr,
/*filename=*/
nullptr,
/*threaded=*/1,
/*error_callback=*/
+[](void * /*data*/, const char *msg, int /*errnum*/) -> void {
Expand Down Expand Up @@ -43,7 +45,8 @@ MLCByteArray TracebackImpl() {
}
if (!symbol) {
backtrace_syminfo(
/*state=*/_bt_state, /*addr=*/pc, /*callback=*/
/*state=*/
_bt_state, /*addr=*/pc, /*callback=*/
+[](void *data, uintptr_t /*pc*/, const char *symname, uintptr_t /*symval*/, uintptr_t /*symsize*/) {
*reinterpret_cast<const char **>(data) = symname;
},
Expand All @@ -54,11 +57,22 @@ MLCByteArray TracebackImpl() {
if (IsForeignFrame(filename, lineno, symbol)) {
return 1;
}
reinterpret_cast<TracebackStorage *>(data)->Append(filename)->Append(lineno)->Append(symbol);
if (!EndsWith(filename, "core.pyx") && //
!EndsWith(filename, "func.h") && //
!EndsWith(filename, "func_details.h") && //
!EndsWith(filename, "mlc/core/all.h") && //
!EndsWith(filename, "mlc/base/all.h") //
) {
TracebackStorage *storage = reinterpret_cast<TracebackStorage *>(data);
storage->Append(filename);
storage->Append(lineno);
storage->Append(symbol);
}
return 0;
};
backtrace_full(
/*state=*/_bt_state, /*skip=*/1, /*callback=*/callback,
/*state=*/
_bt_state, /*skip=*/1, /*callback=*/callback,
/*error_callback=*/[](void * /*data*/, const char * /*msg*/, int /*errnum*/) {},
/*data=*/&storage);
return {static_cast<int64_t>(storage.buffer.size()), storage.buffer.data()};
Expand Down
12 changes: 8 additions & 4 deletions include/mlc/printer/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ struct IRPrinterObj : public Object {
Optional<Str> name = (*it).second->name;
return Id(name.value());
}
// Legalize characters in the name
for (char &c : name_hint) {
if (c != '_' && !std::isalnum(c)) {
c = '_';
bool needs_normalize =
std::any_of(name_hint->begin(), name_hint->end(), [](char c) { return c != '_' && !std::isalnum(c); });
if (needs_normalize) {
name_hint = Str(name_hint->c_str());
for (char &c : name_hint) {
if (c != '_' && !std::isalnum(c)) {
c = '_';
}
}
}
// Find a unique name
Expand Down
1 change: 1 addition & 0 deletions python/mlc/_cython/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _bytes_info() -> bytes:
bytes_info.append(str_py2c(code.co_filename))
tb = tb.tb_next
bytes_info.append(str_py2c(str(exception)))
bytes_info = [b if b else b"<null>" for b in bytes_info]
bytes_info.reverse()
bytes_info.append(b"")
return b"\0".join(bytes_info)
Expand Down
56 changes: 32 additions & 24 deletions tests/python/test_cython_traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
from io import StringIO

import mlc
import pytest


def test_throw_exception_from_c() -> None:
func = mlc.Func.get("mlc.testing.throw_exception_from_c")
try:
with pytest.raises(ValueError) as exc_info:
func()
except ValueError:
msg = traceback.format_exc().strip().splitlines()
assert "Traceback (most recent call last)" in msg[0]
assert "in test_throw_exception_from_c" in msg[1]
assert "ValueError: This is an error message" in msg[-1]
if mlc._cython.SYSTEM != "Darwin":
# FIXME: for some reason, `c_api.cc` is not in the traceback on macOS
# TODO: fix macOS libbacktrace integration on macOS
assert "c_api.cc" in msg[-3]

msg = traceback.format_exception(exc_info.type, exc_info.value, exc_info.tb)
msg = "".join(msg).strip().splitlines()
assert "Traceback (most recent call last)" in msg[0]
assert "in test_throw_exception_from_c" in msg[1]
assert "ValueError: This is an error message" in msg[-1]
if mlc._cython.SYSTEM != "Darwin":
assert "c_api.cc" in msg[-3]


def test_throw_exception_from_ffi() -> None:
Expand All @@ -41,18 +41,26 @@ def _inner() -> None:

_inner()

try:
with pytest.raises(ValueError) as exc_info:
mlc.Func.get("mlc.testing.throw_exception_from_ffi_in_c")(throw_ValueError)

msg = traceback.format_exception(exc_info.type, exc_info.value, exc_info.tb)
msg = "".join(msg).strip().splitlines()
assert "Traceback (most recent call last)" in msg[0]
assert "in test_throw_exception_from_ffi_in_c" in msg[1]
assert "ValueError: This is a ValueError" in msg[-1]
if mlc._cython.SYSTEM != "Darwin":
idx_c_api_tests = next(i for i, line in enumerate(msg) if "c_api.cc" in line)
idx_handle_error = next(i for i, line in enumerate(msg) if "_func_safe_call_impl" in line)
assert idx_c_api_tests < idx_handle_error


def test_throw_NotImplementedError_from_ffi_in_c() -> None:
def throw_ValueError() -> None:
def _inner() -> None:
raise NotImplementedError

_inner()

with pytest.raises(NotImplementedError):
mlc.Func.get("mlc.testing.throw_exception_from_ffi_in_c")(throw_ValueError)
except ValueError:
msg = traceback.format_exc().strip().splitlines()
assert "Traceback (most recent call last)" in msg[0]
assert "in test_throw_exception_from_ffi_in_c" in msg[1]
assert "ValueError: This is a ValueError" in msg[-1]
if mlc._cython.SYSTEM != "Darwin":
# FIXME: for some reason, `c_api.cc` is not in the traceback on macOS
# TODO: fix macOS libbacktrace integration on macOS
idx_c_api_tests = next(i for i, line in enumerate(msg) if "c_api.cc" in line)
idx_handle_error = next(
i for i, line in enumerate(msg) if "_func_safe_call_impl" in line
)
assert idx_c_api_tests < idx_handle_error
6 changes: 6 additions & 0 deletions tests/python/test_printer_ir_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ def test_var_print() -> None:
assert mlcp.to_python(a) == "a"


def test_var_print_name_normalize() -> None:
a = Var(name="a/0/b")
assert mlcp.to_python(a) == "a_0_b"
assert mlcp.to_python(a) == "a_0_b"


def test_add_print() -> None:
a = Var(name="a")
b = Var(name="b")
Expand Down

0 comments on commit b8a7366

Please # to comment.