From: Piotr Caban <piotr@codeweavers.com> --- dlls/ole32/tests/compobj.c | 163 +++++++++++++++++++++++++++++++++++-- 1 file changed, 156 insertions(+), 7 deletions(-) diff --git a/dlls/ole32/tests/compobj.c b/dlls/ole32/tests/compobj.c index 5ec5febb1a9..b4e12b38fe1 100644 --- a/dlls/ole32/tests/compobj.c +++ b/dlls/ole32/tests/compobj.c @@ -32,6 +32,7 @@ #include "urlmon.h" /* for CLSID_FileProtocol */ #include "dde.h" #include "cguid.h" +#include "comsvcs.h" #include "ctxtcall.h" @@ -69,6 +70,7 @@ DEFINE_EXPECT(PreInitialize); DEFINE_EXPECT(PostInitialize); DEFINE_EXPECT(PreUninitialize); DEFINE_EXPECT(PostUninitialize); +DEFINE_EXPECT(context_callback_func); /* functions that are not present on all versions of Windows */ static HRESULT (WINAPI * pCoGetObjectContext)(REFIID riid, LPVOID *ppv); @@ -2000,15 +2002,98 @@ static void test_CoFreeUnusedLibraries(void) CoUninitialize(); } +struct context_callback_arg +{ + IContextCallback *context_callback; + ULONG_PTR token; + BOOL is_mta; + GUID logical_thread_id; + BOOL todo_thread_id; +}; + +static HRESULT WINAPI context_callback_func(ComCallData *arg) +{ + struct context_callback_arg *ctx = (struct context_callback_arg *)arg; + ULONG_PTR token; + GUID thread_id; + HRESULT hr; + + CHECK_EXPECT(context_callback_func); + + hr = CoGetCurrentLogicalThreadId(&thread_id); + ok_ole_success(hr, "CoGetCurrentLgoicalThreadId"); + todo_wine_if(ctx->todo_thread_id) ok(IsEqualIID(&thread_id, &ctx->logical_thread_id), + "thread_id = %s\n", wine_dbgstr_guid(&thread_id)); + + hr = pCoGetContextToken(&token); + ok_ole_success(hr, "CoGetContextToken"); + ok(token == ctx->token, "executed in different context\n"); + return S_FALSE; +} + +static DWORD WINAPI context_callback_thread(void *arg) +{ + struct context_callback_arg *ctx = arg; + HRESULT hr; + GUID id; + + CoInitializeEx(NULL, COINIT_APARTMENTTHREADED); + + hr = CoGetCurrentLogicalThreadId(&id); + ok_ole_success(hr, "CoGetCurrentLogicalThreadId"); + hr = IContextCallback_QueryInterface(ctx->context_callback, &IID_IObjContext, (void **)&ctx->token); + ok_ole_success(hr, "IContextCallback_QueryInterface"); + IObjContext_Release((IObjContext *)ctx->token); + + SET_EXPECT(context_callback_func); + ctx->logical_thread_id = id; + /* TODO: native calls the callback in current thread after temporarily converting it to MTA */ + if (ctx->is_mta) ctx->todo_thread_id = TRUE; + hr = IContextCallback_ContextCallback(ctx->context_callback, context_callback_func, + (ComCallData *)ctx, &IID_IContextCallback, 2, NULL); + ctx->todo_thread_id = FALSE; + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + SET_EXPECT(context_callback_func); + ctx->logical_thread_id = IID_IEnterActivityWithNoLock; + hr = IContextCallback_ContextCallback(ctx->context_callback, context_callback_func, + (ComCallData *)ctx, &IID_IEnterActivityWithNoLock, 2, NULL); + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + CoUninitialize(); + + CoInitializeEx(NULL, COINIT_MULTITHREADED); + + SET_EXPECT(context_callback_func); + ctx->logical_thread_id = id; + hr = IContextCallback_ContextCallback(ctx->context_callback, context_callback_func, + (ComCallData *)ctx, &IID_IContextCallback, 2, NULL); + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + SET_EXPECT(context_callback_func); + ctx->logical_thread_id = IID_IEnterActivityWithNoLock; + hr = IContextCallback_ContextCallback(ctx->context_callback, context_callback_func, + (ComCallData *)ctx, &IID_IEnterActivityWithNoLock, 2, NULL); + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + CoUninitialize(); + return 0; +} + static void test_CoGetObjectContext(void) { HRESULT hr; ULONG refs; IComThreadingInfo *pComThreadingInfo, *threadinginfo2; + struct context_callback_arg callback_arg; IContextCallback *pContextCallback; - IObjContext *pObjContext; APTTYPE apttype; THDTYPE thdtype; + HANDLE thread; GUID id, id2; if (!pCoGetObjectContext) @@ -2073,6 +2158,50 @@ static void test_CoGetObjectContext(void) hr = pCoGetObjectContext(&IID_IContextCallback, (void **)&pContextCallback); ok_ole_success(hr, "CoGetObjectContext(ContextCallback)"); + callback_arg.context_callback = pContextCallback; + hr = pCoGetObjectContext(&IID_IObjContext, (void **)&callback_arg.token); + ok_ole_success(hr, "CoGetObjectContext"); + IObjContext_Release((IObjContext *)callback_arg.token); + callback_arg.is_mta = FALSE; + callback_arg.logical_thread_id = id2; + callback_arg.todo_thread_id = FALSE; + + SET_EXPECT(context_callback_func); + hr = IContextCallback_ContextCallback(pContextCallback, context_callback_func, + (ComCallData *)&callback_arg, &IID_IContextCallback, 2, NULL); + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + SET_EXPECT(context_callback_func); + callback_arg.logical_thread_id = IID_IEnterActivityWithNoLock; + hr = IContextCallback_ContextCallback(pContextCallback, context_callback_func, + (ComCallData *)&callback_arg, &IID_IEnterActivityWithNoLock, 2, NULL); + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + thread = CreateThread(NULL, 0, context_callback_thread, &callback_arg, 0, NULL); + ok(thread != NULL, "CreateThread failed\n"); + while (1) + { + MSG msg; + + switch(MsgWaitForMultipleObjects(1, &thread, FALSE, INFINITE, QS_ALLPOSTMESSAGE)) + { + case WAIT_OBJECT_0: + break; + case WAIT_OBJECT_0 + 1: + while (PeekMessageA(&msg, 0, 0, 0, PM_REMOVE)) + DispatchMessageA(&msg); + continue; + default: + ok(0, "MsgWaitForMultipleObjects failed\n"); + break; + } + + break; + } + CloseHandle(thread); + refs = IContextCallback_Release(pContextCallback); ok(refs == 0, "pContextCallback should have 0 refs instead of %ld refs\n", refs); @@ -2097,14 +2226,34 @@ static void test_CoGetObjectContext(void) hr = pCoGetObjectContext(&IID_IContextCallback, (void **)&pContextCallback); ok_ole_success(hr, "CoGetObjectContext(ContextCallback)"); - refs = IContextCallback_Release(pContextCallback); - ok(refs == 0, "pContextCallback should have 0 refs instead of %ld refs\n", refs); - - hr = pCoGetObjectContext(&IID_IObjContext, (void **)&pObjContext); + callback_arg.context_callback = pContextCallback; + hr = pCoGetObjectContext(&IID_IObjContext, (void **)&callback_arg.token); ok_ole_success(hr, "CoGetObjectContext"); + IObjContext_Release((IObjContext *)callback_arg.token); + callback_arg.is_mta = TRUE; + callback_arg.logical_thread_id = id2; + callback_arg.todo_thread_id = FALSE; + + SET_EXPECT(context_callback_func); + hr = IContextCallback_ContextCallback(pContextCallback, context_callback_func, + (ComCallData *)&callback_arg, &IID_IContextCallback, 2, NULL); + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + SET_EXPECT(context_callback_func); + callback_arg.logical_thread_id = IID_IEnterActivityWithNoLock; + hr = IContextCallback_ContextCallback(pContextCallback, context_callback_func, + (ComCallData *)&callback_arg, &IID_IEnterActivityWithNoLock, 2, NULL); + CHECK_CALLED(context_callback_func, 1); + ok(hr == S_FALSE, "got 0x%08lx\n", hr); + + thread = CreateThread(NULL, 0, context_callback_thread, &callback_arg, 0, NULL); + ok(thread != NULL, "CreateThread failed\n"); + WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); - refs = IObjContext_Release(pObjContext); - ok(refs == 0, "pObjContext should have 0 refs instead of %ld refs\n", refs); + refs = IContextCallback_Release(pContextCallback); + ok(refs == 0, "pContextCallback should have 0 refs instead of %ld refs\n", refs); CoUninitialize(); } -- GitLab https://gitlab.winehq.org/wine/wine/-/merge_requests/10995