This fixes an issue I ran into in UI Automation using an interface proxy marshaled with `MSHCTX_INPROC`. `CoUnmarshalInterface` always passes `MSHCTX_LOCAL` when using the standard marshaler, regardless of what was passed to `CoMarshalInterface`.
When passing an interface that uses the free threaded marshaler as an argument to a method on the proxy retrieved from `CoUnmarshalInterface`, it passes `MSHCTX_LOCAL` when trying to marshal, which the results in the free threaded marshaler trying to create a proxy/stub which fails.
-- v3: combase: Use correct destination context in CoUnmarshalInterface when using the standard marshaler. ole32/tests: Extend test_marshal_channel_buffer() test to include IRpcProxyBufferWrapper checks.
From: Connor McAdams cmcadams@codeweavers.com
Signed-off-by: Connor McAdams cmcadams@codeweavers.com --- dlls/ole32/tests/marshal.c | 102 +++++++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 5 deletions(-)
diff --git a/dlls/ole32/tests/marshal.c b/dlls/ole32/tests/marshal.c index 1344cf1d163..f76505ea1bf 100644 --- a/dlls/ole32/tests/marshal.c +++ b/dlls/ole32/tests/marshal.c @@ -351,10 +351,12 @@ static const IClassFactoryVtbl TestClassFactory_Vtbl = static IClassFactory Test_ClassFactory = { &TestClassFactory_Vtbl };
DEFINE_EXPECT(Invoke); +DEFINE_EXPECT(Connect); DEFINE_EXPECT(CreateStub); DEFINE_EXPECT(CreateProxy); DEFINE_EXPECT(GetWindow); -DEFINE_EXPECT(Disconnect); +DEFINE_EXPECT(RpcStubBuffer_Disconnect); +DEFINE_EXPECT(RpcProxyBuffer_Disconnect);
static HRESULT WINAPI OleWindow_QueryInterface(IOleWindow *iface, REFIID riid, void **ppv) { @@ -476,7 +478,7 @@ static HRESULT WINAPI RpcStubBuffer_Connect(IRpcStubBuffer *iface, IUnknown *pUn
static void WINAPI RpcStubBuffer_Disconnect(IRpcStubBuffer *iface) { - CHECK_EXPECT(Disconnect); + CHECK_EXPECT(RpcStubBuffer_Disconnect); }
static HRESULT WINAPI RpcStubBuffer_Invoke(IRpcStubBuffer *iface, RPCOLEMESSAGE *_prpcmsg, @@ -533,6 +535,79 @@ static const IRpcStubBufferVtbl RpcStubBufferVtbl = { RpcStubBuffer_DebugServerRelease };
+typedef struct { + IRpcProxyBuffer IRpcProxyBuffer_iface; + LONG ref; + IRpcProxyBuffer *buffer; +} ProxyBufferWrapper; + +static ProxyBufferWrapper *impl_from_IRpcProxyBuffer(IRpcProxyBuffer *iface) +{ + return CONTAINING_RECORD(iface, ProxyBufferWrapper, IRpcProxyBuffer_iface); +} + +static HRESULT WINAPI RpcProxyBuffer_QueryInterface(IRpcProxyBuffer *iface, REFIID riid, void **ppv) +{ + ProxyBufferWrapper *This = impl_from_IRpcProxyBuffer(iface); + + if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IRpcProxyBuffer, riid)) { + *ppv = &This->IRpcProxyBuffer_iface; + }else { + *ppv = NULL; + return E_NOINTERFACE; + } + + IUnknown_AddRef((IUnknown*)*ppv); + return S_OK; +} + +static ULONG WINAPI RpcProxyBuffer_AddRef(IRpcProxyBuffer *iface) +{ + ProxyBufferWrapper *This = impl_from_IRpcProxyBuffer(iface); + return InterlockedIncrement(&This->ref); +} + +static ULONG WINAPI RpcProxyBuffer_Release(IRpcProxyBuffer *iface) +{ + ProxyBufferWrapper *This = impl_from_IRpcProxyBuffer(iface); + LONG ref = InterlockedDecrement(&This->ref); + if(!ref) { + IRpcProxyBuffer_Release(This->buffer); + heap_free(This); + } + return ref; +} + +static HRESULT WINAPI RpcProxyBuffer_Connect(IRpcProxyBuffer *iface, IRpcChannelBuffer *pRpcChannelBuffer) +{ + ProxyBufferWrapper *This = impl_from_IRpcProxyBuffer(iface); + void *dest_context_data; + DWORD dest_context; + HRESULT hr; + + CHECK_EXPECT(Connect); + + hr = IRpcChannelBuffer_GetDestCtx(pRpcChannelBuffer, &dest_context, &dest_context_data); + ok(hr == S_OK, "GetDestCtx failed: %08lx\n", hr); + todo_wine ok(dest_context == MSHCTX_INPROC, "desc_context = %lx\n", dest_context); + ok(!dest_context_data, "desc_context_data = %p\n", dest_context_data); + + return IRpcProxyBuffer_Connect(This->buffer, pRpcChannelBuffer); +} + +static void WINAPI RpcProxyBuffer_Disconnect(IRpcProxyBuffer *iface) +{ + CHECK_EXPECT(RpcProxyBuffer_Disconnect); +} + +static const IRpcProxyBufferVtbl RpcProxyBufferVtbl = { + RpcProxyBuffer_QueryInterface, + RpcProxyBuffer_AddRef, + RpcProxyBuffer_Release, + RpcProxyBuffer_Connect, + RpcProxyBuffer_Disconnect, +}; + static IPSFactoryBuffer *ps_factory_buffer;
static HRESULT WINAPI PSFactoryBuffer_QueryInterface(IPSFactoryBuffer *iface, REFIID riid, void **ppv) @@ -561,8 +636,20 @@ static ULONG WINAPI PSFactoryBuffer_Release(IPSFactoryBuffer *iface) static HRESULT WINAPI PSFactoryBuffer_CreateProxy(IPSFactoryBuffer *iface, IUnknown *outer, REFIID riid, IRpcProxyBuffer **ppProxy, void **ppv) { + ProxyBufferWrapper *proxy; + HRESULT hr; + CHECK_EXPECT(CreateProxy); - return IPSFactoryBuffer_CreateProxy(ps_factory_buffer, outer, riid, ppProxy, ppv); + proxy = heap_alloc(sizeof(*proxy)); + proxy->IRpcProxyBuffer_iface.lpVtbl = &RpcProxyBufferVtbl; + proxy->ref = 1; + + hr = IPSFactoryBuffer_CreateProxy(ps_factory_buffer, outer, riid, &proxy->buffer, ppv); + ok(hr == S_OK, "CreateProxy failed: %08lx\n", hr); + + *ppProxy = &proxy->IRpcProxyBuffer_iface; + + return S_OK; }
static HRESULT WINAPI PSFactoryBuffer_CreateStub(IPSFactoryBuffer *iface, REFIID riid, @@ -1327,10 +1414,12 @@ static void test_marshal_channel_buffer(void)
SET_EXPECT(CreateStub); SET_EXPECT(CreateProxy); + SET_EXPECT(Connect); hr = IUnknown_QueryInterface(proxy, &IID_IOleWindow, (void**)&ole_window); ok(hr == S_OK, "Could not get IOleWindow iface: %08lx\n", hr); CHECK_CALLED(CreateStub); CHECK_CALLED(CreateProxy); + CHECK_CALLED(Connect);
SET_EXPECT(Invoke); SET_EXPECT(GetWindow); @@ -1342,10 +1431,13 @@ static void test_marshal_channel_buffer(void)
IOleWindow_Release(ole_window);
- SET_EXPECT(Disconnect); + SET_EXPECT(RpcStubBuffer_Disconnect); + SET_EXPECT(RpcProxyBuffer_Disconnect); IUnknown_Release(proxy); todo_wine - CHECK_CALLED(Disconnect); + CHECK_CALLED(RpcStubBuffer_Disconnect); + todo_wine + CHECK_CALLED(RpcProxyBuffer_Disconnect);
hr = CoRevokeClassObject(registration_key); ok(hr == S_OK, "CoRevokeClassObject failed: %08lx\n", hr);
From: Connor McAdams cmcadams@codeweavers.com
Wine-Bug: https://bugs.winehq.org/show_bug.cgi?id=54609 Signed-off-by: Connor McAdams cmcadams@codeweavers.com --- dlls/combase/combase_private.h | 1 + dlls/combase/marshal.c | 8 +++++--- dlls/combase/stubmanager.c | 25 +++++++++++++++++++++++++ dlls/ole32/tests/marshal.c | 2 +- 4 files changed, 32 insertions(+), 4 deletions(-)
diff --git a/dlls/combase/combase_private.h b/dlls/combase/combase_private.h index f9c349c3e20..53932e9a357 100644 --- a/dlls/combase/combase_private.h +++ b/dlls/combase/combase_private.h @@ -256,4 +256,5 @@ struct ifstub * stub_manager_new_ifstub(struct stub_manager *m, IRpcStubBuffer * HRESULT ipid_get_dispatch_params(const IPID *ipid, struct apartment **stub_apt, struct stub_manager **manager, IRpcStubBuffer **stub, IRpcChannelBuffer **chan, IID *iid, IUnknown **iface); +HRESULT ipid_get_dest_context(const IPID *ipid, MSHCTX *dest_context, void **dest_context_data); HRESULT start_apartment_remote_unknown(struct apartment *apt); diff --git a/dlls/combase/marshal.c b/dlls/combase/marshal.c index e2f6a57d8e1..84f57b8c1c8 100644 --- a/dlls/combase/marshal.c +++ b/dlls/combase/marshal.c @@ -698,7 +698,7 @@ HRESULT WINAPI CoReleaseMarshalData(IStream *stream) }
static HRESULT std_unmarshal_interface(MSHCTX dest_context, void *dest_context_data, - IStream *stream, REFIID riid, void **ppv) + IStream *stream, REFIID riid, void **ppv, BOOL dest_context_known) { struct stub_manager *stubmgr = NULL; struct OR_STANDARD obj; @@ -757,6 +757,8 @@ static HRESULT std_unmarshal_interface(MSHCTX dest_context, void *dest_context_d { if (!stub_manager_notify_unmarshal(stubmgr, &obj.std.ipid)) hres = CO_E_OBJNOTCONNECTED; + if (SUCCEEDED(hres) && !dest_context_known) + hres = ipid_get_dest_context(&obj.std.ipid, &dest_context, &dest_context_data); } else { @@ -803,7 +805,7 @@ HRESULT WINAPI CoUnmarshalInterface(IStream *stream, REFIID riid, void **ppv) hr = get_unmarshaler_from_stream(stream, &marshal, &iid); if (hr == S_FALSE) { - hr = std_unmarshal_interface(0, NULL, stream, &iid, (void **)&object); + hr = std_unmarshal_interface(0, NULL, stream, &iid, (void **)&object, FALSE); if (hr != S_OK) ERR("StdMarshal UnmarshalInterface failed, hr %#lx\n", hr); } @@ -2183,7 +2185,7 @@ static HRESULT WINAPI StdMarshalImpl_UnmarshalInterface(IMarshal *iface, IStream return E_NOTIMPL; }
- return std_unmarshal_interface(marshal->dest_context, marshal->dest_context_data, stream, riid, ppv); + return std_unmarshal_interface(marshal->dest_context, marshal->dest_context_data, stream, riid, ppv, TRUE); }
static HRESULT WINAPI StdMarshalImpl_ReleaseMarshalData(IMarshal *iface, IStream *stream) diff --git a/dlls/combase/stubmanager.c b/dlls/combase/stubmanager.c index ff8937d7bd7..1b079f22fef 100644 --- a/dlls/combase/stubmanager.c +++ b/dlls/combase/stubmanager.c @@ -554,6 +554,31 @@ HRESULT ipid_get_dispatch_params(const IPID *ipid, struct apartment **stub_apt, return S_OK; }
+HRESULT ipid_get_dest_context(const IPID *ipid, MSHCTX *dest_context, void **dest_context_data) +{ + struct stub_manager *stubmgr; + struct ifstub *ifstub; + struct apartment *apt; + void *data; + HRESULT hr; + DWORD ctx; + + hr = ipid_to_ifstub(ipid, &apt, &stubmgr, &ifstub); + if (hr != S_OK) return RPC_E_DISCONNECTED; + + hr = IRpcChannelBuffer_GetDestCtx(ifstub->chan, &ctx, &data); + if (SUCCEEDED(hr)) + { + *dest_context = ctx; + *dest_context_data = data; + } + + stub_manager_int_release(stubmgr); + apartment_release(apt); + + return hr; +} + /* returns TRUE if it is possible to unmarshal, FALSE otherwise. */ BOOL stub_manager_notify_unmarshal(struct stub_manager *m, const IPID *ipid) { diff --git a/dlls/ole32/tests/marshal.c b/dlls/ole32/tests/marshal.c index f76505ea1bf..cbad8c46117 100644 --- a/dlls/ole32/tests/marshal.c +++ b/dlls/ole32/tests/marshal.c @@ -589,7 +589,7 @@ static HRESULT WINAPI RpcProxyBuffer_Connect(IRpcProxyBuffer *iface, IRpcChannel
hr = IRpcChannelBuffer_GetDestCtx(pRpcChannelBuffer, &dest_context, &dest_context_data); ok(hr == S_OK, "GetDestCtx failed: %08lx\n", hr); - todo_wine ok(dest_context == MSHCTX_INPROC, "desc_context = %lx\n", dest_context); + ok(dest_context == MSHCTX_INPROC, "desc_context = %lx\n", dest_context); ok(!dest_context_data, "desc_context_data = %p\n", dest_context_data);
return IRpcProxyBuffer_Connect(This->buffer, pRpcChannelBuffer);
Huw Davies (@huw) commented about dlls/combase/marshal.c:
hr = get_unmarshaler_from_stream(stream, &marshal, &iid); if (hr == S_FALSE) {
hr = std_unmarshal_interface(0, NULL, stream, &iid, (void **)&object);
hr = std_unmarshal_interface(0, NULL, stream, &iid, (void **)&object, FALSE);
Could we not call `ipid_get_dest_context()` here so that we can pass the correct context to `std_unmarshal_interface()` without the need for the additional final arg?
On Tue Jul 11 08:22:32 2023 +0000, Huw Davies wrote:
Could we not call `ipid_get_dest_context()` here so that we can pass the correct context to `std_unmarshal_interface()` without the need for the additional final arg?
We could, but we'd need to pull the ipid from the IStream, which is already done in `std_unmarshal_interface()`. It would probably be cleaner to do it without the extra argument, but then we'd need to pull the data from the IStream twice, once in `CoUnmarshalInterface()`, and then again in `std_unmarshal_interface()`.
On Tue Jul 11 11:27:34 2023 +0000, Connor McAdams wrote:
We could, but we'd need to pull the ipid from the IStream, which is already done in `std_unmarshal_interface()`. It would probably be cleaner to do it without the extra argument, but then we'd need to pull the data from the IStream twice, once in `CoUnmarshalInterface()`, and then again in `std_unmarshal_interface()`.
I guess we could add a helper function like: ``` static HRESULT std_unmarshal_get_dest_context(IStream *stream, MSHCTX *dest_context, void **dest_context_data) { struct OR_STANDARD obj; struct apartment *apt; IStream *stream2; HRESULT hr; ULONG res;
*dest_context = 0; *dest_context_data = NULL;
if (!(apt = apartment_get_current_or_mta())) return CO_E_NOTINITIALIZED;
hr = IStream_Clone(stream, &stream2); if (FAILED(hr)) { apartment_release(apt); return hr; }
hr = IStream_Read(stream2, &obj, FIELD_OFFSET(struct OR_STANDARD, saResAddr.aStringArray), &res); IStream_Release(stream2); if (hr != S_OK) { apartment_release(apt); return STG_E_READFAULT; }
hr = ipid_get_dest_context(&obj.std.ipid, dest_context, dest_context_data); apartment_release(apt);
return hr; } ```
Which we would call from `CoUnmarshalInterface()` prior to calling `std_unmarshal_interface()`.
On Tue Jul 11 12:50:06 2023 +0000, Connor McAdams wrote:
I guess we could add a helper function like:
static HRESULT std_unmarshal_get_dest_context(IStream *stream, MSHCTX *dest_context, void **dest_context_data) { struct OR_STANDARD obj; struct apartment *apt; IStream *stream2; HRESULT hr; ULONG res; *dest_context = 0; *dest_context_data = NULL; if (!(apt = apartment_get_current_or_mta())) return CO_E_NOTINITIALIZED; hr = IStream_Clone(stream, &stream2); if (FAILED(hr)) { apartment_release(apt); return hr; } hr = IStream_Read(stream2, &obj, FIELD_OFFSET(struct OR_STANDARD, saResAddr.aStringArray), &res); IStream_Release(stream2); if (hr != S_OK) { apartment_release(apt); return STG_E_READFAULT; } hr = ipid_get_dest_context(&obj.std.ipid, dest_context, dest_context_data); apartment_release(apt); return hr; }
Which we would call from `CoUnmarshalInterface()` prior to calling `std_unmarshal_interface()`.
Yeah, that's not great.
We could overload `dest_context` with an 'unknown' value, but your current version is fine.
This merge request was approved by Huw Davies.