Skip to content

[SYCL][UR][L0 v2] fix use after free on queue #18101

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions unified-runtime/scripts/templates/queue_api.hpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ from templates import helper as th
struct ur_queue_t_ {
virtual ~ur_queue_t_();

virtual void deferEventFree(ur_event_handle_t hEvent) = 0;

%for obj in th.get_queue_related_functions(specs, n, tags):
%if not 'Release' in obj['name'] and not 'Retain' in obj['name']:
virtual ${x}_result_t ${th.transform_queue_related_function_name(n, tags, obj, format=["type"])} = 0;
Expand Down
84 changes: 38 additions & 46 deletions unified-runtime/source/adapters/level_zero/v2/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ uint64_t event_profiling_data_t::getEventEndTimestamp() {
return adjustedEventEndTimestamp;
}

void event_profiling_data_t::reset() {
// This ensures that the event is consider as not timestamped.
// We can't touch the recordEventEndTimestamp
// as it may still be overwritten by the driver.
// In case event is resued and recordStartTimestamp
// is called again, adjustedEventEndTimestamp will always be updated correctly
// to the new value as we wait for the event to be signaled.
// If the event is reused on another queue, this means that the original
// queue must have been destroyed (and the even pool released back to the
// context) and the timstamp is already wrriten, so there's no race-condition
// possible.
adjustedEventStartTimestamp = 0;
adjustedEventEndTimestamp = 0;
}

void event_profiling_data_t::recordStartTimestamp(ur_device_handle_t hDevice) {
zeTimerResolution = hDevice->ZeDeviceProperties->timerResolution;
timestampMaxValue = hDevice->getTimestampMask();
Expand Down Expand Up @@ -98,16 +113,22 @@ void ur_event_handle_t_::resetQueueAndCommand(ur_queue_t_ *hQueue,
ur_command_t commandType) {
this->hQueue = hQueue;
this->commandType = commandType;
profilingData = event_profiling_data_t(getZeEvent());

if (hQueue) {
UR_CALL_THROWS(hQueue->queueGetInfo(UR_QUEUE_INFO_DEVICE, sizeof(hDevice),
reinterpret_cast<void *>(&hDevice),
nullptr));
} else {
hDevice = nullptr;
}

profilingData.reset();
}

void ur_event_handle_t_::recordStartTimestamp() {
assert(hQueue); // queue must be set before calling this

ur_device_handle_t hDevice;
UR_CALL_THROWS(hQueue->queueGetInfo(UR_QUEUE_INFO_DEVICE, sizeof(hDevice),
reinterpret_cast<void *>(&hDevice),
nullptr));
// queue and device must be set before calling this
assert(hQueue);
assert(hDevice);

profilingData.recordStartTimestamp(hDevice);
}
Expand Down Expand Up @@ -141,33 +162,17 @@ ur_result_t ur_event_handle_t_::retain() {
return UR_RESULT_SUCCESS;
}

ur_result_t ur_event_handle_t_::releaseDeferred() {
assert(zeEventQueryStatus(getZeEvent()) == ZE_RESULT_SUCCESS);
assert(RefCount.load() == 0);

return this->forceRelease();
}

ur_result_t ur_event_handle_t_::release() {
if (!RefCount.decrementAndTest())
return UR_RESULT_SUCCESS;

// Need to take a lock before checking if the event is timestamped.
std::unique_lock<ur_shared_mutex> lock(Mutex);

if (isTimestamped() && !getEventEndTimestamp()) {
// L0 will write end timestamp to this event some time in the future,
// so we can't release it yet.
assert(hQueue);
hQueue->deferEventFree(this);
return UR_RESULT_SUCCESS;
if (event_pool) {
event_pool->free(this);
} else {
std::get<v2::raii::ze_event_handle_t>(hZeEvent).release();
delete this;
}

// Need to unlock now, as forceRelease might deallocate memory backing
// the Mutex.
lock.unlock();

return this->forceRelease();
return UR_RESULT_SUCCESS;
}

bool ur_event_handle_t_::isTimestamped() const {
Expand All @@ -189,6 +194,8 @@ ur_context_handle_t ur_event_handle_t_::getContext() const { return hContext; }

ur_command_t ur_event_handle_t_::getCommandType() const { return commandType; }

ur_device_handle_t ur_event_handle_t_::getDevice() const { return hDevice; }

ur_event_handle_t_::ur_event_handle_t_(
ur_context_handle_t hContext,
v2::raii::cache_borrowed_event eventAllocation, v2::event_pool *pool)
Expand All @@ -209,16 +216,6 @@ ur_event_handle_t_::ur_event_handle_t_(
,
nullptr) {}

ur_result_t ur_event_handle_t_::forceRelease() {
if (event_pool) {
event_pool->free(this);
} else {
std::get<v2::raii::ze_event_handle_t>(hZeEvent).release();
delete this;
}
return UR_RESULT_SUCCESS;
}

namespace ur::level_zero {
ur_result_t urEventRetain(ur_event_handle_t hEvent) try {
return hEvent->retain();
Expand Down Expand Up @@ -323,19 +320,14 @@ ur_result_t urEventGetProfilingInfo(
}
}

auto hQueue = hEvent->getQueue();
if (!hQueue) {
auto hDevice = hEvent->getDevice();
if (!hDevice) {
// no command has been enqueued with this event yet
return UR_RESULT_ERROR_PROFILING_INFO_NOT_AVAILABLE;
}

ze_kernel_timestamp_result_t tsResult;

ur_device_handle_t hDevice;
UR_CALL_THROWS(hQueue->queueGetInfo(UR_QUEUE_INFO_DEVICE, sizeof(hDevice),
reinterpret_cast<void *>(&hDevice),
nullptr));

auto zeTimerResolution = hDevice->ZeDeviceProperties->timerResolution;
auto timestampMaxValue = hDevice->getTimestampMask();

Expand Down
11 changes: 8 additions & 3 deletions unified-runtime/source/adapters/level_zero/v2/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ struct event_profiling_data_t {
bool recordingStarted() const;
bool recordingEnded() const;

// clear the profiling data, allowing the event to be reused
// for a new command
void reset();

private:
ze_event_handle_t hZeEvent;

Expand Down Expand Up @@ -64,9 +68,6 @@ struct ur_event_handle_t_ : _ur_object {
// Set the queue and command that this event is associated with
void resetQueueAndCommand(ur_queue_t_ *hQueue, ur_command_t commandType);

// releases event immediately
ur_result_t forceRelease();

void reset();
ze_event_handle_t getZeEvent() const;

Expand Down Expand Up @@ -95,6 +96,9 @@ struct ur_event_handle_t_ : _ur_object {
// Get the type of the command that this event is associated with
ur_command_t getCommandType() const;

// Get the device associated with this event
ur_device_handle_t getDevice() const;

// Record the start timestamp of the event, to be obtained by
// urEventGetProfilingInfo. resetQueueAndCommand should be
// called before this.
Expand Down Expand Up @@ -122,6 +126,7 @@ struct ur_event_handle_t_ : _ur_object {
// commands
ur_queue_t_ *hQueue = nullptr;
ur_command_t commandType = UR_COMMAND_FORCE_UINT32;
ur_device_handle_t hDevice = nullptr;

v2::event_flags_t flags;
event_profiling_data_t profilingData;
Expand Down
2 changes: 0 additions & 2 deletions unified-runtime/source/adapters/level_zero/v2/queue_api.hpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ ur_queue_immediate_in_order_t::queueGetInfo(ur_queue_info_t propName,
return UR_RESULT_SUCCESS;
}

void ur_queue_immediate_in_order_t::deferEventFree(ur_event_handle_t hEvent) {
auto commandListLocked = commandListManager.lock();
deferredEvents.push_back(hEvent);
}

ur_result_t ur_queue_immediate_in_order_t::queueGetNativeHandle(
ur_queue_native_desc_t * /*pDesc*/, ur_native_handle_t *phNativeQueue) {
*phNativeQueue = reinterpret_cast<ur_native_handle_t>(
Expand All @@ -160,12 +155,6 @@ ur_result_t ur_queue_immediate_in_order_t::queueFinish() {
ZE2UR_CALL(zeCommandListHostSynchronize,
(commandListLocked->getZeCommandList(), UINT64_MAX));

// Free deferred events
for (auto &hEvent : deferredEvents) {
UR_CALL(hEvent->releaseDeferred());
}
deferredEvents.clear();

// Free deferred kernels
for (auto &hKernel : submittedKernels) {
UR_CALL(hKernel->release());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ struct ur_queue_immediate_in_order_t : _ur_object, public ur_queue_t_ {
ur_queue_flags_t flags;

lockable<ur_command_list_manager> commandListManager;
std::vector<ur_event_handle_t> deferredEvents;
std::vector<ur_kernel_handle_t> submittedKernels;

wait_list_view
Expand All @@ -46,8 +45,6 @@ struct ur_queue_immediate_in_order_t : _ur_object, public ur_queue_t_ {
ur_event_handle_t *hUserEvent,
ur_command_t commandType);

void deferEventFree(ur_event_handle_t hEvent) override;

ur_result_t enqueueGenericFillUnlocked(
ur_mem_buffer_t *hBuffer, size_t offset, size_t patternSize,
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
Expand Down
29 changes: 4 additions & 25 deletions unified-runtime/test/adapters/level_zero/v2/event_pool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,39 +272,18 @@ TEST_P(EventPoolTestWithQueue, WithTimestamp) {
&hDevice, nullptr));

ur_event_handle_t first;
ze_event_handle_t zeFirst;
{
ASSERT_SUCCESS(
urEnqueueTimestampRecordingExp(queue, false, 1, &hEvent, &first));
zeFirst = first->getZeEvent();

urEventRelease(first); // should not actually release the event until
// recording is completed
}
ur_event_handle_t second;
ze_event_handle_t zeSecond;
{
ASSERT_SUCCESS(urEnqueueEventsWaitWithBarrier(queue, 0, nullptr, &second));
zeSecond = second->getZeEvent();
ASSERT_SUCCESS(urEventRelease(second));
}
ASSERT_NE(first, second);
ASSERT_NE(zeFirst, zeSecond);
ASSERT_SUCCESS(urEnqueueEventsWaitWithBarrier(queue, 0, nullptr, &second));
// even if the event is reused, it should not be timestamped anymore
ASSERT_FALSE(second->isTimestamped());
ASSERT_SUCCESS(urEventRelease(second));

ASSERT_EQ(zeEventHostSignal(zeEvent.get()), ZE_RESULT_SUCCESS);

ASSERT_SUCCESS(urQueueFinish(queue));

// Now, the first event should be avilable for reuse
ur_event_handle_t third;
ze_event_handle_t zeThird;
{
ASSERT_SUCCESS(urEnqueueEventsWaitWithBarrier(queue, 0, nullptr, &third));
zeThird = third->getZeEvent();
ASSERT_SUCCESS(urEventRelease(third));

ASSERT_FALSE(third->isTimestamped());
}
ASSERT_EQ(first, third);
ASSERT_EQ(zeFirst, zeThird);
}
10 changes: 10 additions & 0 deletions unified-runtime/test/adapters/level_zero/ze_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,18 @@
#include <ur_api.h>
#include <uur/fixtures.h>

static bool ze_initialized = false;

std::unique_ptr<_ze_event_handle_t, std::function<void(ze_event_handle_t)>>
createZeEvent(ur_context_handle_t hContext, ur_device_handle_t hDevice) {
if (!ze_initialized) {
ze_result_t result = zeInit(ZE_INIT_FLAG_GPU_ONLY);
if (result != ZE_RESULT_SUCCESS) {
return nullptr;
}
ze_initialized = true;
}

ze_event_pool_desc_t desc;
desc.stype = ZE_STRUCTURE_TYPE_EVENT_POOL_DESC;
desc.pNext = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,61 @@ TEST_P(urEnqueueTimestampRecordingExpTest, SuccessBlocking) {
ASSERT_SUCCESS(urEventRelease(event));
}

TEST_P(urEnqueueTimestampRecordingExpTest,
ReleaseEventWhileTimestampWritePending) {
void *ptr;
ASSERT_SUCCESS(
urUSMSharedAlloc(context, device, nullptr, nullptr, 1024 * 1024, &ptr));

// Enqueue an operation to keep the device busy
uint8_t pattern = 0xFF;
ASSERT_SUCCESS(urEnqueueUSMFill(queue, ptr, sizeof(uint8_t), &pattern,
1024 * 1024, 0, nullptr, nullptr));

ur_event_handle_t event1 = nullptr;
ASSERT_SUCCESS(
urEnqueueTimestampRecordingExp(queue, false, 0, nullptr, &event1));
ASSERT_SUCCESS(urEventRelease(event1));

ur_event_handle_t event2 = nullptr;
ASSERT_SUCCESS(urEnqueueUSMFill(queue, ptr, sizeof(uint8_t), &pattern,
1024 * 1024, 0, nullptr, &event2));

// Make sure the new event does not contain profiling info (in case it's reused
// by the adapter)
ASSERT_EQ(urEventGetProfilingInfo(event2, UR_PROFILING_INFO_COMMAND_QUEUED,
sizeof(uint64_t), nullptr, nullptr),
UR_RESULT_ERROR_PROFILING_INFO_NOT_AVAILABLE);
ASSERT_SUCCESS(urEventRelease(event2));
ASSERT_SUCCESS(urUSMFree(context, ptr));
}

TEST_P(urEnqueueTimestampRecordingExpTest, ReleaseEventAfterQueueRelease) {
void *ptr;
ASSERT_SUCCESS(
urUSMSharedAlloc(context, device, nullptr, nullptr, 1024 * 1024, &ptr));

// Enqueue an operation to keep the device busy
uint8_t pattern = 0xFF;
ASSERT_SUCCESS(urEnqueueUSMFill(queue, ptr, sizeof(uint8_t), &pattern,
1024 * 1024, 0, nullptr, nullptr));

ur_event_handle_t event1 = nullptr;
ASSERT_SUCCESS(
urEnqueueTimestampRecordingExp(queue, false, 0, nullptr, &event1));

ASSERT_SUCCESS(urQueueRelease(queue));
queue = nullptr;

uint64_t queuedTime = 0;
ASSERT_SUCCESS(
urEventGetProfilingInfo(event1, UR_PROFILING_INFO_COMMAND_QUEUED,
sizeof(uint64_t), &queuedTime, nullptr));

ASSERT_SUCCESS(urEventRelease(event1));
ASSERT_SUCCESS(urUSMFree(context, ptr));
}

TEST_P(urEnqueueTimestampRecordingExpTest, InvalidNullHandleQueue) {
ur_event_handle_t event = nullptr;
ASSERT_EQ_RESULT(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,30 @@ TEST_P(urEventGetProfilingInfoTest, Success) {
UR_PROFILING_INFO_COMMAND_COMPLETE);
}

TEST_P(urEventGetProfilingInfoTest, ReleaseEventAfterQueueRelease) {
void *ptr;
ASSERT_SUCCESS(
urUSMSharedAlloc(context, device, nullptr, nullptr, 1024 * 1024, &ptr));

// Enqueue an operation to keep the device busy
uint8_t pattern = 0xFF;
ur_event_handle_t event1;
ASSERT_SUCCESS(urEnqueueUSMFill(queue, ptr, sizeof(uint8_t), &pattern,
1024 * 1024, 0, nullptr, &event1));

ASSERT_SUCCESS(urQueueRelease(queue));
queue = nullptr;

uint64_t queuedTime = 0;
auto ret = urEventGetProfilingInfo(event1, UR_PROFILING_INFO_COMMAND_QUEUED,
sizeof(uint64_t), &queuedTime, nullptr);
ASSERT_TRUE(ret == UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION ||
ret == UR_RESULT_SUCCESS);

ASSERT_SUCCESS(urEventRelease(event1));
ASSERT_SUCCESS(urUSMFree(context, ptr));
}

TEST_P(urEventGetProfilingInfoTest, InvalidNullHandle) {
UUR_KNOWN_FAILURE_ON(uur::NativeCPU{});

Expand Down