From: Piotr Caban <piotr@codeweavers.com> --- dlls/combase/combase.c | 169 +++++++++++++++++++++++++------------ dlls/ole32/tests/compobj.c | 46 ++++++++++ 2 files changed, 163 insertions(+), 52 deletions(-) diff --git a/dlls/combase/combase.c b/dlls/combase/combase.c index a627d4779ec..8b42615fc0e 100644 --- a/dlls/combase/combase.c +++ b/dlls/combase/combase.c @@ -395,39 +395,6 @@ HRESULT WINAPI InternalTlsAllocData(struct tlsdata **data) return S_OK; } -static void com_cleanup_tlsdata(void) -{ - struct tlsdata *tlsdata = NtCurrentTeb()->ReservedForOle; - struct init_spy *cursor, *cursor2; - - if (!tlsdata) - return; - - if (tlsdata->apt) - apartment_release(tlsdata->apt); - if (tlsdata->implicit_mta_cookie) - apartment_decrement_mta_usage(tlsdata->implicit_mta_cookie); - - if (tlsdata->errorinfo) - IErrorInfo_Release(tlsdata->errorinfo); - if (tlsdata->state) - IUnknown_Release(tlsdata->state); - - LIST_FOR_EACH_ENTRY_SAFE(cursor, cursor2, &tlsdata->spies, struct init_spy, entry) - { - list_remove(&cursor->entry); - if (cursor->spy) - IInitializeSpy_Release(cursor->spy); - free(cursor); - } - - if (tlsdata->context_token) - IObjContext_Release(tlsdata->context_token); - - free(tlsdata); - NtCurrentTeb()->ReservedForOle = NULL; -} - struct global_options { IGlobalOptions IGlobalOptions_iface; @@ -2436,13 +2403,24 @@ HRESULT WINAPI CoRegisterPSClsid(REFIID riid, REFCLSID rclsid) return S_OK; } -struct thread_context +static CRITICAL_SECTION context_cs; +static CRITICAL_SECTION_DEBUG context_cs_debug = +{ + 0, 0, &context_cs, + { &context_cs_debug.ProcessLocksList, &context_cs_debug.ProcessLocksList }, + 0, 0, { (DWORD_PTR)(__FILE__ ": context_cs") } +}; +static CRITICAL_SECTION context_cs = { &context_cs_debug, -1, 0, 0, 0, 0 }; + + +static struct thread_context { IComThreadingInfo IComThreadingInfo_iface; IContextCallback IContextCallback_iface; IObjContext IObjContext_iface; - LONG refcount; -}; + LONG ref; + LONG reported_ref; +} *mta_context; static inline struct thread_context *impl_from_IComThreadingInfo(IComThreadingInfo *iface) { @@ -2492,22 +2470,36 @@ static HRESULT WINAPI thread_context_info_QueryInterface(IComThreadingInfo *ifac static ULONG WINAPI thread_context_info_AddRef(IComThreadingInfo *iface) { struct thread_context *context = impl_from_IComThreadingInfo(iface); - return InterlockedIncrement(&context->refcount); + InterlockedIncrement(&context->ref); + return InterlockedIncrement(&context->reported_ref); } -static ULONG WINAPI thread_context_info_Release(IComThreadingInfo *iface) +static void context_token_release(IObjContext *iface) { - struct thread_context *context = impl_from_IComThreadingInfo(iface); + struct thread_context *context = impl_from_IObjContext(iface); - /* Context instance is initially created with CoGetContextToken() with refcount set to 0, - releasing context while refcount is at 0 destroys it. */ - if (!context->refcount) + if (!InterlockedDecrement(&context->ref)) { + EnterCriticalSection(&context_cs); + if (context->ref) + { + LeaveCriticalSection(&context_cs); + return; + } + if (context == mta_context) mta_context = NULL; + LeaveCriticalSection(&context_cs); + free(context); - return 0; } +} - return InterlockedDecrement(&context->refcount); +static ULONG WINAPI thread_context_info_Release(IComThreadingInfo *iface) +{ + struct thread_context *context = impl_from_IComThreadingInfo(iface); + ULONG ret = InterlockedDecrement(&context->reported_ref); + + context_token_release(&context->IObjContext_iface); + return ret; } static HRESULT WINAPI thread_context_info_GetCurrentApartmentType(IComThreadingInfo *iface, APTTYPE *apttype) @@ -2708,21 +2700,44 @@ static const IObjContextVtbl thread_object_context_vtbl = HRESULT WINAPI CoGetContextToken(ULONG_PTR *token) { struct tlsdata *tlsdata; + struct apartment *apt; HRESULT hr; TRACE("%p\n", token); - if (!InternalIsProcessInitialized()) + if (FAILED(hr = com_get_tlsdata(&tlsdata))) + return hr; + + if (!token) + return E_POINTER; + + if (!(apt = apartment_get_current_or_mta())) { ERR("apartment not initialised\n"); return CO_E_NOTINITIALIZED; } - if (FAILED(hr = com_get_tlsdata(&tlsdata))) - return hr; + if (tlsdata->context_token) + { + struct thread_context *context = impl_from_IObjContext(tlsdata->context_token); + if ((apt->multi_threaded && context != mta_context) || + (!apt->multi_threaded && context == mta_context)) + { + context_token_release(tlsdata->context_token); + tlsdata->context_token = NULL; + } + } - if (!token) - return E_POINTER; + if (!tlsdata->context_token && apt->multi_threaded && mta_context) + { + EnterCriticalSection(&context_cs); + if (mta_context) + { + tlsdata->context_token = &mta_context->IObjContext_iface; + InterlockedIncrement(&mta_context->ref); + } + LeaveCriticalSection(&context_cs); + } if (!tlsdata->context_token) { @@ -2730,14 +2745,30 @@ HRESULT WINAPI CoGetContextToken(ULONG_PTR *token) context = calloc(1, sizeof(*context)); if (!context) + { + apartment_release(apt); return E_OUTOFMEMORY; + } context->IComThreadingInfo_iface.lpVtbl = &thread_context_info_vtbl; context->IContextCallback_iface.lpVtbl = &thread_context_callback_vtbl; context->IObjContext_iface.lpVtbl = &thread_object_context_vtbl; - /* Context token does not take a reference, it's always zero until the - interface is explicitly requested with CoGetObjectContext(). */ - context->refcount = 0; + context->ref = 1; + context->reported_ref = 0; + + if (apt->multi_threaded) + { + EnterCriticalSection(&context_cs); + if (!mta_context) + mta_context = context; + else + { + context_token_release(&context->IObjContext_iface); + context = mta_context; + InterlockedIncrement(&context->ref); + } + LeaveCriticalSection(&context_cs); + } tlsdata->context_token = &context->IObjContext_iface; } @@ -2745,6 +2776,7 @@ HRESULT WINAPI CoGetContextToken(ULONG_PTR *token) *token = (ULONG_PTR)tlsdata->context_token; TRACE("context_token %p\n", tlsdata->context_token); + apartment_release(apt); return S_OK; } @@ -3435,6 +3467,39 @@ HRESULT WINAPI CoRegisterActivationFilter(IActivationFilter *filter) return E_NOTIMPL; } +static void com_cleanup_tlsdata(void) +{ + struct tlsdata *tlsdata = NtCurrentTeb()->ReservedForOle; + struct init_spy *cursor, *cursor2; + + if (!tlsdata) + return; + + if (tlsdata->apt) + apartment_release(tlsdata->apt); + if (tlsdata->implicit_mta_cookie) + apartment_decrement_mta_usage(tlsdata->implicit_mta_cookie); + + if (tlsdata->errorinfo) + IErrorInfo_Release(tlsdata->errorinfo); + if (tlsdata->state) + IUnknown_Release(tlsdata->state); + + LIST_FOR_EACH_ENTRY_SAFE(cursor, cursor2, &tlsdata->spies, struct init_spy, entry) + { + list_remove(&cursor->entry); + if (cursor->spy) + IInitializeSpy_Release(cursor->spy); + free(cursor); + } + + if (tlsdata->context_token) + context_token_release(tlsdata->context_token); + + free(tlsdata); + NtCurrentTeb()->ReservedForOle = NULL; +} + /*********************************************************************** * DllMain (combase.@) */ diff --git a/dlls/ole32/tests/compobj.c b/dlls/ole32/tests/compobj.c index 2c4cc566f3a..1f6e2326129 100644 --- a/dlls/ole32/tests/compobj.c +++ b/dlls/ole32/tests/compobj.c @@ -2201,12 +2201,42 @@ static void test_CoGetCallContext(void) CoUninitialize(); } +static DWORD WINAPI get_context_token_thread(void *arg) +{ + ULONG_PTR mta_token = (ULONG_PTR)arg, token; + HRESULT hr; + ULONG refs; + + test_apt_type(APTTYPE_MTA, APTTYPEQUALIFIER_IMPLICIT_MTA); + hr = pCoGetContextToken(&token); + ok(hr == S_OK, "Expected S_OK, got 0x%08lx\n", hr); + ok(token, "Expected token != 0\n"); + ok(token == mta_token, "token != mta_token\n"); + + refs = IUnknown_AddRef((IUnknown *)token); + ok(refs == 1, "Expected 1, got %lu\n", refs); + IUnknown_Release((IUnknown *)token); + + hr = CoInitialize(NULL); + ok(hr == S_OK, "CoInitialize() failed with error 0x%08lx\n", hr); + test_apt_type(APTTYPE_MAINSTA, APTTYPEQUALIFIER_NONE); + + hr = pCoGetContextToken(&token); + ok(hr == S_OK, "Expected S_OK, got 0x%08lx\n", hr); + ok(token, "Expected token != 0\n"); + ok(token != mta_token, "token == mta_token\n"); + + CoUninitialize(); + return 0; +} + static void test_CoGetContextToken(void) { HRESULT hr; ULONG refs; ULONG_PTR token, token2; IObjContext *ctx; + HANDLE thread; if (!pCoGetContextToken) { @@ -2267,6 +2297,22 @@ static void test_CoGetContextToken(void) refs = IObjContext_Release(ctx); ok(refs == 1, "Expected 0, got %lu\n", refs); + CoUninitialize(); + + hr = CoInitializeEx(NULL, COINIT_MULTITHREADED); + ok(hr == S_OK, "Expected S_OK, got 0x%08lx\n", hr); + + token = 0; + hr = pCoGetContextToken(&token); + ok(hr == S_OK, "Expected S_OK, got 0x%08lx\n", hr); + ok(token, "Expected token != 0\n"); + ok(token != token2, "token did not change\n"); + + thread = CreateThread(NULL, 0, get_context_token_thread, (void*)token, 0, NULL); + ok(thread != NULL, "CreateThread failed\n"); + WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + refs = IObjContext_Release(ctx); ok(refs == 0, "Expected 0, got %lu\n", refs); -- GitLab https://gitlab.winehq.org/wine/wine/-/merge_requests/10982