From: Zhiyi Zhang zzhang@codeweavers.com
--- dlls/ntdll/tests/file.c | 108 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+)
diff --git a/dlls/ntdll/tests/file.c b/dlls/ntdll/tests/file.c index a372bd170e1..4a308f3e577 100644 --- a/dlls/ntdll/tests/file.c +++ b/dlls/ntdll/tests/file.c @@ -54,6 +54,7 @@ static NTSTATUS (WINAPI *pRtlWow64EnableFsRedirectionEx)( ULONG, ULONG * ); static NTSTATUS (WINAPI *pNtAllocateReserveObject)( HANDLE *, const OBJECT_ATTRIBUTES *, MEMORY_RESERVE_OBJECT_TYPE ); static NTSTATUS (WINAPI *pNtAssociateWaitCompletionPacket)(HANDLE, HANDLE, HANDLE, PVOID, PVOID, NTSTATUS, ULONG_PTR, PBOOLEAN); +static NTSTATUS (WINAPI *pNtCancelWaitCompletionPacket)(HANDLE, BOOLEAN); static NTSTATUS (WINAPI *pNtCreateEvent)(PHANDLE, ACCESS_MASK, const OBJECT_ATTRIBUTES *, EVENT_TYPE, BOOLEAN); static NTSTATUS (WINAPI *pNtCreateMailslotFile)( PHANDLE, ULONG, POBJECT_ATTRIBUTES, PIO_STATUS_BLOCK, ULONG, ULONG, ULONG, PLARGE_INTEGER ); @@ -6555,6 +6556,111 @@ static void test_associate_wait_completion_packet(void) ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); }
+static void test_cancel_wait_completion_packet(void) +{ + static const char pipe_name[] = "\\.\pipe\test_cancel_wait_completion_packet"; + UNICODE_STRING packet_name = RTL_CONSTANT_STRING(L"\BaseNamedObjects\test_cancel_wait_completion_packet"); + BYTE send_buf[TEST_BUF_LEN], recv_buf[TEST_BUF_LEN]; + HANDLE completion, packet, server, client; + DWORD read_bytes, written_bytes; + OVERLAPPED overlapped = {0}; + OBJECT_ATTRIBUTES attr; + ULONG completion_count; + BOOLEAN signaled; + NTSTATUS status; + + if (!pNtCancelWaitCompletionPacket) + { + todo_wine + win_skip("NtCancelWaitCompletionPacket is unavailable.\n"); + return; + } + + status = pNtCreateIoCompletion(&completion, IO_COMPLETION_ALL_ACCESS, NULL, 0); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + + server = CreateNamedPipeA(pipe_name, PIPE_ACCESS_INBOUND | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_WAIT, 4, 1024, 1024, + 1000, NULL); + ok(server != INVALID_HANDLE_VALUE, "CreateNamedPipe failed, error %lu.\n", GetLastError()); + client = CreateFileA(pipe_name, GENERIC_WRITE, 0, NULL, OPEN_EXISTING, + FILE_FLAG_NO_BUFFERING | FILE_FLAG_OVERLAPPED, NULL); + ok(client != INVALID_HANDLE_VALUE, "CreateFile failed, error %lu.\n", GetLastError()); + + InitializeObjectAttributes(&attr, &packet_name, 0, NULL, NULL); + status = pNtCreateWaitCompletionPacket(&packet, GENERIC_ALL, &attr); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + + /* Remove not associated packets */ + status = pNtCancelWaitCompletionPacket(packet, TRUE); + ok(status == STATUS_CANCELLED, "Got unexpected status %#lx.\n", status); + + /* ReadFile() should be called before NtAssociateWaitCompletionPacket(). Otherwise, the server + * object stays signaled */ + memset(send_buf, 0xa1, TEST_BUF_LEN); + memset(recv_buf, 0xb2, TEST_BUF_LEN); + ReadFile(server, recv_buf, TEST_BUF_LEN, &read_bytes, &overlapped); + + /* Cancel non-signaled packets */ + status = pNtAssociateWaitCompletionPacket(packet, completion, server, (void *)1, (void *)2, + STATUS_SUCCESS, (ULONG_PTR)3, &signaled); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + ok(signaled == FALSE, "Got unexpected signaled %d.\n", signaled); + completion_count = get_pending_msgs(completion); + ok(!completion_count, "Got unexpected completion count %ld.\n", completion_count); + + status = pNtCancelWaitCompletionPacket(packet, FALSE); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + + WriteFile(client, send_buf, TEST_BUF_LEN, &written_bytes, NULL); + completion_count = get_pending_msgs(completion); + ok(!completion_count, "Got unexpected completion count %ld.\n", completion_count); + + /* Cancel signaled packets */ + ReadFile(server, recv_buf, TEST_BUF_LEN, &read_bytes, &overlapped); + + status = pNtAssociateWaitCompletionPacket(packet, completion, server, (void *)1, (void *)2, + STATUS_SUCCESS, (ULONG_PTR)3, &signaled); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + ok(signaled == FALSE, "Got unexpected signaled %d.\n", signaled); + completion_count = get_pending_msgs(completion); + ok(!completion_count, "Got unexpected completion count %ld.\n", completion_count); + + WriteFile(client, send_buf, TEST_BUF_LEN, &written_bytes, NULL); + completion_count = get_pending_msgs(completion); + ok(completion_count == 1, "Got unexpected completion count %ld.\n", completion_count); + + status = pNtCancelWaitCompletionPacket(packet, FALSE); + ok(status == STATUS_PENDING, "Got unexpected status %#lx.\n", status); + completion_count = get_pending_msgs(completion); + ok(completion_count == 1, "Got unexpected completion count %ld.\n", completion_count); + + status = pNtCancelWaitCompletionPacket(packet, TRUE); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + completion_count = get_pending_msgs(completion); + ok(!completion_count, "Got unexpected completion count %ld.\n", completion_count); + + /* Remove already removed packets */ + status = pNtCancelWaitCompletionPacket(packet, TRUE); + ok(status == STATUS_CANCELLED, "Got unexpected status %#lx.\n", status); + + /* Parameter checks */ + status = pNtCancelWaitCompletionPacket(NULL, FALSE); + ok(status == STATUS_INVALID_HANDLE, "Got unexpected status %#lx.\n", status); + + status = pNtCancelWaitCompletionPacket(INVALID_HANDLE_VALUE, FALSE); + ok(status == STATUS_OBJECT_TYPE_MISMATCH, "Got unexpected status %#lx.\n", status); + + status = pNtClose(packet); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + status = pNtClose(client); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + status = pNtClose(server); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); + status = pNtClose(completion); + ok(status == STATUS_SUCCESS, "Got unexpected status %#lx.\n", status); +} + START_TEST(file) { HMODULE hkernel32 = GetModuleHandleA("kernel32.dll"); @@ -6574,6 +6680,7 @@ START_TEST(file) pRtlWow64EnableFsRedirectionEx = (void *)GetProcAddress(hntdll, "RtlWow64EnableFsRedirectionEx"); pNtAllocateReserveObject= (void *)GetProcAddress(hntdll, "NtAllocateReserveObject"); pNtAssociateWaitCompletionPacket = (void *)GetProcAddress(hntdll, "NtAssociateWaitCompletionPacket"); + pNtCancelWaitCompletionPacket = (void *)GetProcAddress(hntdll, "NtCancelWaitCompletionPacket"); pNtCreateEvent = (void *)GetProcAddress(hntdll, "NtCreateEvent"); pNtCreateMailslotFile = (void *)GetProcAddress(hntdll, "NtCreateMailslotFile"); pNtCreateFile = (void *)GetProcAddress(hntdll, "NtCreateFile"); @@ -6644,4 +6751,5 @@ START_TEST(file) test_reparse_points(); test_create_wait_completion_packet(); test_associate_wait_completion_packet(); + test_cancel_wait_completion_packet(); }