Signed-off-by: Zebediah Figura z.figura12@gmail.com --- dlls/ntdll/tests/threadpool.c | 228 ++++++++++++++++++++++++++++------ 1 file changed, 190 insertions(+), 38 deletions(-)
diff --git a/dlls/ntdll/tests/threadpool.c b/dlls/ntdll/tests/threadpool.c index 32d4c3eac2..2984983a06 100644 --- a/dlls/ntdll/tests/threadpool.c +++ b/dlls/ntdll/tests/threadpool.c @@ -22,7 +22,6 @@ #define NONAMELESSUNION #include "ntdll_test.h"
-static HMODULE hntdll = 0; static NTSTATUS (WINAPI *pTpAllocCleanupGroup)(TP_CLEANUP_GROUP **); static NTSTATUS (WINAPI *pTpAllocIoCompletion)(TP_IO **,HANDLE,PTP_IO_CALLBACK,void *,TP_CALLBACK_ENVIRON *); static NTSTATUS (WINAPI *pTpAllocPool)(TP_POOL **,PVOID); @@ -52,51 +51,58 @@ static VOID (WINAPI *pTpWaitForTimer)(TP_TIMER *,BOOL); static VOID (WINAPI *pTpWaitForWait)(TP_WAIT *,BOOL); static VOID (WINAPI *pTpWaitForWork)(TP_WORK *,BOOL);
-#define NTDLL_GET_PROC(func) \ +static void (WINAPI *pCancelThreadpoolIo)(TP_IO *); +static void (WINAPI *pCloseThreadpoolIo)(TP_IO *); +static TP_IO *(WINAPI *pCreateThreadpoolIo)(HANDLE, PTP_WIN32_IO_CALLBACK, void *, TP_CALLBACK_ENVIRON *); +static void (WINAPI *pStartThreadpoolIo)(TP_IO *); +static void (WINAPI *pWaitForThreadpoolIoCallbacks)(TP_IO *, BOOL); + +#define GET_PROC(func) \ do \ { \ - p ## func = (void *)GetProcAddress(hntdll, #func); \ + p ## func = (void *)GetProcAddress(module, #func); \ if (!p ## func) trace("Failed to get address for %s\n", #func); \ } \ while (0)
static BOOL init_threadpool(void) { - hntdll = GetModuleHandleA("ntdll"); - if (!hntdll) - { - win_skip("Could not load ntdll\n"); - return FALSE; - } - - NTDLL_GET_PROC(TpAllocCleanupGroup); - NTDLL_GET_PROC(TpAllocIoCompletion); - NTDLL_GET_PROC(TpAllocPool); - NTDLL_GET_PROC(TpAllocTimer); - NTDLL_GET_PROC(TpAllocWait); - NTDLL_GET_PROC(TpAllocWork); - NTDLL_GET_PROC(TpCancelAsyncIoOperation); - NTDLL_GET_PROC(TpCallbackMayRunLong); - NTDLL_GET_PROC(TpCallbackReleaseSemaphoreOnCompletion); - NTDLL_GET_PROC(TpDisassociateCallback); - NTDLL_GET_PROC(TpIsTimerSet); - NTDLL_GET_PROC(TpPostWork); - NTDLL_GET_PROC(TpReleaseCleanupGroup); - NTDLL_GET_PROC(TpReleaseCleanupGroupMembers); - NTDLL_GET_PROC(TpReleaseIoCompletion); - NTDLL_GET_PROC(TpReleasePool); - NTDLL_GET_PROC(TpReleaseTimer); - NTDLL_GET_PROC(TpReleaseWait); - NTDLL_GET_PROC(TpReleaseWork); - NTDLL_GET_PROC(TpSetPoolMaxThreads); - NTDLL_GET_PROC(TpSetTimer); - NTDLL_GET_PROC(TpSetWait); - NTDLL_GET_PROC(TpSimpleTryPost); - NTDLL_GET_PROC(TpStartAsyncIoOperation); - NTDLL_GET_PROC(TpWaitForIoCompletion); - NTDLL_GET_PROC(TpWaitForTimer); - NTDLL_GET_PROC(TpWaitForWait); - NTDLL_GET_PROC(TpWaitForWork); + HMODULE module = GetModuleHandleA("ntdll"); + GET_PROC(TpAllocCleanupGroup); + GET_PROC(TpAllocIoCompletion); + GET_PROC(TpAllocPool); + GET_PROC(TpAllocTimer); + GET_PROC(TpAllocWait); + GET_PROC(TpAllocWork); + GET_PROC(TpCallbackMayRunLong); + GET_PROC(TpCallbackReleaseSemaphoreOnCompletion); + GET_PROC(TpCancelAsyncIoOperation); + GET_PROC(TpDisassociateCallback); + GET_PROC(TpIsTimerSet); + GET_PROC(TpPostWork); + GET_PROC(TpReleaseCleanupGroup); + GET_PROC(TpReleaseCleanupGroupMembers); + GET_PROC(TpReleaseIoCompletion); + GET_PROC(TpReleasePool); + GET_PROC(TpReleaseTimer); + GET_PROC(TpReleaseWait); + GET_PROC(TpReleaseWork); + GET_PROC(TpSetPoolMaxThreads); + GET_PROC(TpSetTimer); + GET_PROC(TpSetWait); + GET_PROC(TpSimpleTryPost); + GET_PROC(TpStartAsyncIoOperation); + GET_PROC(TpWaitForIoCompletion); + GET_PROC(TpWaitForTimer); + GET_PROC(TpWaitForWait); + GET_PROC(TpWaitForWork); + + module = GetModuleHandleA("kernel32"); + GET_PROC(CancelThreadpoolIo); + GET_PROC(CloseThreadpoolIo); + GET_PROC(CreateThreadpoolIo); + GET_PROC(StartThreadpoolIo); + GET_PROC(WaitForThreadpoolIoCallbacks);
if (!pTpAllocPool) { @@ -2084,6 +2090,151 @@ static void test_tp_io(void) pTpReleasePool(pool); }
+static void CALLBACK kernel32_io_cb(TP_CALLBACK_INSTANCE *instance, void *userdata, + void *ovl, ULONG ret, ULONG_PTR length, TP_IO *io) +{ + struct io_cb_ctx *ctx = userdata; + ++ctx->count; + ctx->ovl = ovl; + ctx->ret = ret; + ctx->length = length; + ctx->io = io; +} + +static void test_kernel32_tp_io(void) +{ + TP_CALLBACK_ENVIRON environment = {.Version = 1}; + OVERLAPPED ovl = {}, ovl2 = {}; + HANDLE client, server, thread; + struct io_cb_ctx userdata; + char in[1], in2[1]; + const char out[1]; + NTSTATUS status; + DWORD ret_size; + TP_POOL *pool; + TP_IO *io; + BOOL ret; + + ovl.hEvent = CreateEventW(NULL, TRUE, FALSE, NULL); + + status = pTpAllocPool(&pool, NULL); + ok(!status, "failed to allocate pool, status %#x\n", status); + + server = CreateNamedPipeA("\\.\pipe\wine_tp_test", + PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, 0, 1, 1024, 1024, 0, NULL); + ok(server != INVALID_HANDLE_VALUE, "Failed to create server pipe, error %u.\n", GetLastError()); + client = CreateFileA("\\.\pipe\wine_tp_test", GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, 0, 0); + ok(client != INVALID_HANDLE_VALUE, "Failed to create client pipe, error %u.\n", GetLastError()); + + environment.Pool = pool; + io = NULL; + io = pCreateThreadpoolIo(server, kernel32_io_cb, &userdata, &environment); + todo_wine ok(!!io, "expected non-NULL TP_IO\n"); + if (!io) + return; + + pWaitForThreadpoolIoCallbacks(io, FALSE); + + userdata.count = 0; + pStartThreadpoolIo(io); + + thread = CreateThread(NULL, 0, io_wait_thread, io, 0, NULL); + ok(WaitForSingleObject(thread, 100) == WAIT_TIMEOUT, "TpWaitForIoCompletion() should not return\n"); + + ret = ReadFile(server, in, sizeof(in), NULL, &ovl); + ok(!ret, "wrong ret %d\n", ret); + ok(GetLastError() == ERROR_IO_PENDING, "wrong error %u\n", GetLastError()); + + ret = WriteFile(client, out, sizeof(out), &ret_size, NULL); + ok(ret, "WriteFile() failed, error %u\n", GetLastError()); + + pWaitForThreadpoolIoCallbacks(io, FALSE); + ok(userdata.count == 1, "callback ran %u times\n", userdata.count); + ok(userdata.ovl == &ovl, "expected %p, got %p\n", &ovl, userdata.ovl); + ok(userdata.ret == ERROR_SUCCESS, "got status %#x\n", userdata.ret); + ok(userdata.length == 1, "got length %lu\n", userdata.length); + ok(userdata.io == io, "expected %p, got %p\n", io, userdata.io); + + ok(!WaitForSingleObject(thread, 1000), "wait timed out\n"); + CloseHandle(thread); + + userdata.count = 0; + pStartThreadpoolIo(io); + pStartThreadpoolIo(io); + + ret = ReadFile(server, in, sizeof(in), NULL, &ovl); + ok(!ret, "wrong ret %d\n", ret); + ok(GetLastError() == ERROR_IO_PENDING, "wrong error %u\n", GetLastError()); + ret = ReadFile(server, in2, sizeof(in2), NULL, &ovl2); + ok(!ret, "wrong ret %d\n", ret); + ok(GetLastError() == ERROR_IO_PENDING, "wrong error %u\n", GetLastError()); + + ret = WriteFile(client, out, sizeof(out), &ret_size, NULL); + ok(ret, "WriteFile() failed, error %u\n", GetLastError()); + ret = WriteFile(client, out, sizeof(out), &ret_size, NULL); + ok(ret, "WriteFile() failed, error %u\n", GetLastError()); + + pWaitForThreadpoolIoCallbacks(io, FALSE); + ok(userdata.count == 2, "callback ran %u times\n", userdata.count); + ok(userdata.ret == STATUS_SUCCESS, "got status %#x\n", userdata.ret); + ok(userdata.length == 1, "got length %lu\n", userdata.length); + ok(userdata.io == io, "expected %p, got %p\n", io, userdata.io); + + userdata.count = 0; + pStartThreadpoolIo(io); + pWaitForThreadpoolIoCallbacks(io, TRUE); + ok(!userdata.count, "callback ran %u times\n", userdata.count); + + pStartThreadpoolIo(io); + + ret = WriteFile(client, out, sizeof(out), &ret_size, NULL); + ok(ret, "WriteFile() failed, error %u\n", GetLastError()); + + ret = ReadFile(server, in, sizeof(in), NULL, &ovl); + ok(ret, "wrong ret %d\n", ret); + + pWaitForThreadpoolIoCallbacks(io, FALSE); + ok(userdata.count == 1, "callback ran %u times\n", userdata.count); + ok(userdata.ovl == &ovl, "expected %p, got %p\n", &ovl, userdata.ovl); + ok(userdata.ret == ERROR_SUCCESS, "got status %#x\n", userdata.ret); + ok(userdata.length == 1, "got length %lu\n", userdata.length); + ok(userdata.io == io, "expected %p, got %p\n", io, userdata.io); + + userdata.count = 0; + pStartThreadpoolIo(io); + + ret = ReadFile(server, NULL, 1, NULL, &ovl); + ok(!ret, "wrong ret %d\n", ret); + ok(GetLastError() == ERROR_NOACCESS, "wrong error %u\n", GetLastError()); + + pCancelThreadpoolIo(io); + pWaitForThreadpoolIoCallbacks(io, FALSE); + ok(!userdata.count, "callback ran %u times\n", userdata.count); + + userdata.count = 0; + pStartThreadpoolIo(io); + + ret = ReadFile(server, in, sizeof(in), NULL, &ovl); + ok(!ret, "wrong ret %d\n", ret); + ok(GetLastError() == ERROR_IO_PENDING, "wrong error %u\n", GetLastError()); + ret = CancelIo(server); + ok(ret, "CancelIo() failed, error %u\n", GetLastError()); + + pWaitForThreadpoolIoCallbacks(io, FALSE); + ok(userdata.count == 1, "callback ran %u times\n", userdata.count); + ok(userdata.ovl == &ovl, "expected %p, got %p\n", &ovl, userdata.ovl); + ok(userdata.ret == ERROR_OPERATION_ABORTED, "got status %#x\n", userdata.ret); + ok(!userdata.length, "got length %lu\n", userdata.length); + ok(userdata.io == io, "expected %p, got %p\n", io, userdata.io); + + CloseHandle(ovl.hEvent); + CloseHandle(client); + CloseHandle(server); + pCloseThreadpoolIo(io); + pTpReleasePool(pool); +} + START_TEST(threadpool) { test_RtlQueueWorkItem(); @@ -2104,4 +2255,5 @@ START_TEST(threadpool) test_tp_wait(); test_tp_multi_wait(); test_tp_io(); + test_kernel32_tp_io(); }