-- v20: rtworkq: Release cancelled work items. mfplat/tests: Validate MFCancelWorkItem releases the callback. rtworkq: Avoid use-after-free.
From: Yuxuan Shui yshui@codeweavers.com
queue_release_pending_item releases the work_item reference but later accesses `item->queue`, which is a potential use-after-free. --- dlls/rtworkq/queue.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/dlls/rtworkq/queue.c b/dlls/rtworkq/queue.c index eebb096ad31..00b77bf6953 100644 --- a/dlls/rtworkq/queue.c +++ b/dlls/rtworkq/queue.c @@ -734,9 +734,10 @@ static HRESULT invoke_async_callback(IRtwqAsyncResult *result) * removed from pending items when it got canceled. */ static BOOL queue_release_pending_item(struct work_item *item) { + struct queue *queue = item->queue; BOOL ret = FALSE;
- EnterCriticalSection(&item->queue->cs); + EnterCriticalSection(&queue->cs); if (item->key) { list_remove(&item->entry); @@ -744,7 +745,7 @@ static BOOL queue_release_pending_item(struct work_item *item) item->key = 0; IUnknown_Release(&item->IUnknown_iface); } - LeaveCriticalSection(&item->queue->cs); + LeaveCriticalSection(&queue->cs); return ret; }
From: Yuxuan Shui yshui@codeweavers.com
--- dlls/mfplat/tests/mfplat.c | 64 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+)
diff --git a/dlls/mfplat/tests/mfplat.c b/dlls/mfplat/tests/mfplat.c index 904e8808305..abad318f886 100644 --- a/dlls/mfplat/tests/mfplat.c +++ b/dlls/mfplat/tests/mfplat.c @@ -585,6 +585,51 @@ static const IMFAsyncCallbackVtbl test_async_callback_result_vtbl = test_async_callback_result_Invoke, };
+/* Test context for MFAddPeriodCallback. */ +struct test_context +{ + IUnknown IUnknown_iface; + LONG refcount; +}; + +static struct test_context *test_context_from_IUnknown(IUnknown *iface) +{ + return CONTAINING_RECORD(iface, struct test_context, IUnknown_iface); +} + +static HRESULT WINAPI test_context_QueryInterface(IUnknown *iface, REFIID riid, void **obj) +{ + if (IsEqualIID(riid, &IID_IUnknown)) + { + *obj = iface; + IUnknown_AddRef(iface); + return S_OK; + } + + *obj = NULL; + return E_NOINTERFACE; +} + +static ULONG WINAPI test_context_AddRef(IUnknown *iface) +{ + struct test_context *context = test_context_from_IUnknown(iface); + return InterlockedIncrement(&context->refcount); +} + +static ULONG WINAPI test_context_Release(IUnknown *iface) +{ + struct test_context *context = test_context_from_IUnknown(iface); + ULONG refcount = InterlockedDecrement(&context->refcount); + return refcount; +} + +static const IUnknownVtbl test_context_vtbl = +{ + test_context_QueryInterface, + test_context_AddRef, + test_context_Release, +}; + static DWORD wait_async_callback_result(IMFAsyncCallback *iface, DWORD timeout, IMFAsyncResult **result) { struct test_callback *callback = impl_from_IMFAsyncCallback(iface); @@ -3599,6 +3644,7 @@ static void test_scheduled_items(void) IMFAsyncResult *result; MFWORKITEM_KEY key, key2; HRESULT hr; + ULONG refcount;
callback = create_test_callback(NULL);
@@ -3611,6 +3657,9 @@ static void test_scheduled_items(void) hr = MFCancelWorkItem(key); ok(hr == S_OK, "Failed to cancel item, hr %#lx.\n", hr);
+ refcount = IMFAsyncCallback_Release(&callback->IMFAsyncCallback_iface); + todo_wine ok(refcount == 0, "Unexpected refcount %lu.\n", refcount); + hr = MFCancelWorkItem(key); ok(hr == MF_E_NOT_FOUND || broken(hr == S_OK) /* < win10 */, "Unexpected hr %#lx.\n", hr);
@@ -3620,6 +3669,8 @@ static void test_scheduled_items(void) return; }
+ callback = create_test_callback(NULL); + hr = MFCreateAsyncResult(NULL, &callback->IMFAsyncCallback_iface, NULL, &result); ok(hr == S_OK, "Failed to create result, hr %#lx.\n", hr);
@@ -3716,6 +3767,10 @@ static void test_periodic_callback(void) { DWORD period, key; HRESULT hr; + struct test_context context = { + .IUnknown_iface = { &test_context_vtbl }, + .refcount = 1, + };
hr = MFStartup(MF_VERSION, MFSTARTUP_FULL); ok(hr == S_OK, "Failed to start up, hr %#lx.\n", hr); @@ -3746,6 +3801,15 @@ static void test_periodic_callback(void)
ok(periodic_counter > 0, "Unexpected counter value %lu.\n", periodic_counter);
+ hr= pMFAddPeriodicCallback(periodic_callback, &context.IUnknown_iface, &key); + ok(hr == S_OK, "Failed to add periodic callback, hr %#lx.\n", hr); + ok(context.refcount == 2, "Unexpected refcount %ld.\n", context.refcount); + + hr = pMFRemovePeriodicCallback(key); + ok(hr == S_OK, "Failed to remove callback, hr %#lx.\n", hr); + Sleep(500); + todo_wine ok(context.refcount == 1, "Unexpected refcount %ld.\n", context.refcount); + hr = MFShutdown(); ok(hr == S_OK, "Failed to shut down, hr %#lx.\n", hr); }
From: Yuxuan Shui yshui@codeweavers.com
Usually the threadpool holds a reference to the `work_item`, which is released when the `work_item`'s callback is invoked. On the other hand, `queue_cancel_item` closes the threadpool object without releasing the `work_item`. So if the callbacks don't get a chance to run - which is not guaranteed - the `work_item` will be leaked.
The fix is not as simple as adding a `IUnknown_Release` to `queue_cancel_item`, because the `work_item`'s callback can still be called after CloseThreadpoolTimer/Wait has returned. In fact its callback might currently be running. In which case the callback will access freed memory if `queue_cancel_item` frees the `work_item`.
We have to stop any further callbacks to be queued, wait for any currently running callbacks to finish, then finally we can release the `work_item` if it hasn't already been freed during the wait. --- dlls/mf/tests/mf.c | 1 - dlls/mfplat/tests/mfplat.c | 4 ++-- dlls/rtworkq/queue.c | 48 +++++++++++++++++++++++++++++++------- 3 files changed, 41 insertions(+), 12 deletions(-)
diff --git a/dlls/mf/tests/mf.c b/dlls/mf/tests/mf.c index fece60d37a7..ee74014b63d 100644 --- a/dlls/mf/tests/mf.c +++ b/dlls/mf/tests/mf.c @@ -6376,7 +6376,6 @@ if (SUCCEEDED(hr)) check_sar_rate_support(sink);
ref = IMFMediaSink_Release(sink); - todo_wine ok(ref == 0, "Release returned %ld\n", ref);
/* Activation */ diff --git a/dlls/mfplat/tests/mfplat.c b/dlls/mfplat/tests/mfplat.c index abad318f886..ad4223685f0 100644 --- a/dlls/mfplat/tests/mfplat.c +++ b/dlls/mfplat/tests/mfplat.c @@ -3658,7 +3658,7 @@ static void test_scheduled_items(void) ok(hr == S_OK, "Failed to cancel item, hr %#lx.\n", hr);
refcount = IMFAsyncCallback_Release(&callback->IMFAsyncCallback_iface); - todo_wine ok(refcount == 0, "Unexpected refcount %lu.\n", refcount); + ok(refcount == 0, "Unexpected refcount %lu.\n", refcount);
hr = MFCancelWorkItem(key); ok(hr == MF_E_NOT_FOUND || broken(hr == S_OK) /* < win10 */, "Unexpected hr %#lx.\n", hr); @@ -3808,7 +3808,7 @@ static void test_periodic_callback(void) hr = pMFRemovePeriodicCallback(key); ok(hr == S_OK, "Failed to remove callback, hr %#lx.\n", hr); Sleep(500); - todo_wine ok(context.refcount == 1, "Unexpected refcount %ld.\n", context.refcount); + ok(context.refcount == 1, "Unexpected refcount %ld.\n", context.refcount);
hr = MFShutdown(); ok(hr == S_OK, "Failed to shut down, hr %#lx.\n", hr); diff --git a/dlls/rtworkq/queue.c b/dlls/rtworkq/queue.c index 00b77bf6953..9046a70b359 100644 --- a/dlls/rtworkq/queue.c +++ b/dlls/rtworkq/queue.c @@ -883,7 +883,8 @@ static HRESULT queue_submit_timer(struct queue *queue, IRtwqAsyncResult *result,
static HRESULT queue_cancel_item(struct queue *queue, RTWQWORKITEM_KEY key) { - HRESULT hr = RTWQ_E_NOT_FOUND; + TP_WAIT *wait_object; + TP_TIMER *timer_object; struct work_item *item;
EnterCriticalSection(&queue->cs); @@ -891,29 +892,58 @@ static HRESULT queue_cancel_item(struct queue *queue, RTWQWORKITEM_KEY key) { if (item->key == key) { + /* We can't immediately release the item here, because the callback could already be + * running somewhere else. And if we release it here, the callback will access freed memory. + * So instead we have to make sure the callback is really stopped, or has really finished + * running before we do that. And we can't do that in this critical section, which would be a + * deadlock. So we first keep an extra reference to it, then leave the critical section to + * wait for the thread-pool objects, finally we re-enter critical section to release it. */ key >>= 32; + IUnknown_AddRef(&item->IUnknown_iface); if ((key & WAIT_ITEM_KEY_MASK) == WAIT_ITEM_KEY_MASK) { - IRtwqAsyncResult_SetStatus(item->result, RTWQ_E_OPERATION_CANCELLED); - invoke_async_callback(item->result); - CloseThreadpoolWait(item->u.wait_object); + wait_object = item->u.wait_object; item->u.wait_object = NULL; + LeaveCriticalSection(&queue->cs); + + SetThreadpoolWait(wait_object, NULL, NULL); + WaitForThreadpoolWaitCallbacks(wait_object, TRUE); + CloseThreadpoolWait(wait_object); } else if ((key & SCHEDULED_ITEM_KEY_MASK) == SCHEDULED_ITEM_KEY_MASK) { - CloseThreadpoolTimer(item->u.timer_object); + timer_object = item->u.timer_object; item->u.timer_object = NULL; + LeaveCriticalSection(&queue->cs); + + SetThreadpoolTimer(timer_object, NULL, 0, 0); + WaitForThreadpoolTimerCallbacks(timer_object, TRUE); + CloseThreadpoolTimer(timer_object); } else + { WARN("Unknown item key mask %#I64x.\n", key); - queue_release_pending_item(item); - hr = S_OK; - break; + LeaveCriticalSection(&queue->cs); + } + + if (queue_release_pending_item(item)) + { + /* This means the callback wasn't run during our wait, so we can invoke the + * callback with a canceled status, and release the work item. */ + if ((key & WAIT_ITEM_KEY_MASK) == WAIT_ITEM_KEY_MASK) + { + IRtwqAsyncResult_SetStatus(item->result, RTWQ_E_OPERATION_CANCELLED); + invoke_async_callback(item->result); + } + IUnknown_Release(&item->IUnknown_iface); + } + IUnknown_Release(&item->IUnknown_iface); + return S_OK; } } LeaveCriticalSection(&queue->cs);
- return hr; + return RTWQ_E_NOT_FOUND; }
static HRESULT alloc_user_queue(const struct queue_desc *desc, DWORD *queue_id)
Why can you just use create_test_callback() here as well? This struct test_context is merely a stub interface and there is nothing special about it.
You can use some thing like "EXPECT_REF(&callback->IMFAsyncCallback_iface, 2);" to check the ref count.