Module: wine Branch: master Commit: 28b916b26eceb35a071e9aad76da1fbbaa857ba1 URL: http://source.winehq.org/git/wine.git/?a=commit;h=28b916b26eceb35a071e9aad76...
Author: Huw Davies huw@codeweavers.com Date: Tue Oct 27 14:42:34 2015 +0000
ole32: Fix ref counting in GetDataHere Proxy.
Signed-off-by: Huw Davies huw@codeweavers.com Signed-off-by: Alexandre Julliard julliard@winehq.org
---
dlls/ole32/tests/usrmarshal.c | 236 +++++++++++++++++++++++++++++++++++++++++- dlls/ole32/usrmarshal.c | 38 +++++-- 2 files changed, 267 insertions(+), 7 deletions(-)
diff --git a/dlls/ole32/tests/usrmarshal.c b/dlls/ole32/tests/usrmarshal.c index c68689c..529ad75 100644 --- a/dlls/ole32/tests/usrmarshal.c +++ b/dlls/ole32/tests/usrmarshal.c @@ -85,6 +85,91 @@ static void init_user_marshal_cb(USER_MARSHAL_CB *umcb, umcb->CBType = buffer ? USER_MARSHAL_CB_UNMARSHALL : USER_MARSHAL_CB_BUFFER_SIZE; }
+#define RELEASEMARSHALDATA WM_USER + +struct host_object_data +{ + IStream *stream; + IID iid; + IUnknown *object; + MSHLFLAGS marshal_flags; + HANDLE marshal_event; + IMessageFilter *filter; +}; + +static DWORD CALLBACK host_object_proc(LPVOID p) +{ + struct host_object_data *data = p; + HRESULT hr; + MSG msg; + + CoInitializeEx(NULL, COINIT_APARTMENTTHREADED); + + if (data->filter) + { + IMessageFilter * prev_filter = NULL; + hr = CoRegisterMessageFilter(data->filter, &prev_filter); + if (prev_filter) IMessageFilter_Release(prev_filter); + ok(hr == S_OK, "got %08x\n", hr); + } + + hr = CoMarshalInterface(data->stream, &data->iid, data->object, MSHCTX_INPROC, NULL, data->marshal_flags); + ok(hr == S_OK, "got %08x\n", hr); + + /* force the message queue to be created before signaling parent thread */ + PeekMessageA(&msg, NULL, WM_USER, WM_USER, PM_NOREMOVE); + + SetEvent(data->marshal_event); + + while (GetMessageA(&msg, NULL, 0, 0)) + { + if (msg.hwnd == NULL && msg.message == RELEASEMARSHALDATA) + { + CoReleaseMarshalData(data->stream); + SetEvent((HANDLE)msg.lParam); + } + else + DispatchMessageA(&msg); + } + + HeapFree(GetProcessHeap(), 0, data); + + CoUninitialize(); + + return hr; +} + +static DWORD start_host_object2(IStream *stream, REFIID riid, IUnknown *object, MSHLFLAGS marshal_flags, IMessageFilter *filter, HANDLE *thread) +{ + DWORD tid = 0; + HANDLE marshal_event = CreateEventA(NULL, FALSE, FALSE, NULL); + struct host_object_data *data = HeapAlloc(GetProcessHeap(), 0, sizeof(*data)); + + data->stream = stream; + data->iid = *riid; + data->object = object; + data->marshal_flags = marshal_flags; + data->marshal_event = marshal_event; + data->filter = filter; + + *thread = CreateThread(NULL, 0, host_object_proc, data, 0, &tid); + + /* wait for marshaling to complete before returning */ + ok( !WaitForSingleObject(marshal_event, 10000), "wait timed out\n" ); + CloseHandle(marshal_event); + + return tid; +} + +static void end_host_object(DWORD tid, HANDLE thread) +{ + BOOL ret = PostThreadMessageA(tid, WM_QUIT, 0, 0); + ok(ret, "PostThreadMessage failed with error %d\n", GetLastError()); + /* be careful of races - don't return until hosting thread has terminated */ + ok( !WaitForSingleObject(thread, 10000), "wait timed out\n" ); + CloseHandle(thread); +} + static const char cf_marshaled[] = { 0x9, 0x0, 0x0, 0x0, @@ -1105,9 +1190,156 @@ static void test_marshal_HBRUSH(void) DeleteObject(hBrush); }
+struct obj +{ + IDataObject IDataObject_iface; +}; + +static HRESULT WINAPI obj_QueryInterface(IDataObject *iface, REFIID iid, void **obj) +{ + *obj = NULL; + + if (IsEqualGUID(iid, &IID_IUnknown) || + IsEqualGUID(iid, &IID_IDataObject)) + *obj = iface; + + if (*obj) + { + IDataObject_AddRef(iface); + return S_OK; + } + + return E_NOINTERFACE; +} + +static ULONG WINAPI obj_AddRef(IDataObject *iface) +{ + return 2; +} + +static ULONG WINAPI obj_Release(IDataObject *iface) +{ + return 1; +} + +static HRESULT WINAPI obj_DO_GetDataHere(IDataObject *iface, FORMATETC *fmt, + STGMEDIUM *med) +{ + ok( med->pUnkForRelease == NULL, "got %p\n", med->pUnkForRelease ); + + if (fmt->cfFormat == 2) + { + IStream_Release(U(med)->pstm); + U(med)->pstm = &Test_Stream2.IStream_iface; + } + + return S_OK; +} + +static const IDataObjectVtbl obj_data_object_vtbl = +{ + obj_QueryInterface, + obj_AddRef, + obj_Release, + NULL, /* GetData */ + obj_DO_GetDataHere, + NULL, /* QueryGetData */ + NULL, /* GetCanonicalFormatEtc */ + NULL, /* SetData */ + NULL, /* EnumFormatEtc */ + NULL, /* DAdvise */ + NULL, /* DUnadvise */ + NULL /* EnumDAdvise */ +}; + +static struct obj obj = +{ + {&obj_data_object_vtbl} +}; + +static void test_GetDataHere_Proxy(void) +{ + HRESULT hr; + IStream *stm; + HANDLE thread; + DWORD tid; + static const LARGE_INTEGER zero; + IDataObject *data; + FORMATETC fmt; + STGMEDIUM med; + + hr = CreateStreamOnHGlobal( NULL, TRUE, &stm ); + ok( hr == S_OK, "got %08x\n", hr ); + tid = start_host_object2( stm, &IID_IDataObject, (IUnknown *)&obj.IDataObject_iface, MSHLFLAGS_NORMAL, NULL, &thread ); + + IStream_Seek( stm, zero, STREAM_SEEK_SET, NULL ); + hr = CoUnmarshalInterface( stm, &IID_IDataObject, (void **)&data ); + ok( hr == S_OK, "got %08x\n", hr ); + IStream_Release( stm ); + + Test_Stream.refs = 1; + Test_Stream2.refs = 1; + Test_Unknown.refs = 1; + + fmt.cfFormat = 1; + fmt.ptd = NULL; + fmt.dwAspect = DVASPECT_CONTENT; + fmt.lindex = -1; + U(med).pstm = NULL; + med.pUnkForRelease = &Test_Unknown.IUnknown_iface; + + fmt.tymed = med.tymed = TYMED_NULL; + hr = IDataObject_GetDataHere( data, &fmt, &med ); + ok( hr == DV_E_TYMED, "got %08x\n", hr ); + + for (fmt.tymed = TYMED_HGLOBAL; fmt.tymed <= TYMED_ENHMF; fmt.tymed <<= 1) + { + med.tymed = fmt.tymed; + hr = IDataObject_GetDataHere( data, &fmt, &med ); + ok( hr == (fmt.tymed <= TYMED_ISTORAGE ? S_OK : DV_E_TYMED), "got %08x for tymed %d\n", hr, fmt.tymed ); + ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs ); + } + + fmt.tymed = TYMED_ISTREAM; + med.tymed = TYMED_ISTORAGE; + hr = IDataObject_GetDataHere( data, &fmt, &med ); + ok( hr == DV_E_TYMED, "got %08x\n", hr ); + + fmt.tymed = med.tymed = TYMED_ISTREAM; + U(med).pstm = &Test_Stream.IStream_iface; + med.pUnkForRelease = &Test_Unknown.IUnknown_iface; + + hr = IDataObject_GetDataHere( data, &fmt, &med ); + ok( hr == S_OK, "got %08x\n", hr ); + + ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" ); + ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" ); + + ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs ); + ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs ); + + fmt.cfFormat = 2; + fmt.tymed = med.tymed = TYMED_ISTREAM; + U(med).pstm = &Test_Stream.IStream_iface; + med.pUnkForRelease = &Test_Unknown.IUnknown_iface; + + hr = IDataObject_GetDataHere( data, &fmt, &med ); + ok( hr == S_OK, "got %08x\n", hr ); + + ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" ); + ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" ); + + ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs ); + ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs ); + ok( Test_Stream2.refs == 0, "got %d\n", Test_Stream2.refs ); + + IDataObject_Release( data ); + end_host_object( tid, thread ); +} + START_TEST(usrmarshal) { - CoInitialize(NULL); + CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
test_marshal_CLIPFORMAT(); test_marshal_HWND(); @@ -1122,5 +1354,7 @@ START_TEST(usrmarshal) test_marshal_HICON(); test_marshal_HBRUSH();
+ test_GetDataHere_Proxy(); + CoUninitialize(); } diff --git a/dlls/ole32/usrmarshal.c b/dlls/ole32/usrmarshal.c index 1a3f6af..89d0675 100644 --- a/dlls/ole32/usrmarshal.c +++ b/dlls/ole32/usrmarshal.c @@ -2783,13 +2783,39 @@ HRESULT __RPC_STUB IDataObject_GetData_Stub( return IDataObject_GetData(This, pformatetcIn, pRemoteMedium); }
-HRESULT CALLBACK IDataObject_GetDataHere_Proxy( - IDataObject* This, - FORMATETC *pformatetc, - STGMEDIUM *pmedium) +HRESULT CALLBACK IDataObject_GetDataHere_Proxy(IDataObject *iface, FORMATETC *fmt, STGMEDIUM *med) { - TRACE("(%p)->(%p, %p)\n", This, pformatetc, pmedium); - return IDataObject_RemoteGetDataHere_Proxy(This, pformatetc, pmedium); + IUnknown *release; + IStorage *stg = NULL; + HRESULT hr; + + TRACE("(%p)->(%p, %p)\n", iface, fmt, med); + + if ((med->tymed & (TYMED_HGLOBAL | TYMED_FILE | TYMED_ISTREAM | TYMED_ISTORAGE)) == 0) + return DV_E_TYMED; + if (med->tymed != fmt->tymed) + return DV_E_TYMED; + + release = med->pUnkForRelease; + med->pUnkForRelease = NULL; + + if (med->tymed == TYMED_ISTREAM || med->tymed == TYMED_ISTORAGE) + { + stg = med->u.pstg; /* This may actually be a stream, but that's ok */ + if (stg) IStorage_AddRef( stg ); + } + + hr = IDataObject_RemoteGetDataHere_Proxy(iface, fmt, med); + + med->pUnkForRelease = release; + if (stg) + { + if (med->u.pstg) + IStorage_Release( med->u.pstg ); + med->u.pstg = stg; + } + + return hr; }
HRESULT __RPC_STUB IDataObject_GetDataHere_Stub(