Signed-off-by: Paul Gofman pgofman@codeweavers.com --- dlls/netio.sys/netio.c | 162 +++++++++++++++++++++++++++-------------- 1 file changed, 106 insertions(+), 56 deletions(-)
diff --git a/dlls/netio.sys/netio.c b/dlls/netio.sys/netio.c index ed99a087a58..6cadff2e87c 100644 --- a/dlls/netio.sys/netio.c +++ b/dlls/netio.sys/netio.c @@ -53,6 +53,15 @@ struct listen_socket_callback_context SOCKET acceptor; };
+#define MAX_PENDING_IO 10 + +struct wsk_pending_io +{ + OVERLAPPED ovr; + TP_WAIT *tp_wait; + IRP *irp; +}; + struct wsk_socket_internal { WSK_SOCKET wsk_socket; @@ -63,12 +72,11 @@ struct wsk_socket_internal ADDRESS_FAMILY address_family; USHORT socket_type; ULONG protocol; - OVERLAPPED ovr; - TP_WAIT *tp_wait; - IRP *pending_irp;
CRITICAL_SECTION cs_socket;
+ struct wsk_pending_io pending_io[MAX_PENDING_IO]; + union { struct listen_socket_callback_context listen_socket_callback_context; @@ -134,14 +142,9 @@ static inline void unlock_socket(struct wsk_socket_internal *socket) LeaveCriticalSection(&socket->cs_socket); }
-static void socket_init(struct wsk_socket_internal *socket, PTP_WAIT_CALLBACK socket_async_callback) +static void socket_init(struct wsk_socket_internal *socket) { InitializeCriticalSection(&socket->cs_socket); - if (socket_async_callback) - { - socket->ovr.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL); - socket->tp_wait = CreateThreadpoolWait(socket_async_callback, socket, NULL); - } }
static void dispatch_irp(IRP *irp, NTSTATUS status) @@ -152,6 +155,56 @@ static void dispatch_irp(IRP *irp, NTSTATUS status) IoCompleteRequest(irp, IO_NO_INCREMENT); }
+static struct wsk_pending_io *allocate_pending_io(struct wsk_socket_internal *socket, + PTP_WAIT_CALLBACK socket_async_callback, IRP *irp) +{ + struct wsk_pending_io *io = socket->pending_io; + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(socket->pending_io); ++i) + if (!io[i].irp) + break; + + if (i == ARRAY_SIZE(socket->pending_io)) + { + FIXME("Pending io requests count exceeds limit.\n"); + return NULL; + } + + io[i].irp = irp; + + if (io[i].tp_wait) + return &io[i]; + + io[i].ovr.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL); + io[i].tp_wait = CreateThreadpoolWait(socket_async_callback, socket, NULL); + + return &io[i]; +} + +static struct wsk_pending_io *find_pending_io(struct wsk_socket_internal *socket, TP_WAIT *tp_wait) +{ + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(socket->pending_io); ++i) + { + if (socket->pending_io[i].tp_wait == tp_wait) + return &socket->pending_io[i]; + } + + FIXME("Pending io not found for tp_wait %p.\n", tp_wait); + return NULL; +} + +static void dispatch_pending_io(struct wsk_pending_io *io, NTSTATUS status, ULONG_PTR information) +{ + TRACE("io %p, status %#x, information %#lx.\n", io, status, information); + + io->irp->IoStatus.Information = information; + dispatch_irp(io->irp, status); + io->irp = NULL; +} + static NTSTATUS WINAPI wsk_control_socket(WSK_SOCKET *socket, WSK_CONTROL_SOCKET_TYPE request_type, ULONG control_code, ULONG level, SIZE_T input_size, void *input_buffer, SIZE_T output_size, void *output_buffer, SIZE_T *output_size_returned, IRP *irp) @@ -168,18 +221,29 @@ static NTSTATUS WINAPI wsk_close_socket(WSK_SOCKET *socket, IRP *irp) { struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket); NTSTATUS status; + unsigned int i;
TRACE("socket %p, irp %p.\n", socket, irp);
lock_socket(s);
- if (s->tp_wait) + for (i = 0; i < ARRAY_SIZE(s->pending_io); ++i) { - CancelIoEx((HANDLE)s->s, &s->ovr); - unlock_socket(s); - WaitForThreadpoolWaitCallbacks(s->tp_wait, FALSE); - lock_socket(s); - CloseThreadpoolWait(s->tp_wait); + struct wsk_pending_io *io = &s->pending_io[i]; + + if (io->tp_wait) + { + CancelIoEx((HANDLE)s->s, &io->ovr); + SetThreadpoolWait(io->tp_wait, NULL, NULL); + unlock_socket(s); + WaitForThreadpoolWaitCallbacks(io->tp_wait, FALSE); + lock_socket(s); + CloseThreadpoolWait(io->tp_wait); + CloseHandle(io->ovr.hEvent); + } + + if (io->irp) + dispatch_pending_io(io, STATUS_CANCELLED, 0); }
if (s->flags & WSK_FLAG_LISTEN_SOCKET && s->callback_context.listen_socket_callback_context.acceptor) @@ -187,15 +251,6 @@ static NTSTATUS WINAPI wsk_close_socket(WSK_SOCKET *socket, IRP *irp)
status = closesocket(s->s) ? sock_error_to_ntstatus(WSAGetLastError()) : STATUS_SUCCESS;
- if (s->ovr.hEvent) - CloseHandle(s->ovr.hEvent); - - if (s->pending_irp) - { - s->pending_irp->IoStatus.Information = 0; - dispatch_irp(s->pending_irp, STATUS_CANCELLED); - } - unlock_socket(s); DeleteCriticalSection(&s->cs_socket); heap_free(socket); @@ -230,18 +285,16 @@ static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULO return STATUS_PENDING; }
-static void create_accept_socket(struct wsk_socket_internal *socket) +static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_pending_io *io) { struct listen_socket_callback_context *context = &socket->callback_context.listen_socket_callback_context; struct wsk_socket_internal *accept_socket; - NTSTATUS status;
if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket)))) { ERR("No memory.\n"); - status = STATUS_NO_MEMORY; - socket->pending_irp->IoStatus.Information = 0; + dispatch_pending_io(io, STATUS_NO_MEMORY, 0); } else { @@ -254,15 +307,11 @@ static void create_accept_socket(struct wsk_socket_internal *socket) accept_socket->address_family = socket->address_family; accept_socket->protocol = socket->protocol; accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET; - socket_init(accept_socket, NULL); + socket_init(accept_socket); /* TODO: fill local and remote addresses. */
- socket->pending_irp->IoStatus.Information = (ULONG_PTR)&accept_socket->wsk_socket; - status = STATUS_SUCCESS; + dispatch_pending_io(io, STATUS_SUCCESS, (ULONG_PTR)&accept_socket->wsk_socket); } - TRACE("status %#x.\n", status); - dispatch_irp(socket->pending_irp, status); - socket->pending_irp = NULL; }
static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait, @@ -270,24 +319,24 @@ static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_ { struct listen_socket_callback_context *context; struct wsk_socket_internal *socket = socket_; + struct wsk_pending_io *io; DWORD size;
TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result);
lock_socket(socket); context = &socket->callback_context.listen_socket_callback_context; + io = find_pending_io(socket, wait);
- if (GetOverlappedResult((HANDLE)socket->s, &socket->ovr, &size, FALSE)) + if (GetOverlappedResult((HANDLE)socket->s, &io->ovr, &size, FALSE)) { - create_accept_socket(socket); + create_accept_socket(socket, io); } else { closesocket(context->acceptor); context->acceptor = 0; - socket->pending_irp->IoStatus.Information = 0; - dispatch_irp(socket->pending_irp, socket->ovr.Internal); - socket->pending_irp = NULL; + dispatch_pending_io(io, io->ovr.Internal, 0); } unlock_socket(socket); } @@ -314,8 +363,8 @@ static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void * struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(listen_socket); static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT; struct listen_socket_callback_context *context; + struct wsk_pending_io *io; SOCKET acceptor; - NTSTATUS status; DWORD size; int error;
@@ -329,44 +378,47 @@ static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *
if (!InitOnceExecuteOnce(&init_once, init_accept_functions, (void *)s->s, NULL)) { - status = STATUS_UNSUCCESSFUL; - dispatch_irp(irp, status); - return status; + dispatch_irp(irp, STATUS_UNSUCCESSFUL); + return STATUS_PENDING; }
lock_socket(s); + if (!(io = allocate_pending_io(s, accept_callback, irp))) + { + irp->IoStatus.Information = 0; + dispatch_irp(irp, STATUS_UNSUCCESSFUL); + unlock_socket(s); + return STATUS_PENDING; + } + context = &s->callback_context.listen_socket_callback_context; if ((acceptor = WSASocketW(s->address_family, s->socket_type, s->protocol, NULL, 0, WSA_FLAG_OVERLAPPED)) == INVALID_SOCKET) { - status = sock_error_to_ntstatus(WSAGetLastError()); - dispatch_irp(irp, status); + dispatch_pending_io(io, sock_error_to_ntstatus(WSAGetLastError()), 0); unlock_socket(s); - return status; + return STATUS_PENDING; }
- s->pending_irp = irp; context->remote_address = remote_address; context->client_dispatch = accept_socket_dispatch; context->client_context = accept_socket_context; context->acceptor = acceptor;
if (pAcceptEx(s->s, acceptor, context->addr_buffer, 0, - sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, &size, &s->ovr)) + sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, &size, &io->ovr)) { - create_accept_socket(s); + create_accept_socket(s, io); } else if ((error = WSAGetLastError()) == ERROR_IO_PENDING) { - SetThreadpoolWait(s->tp_wait, s->ovr.hEvent, NULL); + SetThreadpoolWait(io->tp_wait, io->ovr.hEvent, NULL); } else { closesocket(acceptor); context->acceptor = 0; - irp->IoStatus.Information = 0; - dispatch_irp(irp, sock_error_to_ntstatus(error)); - s->pending_irp = NULL; + dispatch_pending_io(io, sock_error_to_ntstatus(error), 0); } unlock_socket(s);
@@ -490,7 +542,6 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam PETHREAD owning_thread, SECURITY_DESCRIPTOR *security_descriptor, IRP *irp) { struct wsk_socket_internal *socket; - PTP_WAIT_CALLBACK async_callback; NTSTATUS status; SOCKET s;
@@ -532,7 +583,6 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam { case WSK_FLAG_LISTEN_SOCKET: socket->wsk_socket.Dispatch = &wsk_provider_listen_dispatch; - async_callback = accept_callback; break;
case WSK_FLAG_CONNECTION_SOCKET: @@ -547,7 +597,7 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam goto done; }
- socket_init(socket, async_callback); + socket_init(socket);
irp->IoStatus.Information = (ULONG_PTR)&socket->wsk_socket; status = STATUS_SUCCESS;
Signed-off-by: Paul Gofman pgofman@codeweavers.com --- dlls/netio.sys/netio.c | 90 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 4 deletions(-)
diff --git a/dlls/netio.sys/netio.c b/dlls/netio.sys/netio.c index 6cadff2e87c..b4f3b42ed04 100644 --- a/dlls/netio.sys/netio.c +++ b/dlls/netio.sys/netio.c @@ -467,18 +467,100 @@ static NTSTATUS WINAPI wsk_get_remote_address(WSK_SOCKET *socket, SOCKADDR *remo return STATUS_NOT_IMPLEMENTED; }
+static void WINAPI send_receive_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait, + TP_WAIT_RESULT wait_result) +{ + struct wsk_socket_internal *socket = socket_; + struct wsk_pending_io *io; + DWORD length, flags; + + TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result); + + lock_socket(socket); + io = find_pending_io(socket, wait); + + if (WSAGetOverlappedResult(socket->s, &io->ovr, &length, FALSE, &flags)) + dispatch_pending_io(io, STATUS_SUCCESS, length); + else + dispatch_pending_io(io, io->ovr.Internal, 0); + + unlock_socket(socket); +} + +static NTSTATUS WINAPI do_send_receive(WSK_SOCKET *socket, WSK_BUF *wsk_buf, ULONG flags, IRP *irp, BOOL is_send) +{ + struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket); + struct wsk_pending_io *io; + DWORD wsa_flags; + WSABUF wsa_buf; + DWORD length; + int error; + + TRACE("socket %p, buffer %p, flags %#x, irp %p, is_send %#x.\n", + socket, wsk_buf, flags, irp, is_send); + + if (!irp) + return STATUS_INVALID_PARAMETER; + + if (!wsk_buf->Mdl && wsk_buf->Length) + return STATUS_INVALID_PARAMETER; + + if (wsk_buf->Mdl && wsk_buf->Mdl->Next) + { + FIXME("Chained MDLs are not supported.\n"); + irp->IoStatus.Information = 0; + dispatch_irp(irp, STATUS_UNSUCCESSFUL); + return STATUS_PENDING; + } + + if (flags) + FIXME("flags %#x not implemented.\n", flags); + + lock_socket(s); + if (!(io = allocate_pending_io(s, send_receive_callback, irp))) + { + irp->IoStatus.Information = 0; + dispatch_irp(irp, STATUS_UNSUCCESSFUL); + unlock_socket(s); + return STATUS_PENDING; + } + + wsa_buf.len = wsk_buf->Length; + wsa_buf.buf = wsk_buf->Mdl ? (CHAR *)wsk_buf->Mdl->StartVa + + wsk_buf->Mdl->ByteOffset + wsk_buf->Offset : NULL; + + wsa_flags = 0; + + if (!(is_send ? WSASend(s->s, &wsa_buf, 1, &length, wsa_flags, &io->ovr, NULL) + : WSARecv(s->s, &wsa_buf, 1, &length, &wsa_flags, &io->ovr, NULL))) + { + dispatch_pending_io(io, STATUS_SUCCESS, length); + } + else if ((error = WSAGetLastError()) == WSA_IO_PENDING) + { + SetThreadpoolWait(io->tp_wait, io->ovr.hEvent, NULL); + } + else + { + dispatch_pending_io(io, sock_error_to_ntstatus(error), 0); + } + unlock_socket(s); + + return STATUS_PENDING; +} + static NTSTATUS WINAPI wsk_send(WSK_SOCKET *socket, WSK_BUF *buffer, ULONG flags, IRP *irp) { - FIXME("socket %p, buffer %p, flags %#x, irp %p stub.\n", socket, buffer, flags, irp); + TRACE("socket %p, buffer %p, flags %#x, irp %p.\n", socket, buffer, flags, irp);
- return STATUS_NOT_IMPLEMENTED; + return do_send_receive(socket, buffer, flags, irp, TRUE); }
static NTSTATUS WINAPI wsk_receive(WSK_SOCKET *socket, WSK_BUF *buffer, ULONG flags, IRP *irp) { - FIXME("socket %p, buffer %p, flags %#x, irp %p stub.\n", socket, buffer, flags, irp); + TRACE("socket %p, buffer %p, flags %#x, irp %p.\n", socket, buffer, flags, irp);
- return STATUS_NOT_IMPLEMENTED; + return do_send_receive(socket, buffer, flags, irp, FALSE); }
static NTSTATUS WINAPI wsk_disconnect(WSK_SOCKET *socket, WSK_BUF *buffer, ULONG flags, IRP *irp)
Signed-off-by: Paul Gofman pgofman@codeweavers.com --- dlls/ntoskrnl.exe/tests/driver4.c | 51 ++++++++++++++++++++++++++++++ dlls/ntoskrnl.exe/tests/ntoskrnl.c | 9 ++++++ 2 files changed, 60 insertions(+)
diff --git a/dlls/ntoskrnl.exe/tests/driver4.c b/dlls/ntoskrnl.exe/tests/driver4.c index fd495a26e45..fedd5f25c2c 100644 --- a/dlls/ntoskrnl.exe/tests/driver4.c +++ b/dlls/ntoskrnl.exe/tests/driver4.c @@ -167,15 +167,22 @@ struct socket_context { };
+#define TEST_BUFFER_LENGTH 256 + static void test_wsk_listen_socket(void) { + static const char test_receive_string[] = "Client test string 1."; const WSK_PROVIDER_LISTEN_DISPATCH *tcp_dispatch, *udp_dispatch; + static const char test_send_string[] = "Server test string 1."; static const WSK_CLIENT_LISTEN_DISPATCH client_listen_dispatch; const WSK_PROVIDER_CONNECTION_DISPATCH *accept_dispatch; WSK_SOCKET *tcp_socket, *udp_socket, *accept_socket; struct socket_context context; + WSK_BUF wsk_buf1, wsk_buf2; + void *buffer1, *buffer2; struct sockaddr_in addr; LARGE_INTEGER timeout; + MDL *mdl1, *mdl2; NTSTATUS status; KEVENT event; IRP *irp; @@ -183,6 +190,19 @@ static void test_wsk_listen_socket(void) irp = IoAllocateIrp(1, FALSE); KeInitializeEvent(&event, SynchronizationEvent, FALSE);
+ buffer1 = ExAllocatePool(NonPagedPool, TEST_BUFFER_LENGTH); + mdl1 = IoAllocateMdl(buffer1, TEST_BUFFER_LENGTH, FALSE, FALSE, NULL); + MmBuildMdlForNonPagedPool(mdl1); + buffer2 = ExAllocatePool(NonPagedPool, TEST_BUFFER_LENGTH); + mdl2 = IoAllocateMdl(buffer2, TEST_BUFFER_LENGTH, FALSE, FALSE, NULL); + MmBuildMdlForNonPagedPool(mdl2); + + wsk_buf1.Mdl = mdl1; + wsk_buf1.Offset = 0; + wsk_buf1.Length = TEST_BUFFER_LENGTH; + wsk_buf2 = wsk_buf1; + wsk_buf2.Mdl = mdl2; + status = provider_npi.Dispatch->WskSocket(NULL, AF_INET, SOCK_STREAM, IPPROTO_TCP, WSK_FLAG_LISTEN_SOCKET, &context, &client_listen_dispatch, NULL, NULL, NULL, NULL); ok(status == STATUS_INVALID_PARAMETER, "Got unexpected status %#x.\n", status); @@ -283,6 +303,32 @@ static void test_wsk_listen_socket(void) accept_socket = (WSK_SOCKET *)wsk_irp->IoStatus.Information; accept_dispatch = accept_socket->Dispatch;
+ IoReuseIrp(irp, STATUS_UNSUCCESSFUL); + IoSetCompletionRoutine(irp, irp_completion_routine, &event, TRUE, TRUE, TRUE); + status = accept_dispatch->WskReceive(accept_socket, &wsk_buf2, 0, irp); + ok(status == STATUS_PENDING, "Got unexpected status %#x.\n", status); + + IoReuseIrp(wsk_irp, STATUS_UNSUCCESSFUL); + IoSetCompletionRoutine(wsk_irp, irp_completion_routine, &irp_complete_event, TRUE, TRUE, TRUE); + strcpy(buffer1, test_send_string); + /* Setting Length in WSK_BUF greater than MDL allocation size BSODs Windows. + * wsk_buf1.Length = TEST_BUFFER_LENGTH * 2; */ + status = accept_dispatch->WskSend(accept_socket, &wsk_buf1, 0, wsk_irp); + ok(status == STATUS_PENDING, "Got unexpected status %#x.\n", status); + + status = KeWaitForSingleObject(&irp_complete_event, Executive, KernelMode, FALSE, &timeout); + ok(status == STATUS_SUCCESS, "Got unexpected status %#x.\n", status); + ok(wsk_irp->IoStatus.Status == STATUS_SUCCESS, "Got unexpected status %#x.\n", wsk_irp->IoStatus.Status); + ok(wsk_irp->IoStatus.Information == TEST_BUFFER_LENGTH, "Got unexpected status %#x.\n", + wsk_irp->IoStatus.Status); + + status = KeWaitForSingleObject(&event, Executive, KernelMode, FALSE, &timeout); + ok(status == STATUS_SUCCESS, "Got unexpected status %#x.\n", status); + ok(irp->IoStatus.Status == STATUS_SUCCESS, "Got unexpected status %#x.\n", irp->IoStatus.Status); + ok(irp->IoStatus.Information == sizeof(test_receive_string), "Got unexpected Information %#lx.\n", + irp->IoStatus.Information); + ok(!strcmp(buffer2, test_receive_string), "Received unexpected data.\n"); + IoReuseIrp(wsk_irp, STATUS_UNSUCCESSFUL); IoSetCompletionRoutine(wsk_irp, irp_completion_routine, &irp_complete_event, TRUE, TRUE, TRUE); status = accept_dispatch->Basic.WskCloseSocket(accept_socket, wsk_irp); @@ -317,6 +363,11 @@ static void test_wsk_listen_socket(void) ok(irp->IoStatus.Status == STATUS_CANCELLED, "Got unexpected status %#x.\n", irp->IoStatus.Status); ok(!irp->IoStatus.Information, "Got unexpected Information %#lx.\n", irp->IoStatus.Information); IoFreeIrp(irp); + + IoFreeMdl(mdl1); + IoFreeMdl(mdl2); + ExFreePool(buffer1); + ExFreePool(buffer2); }
static NTSTATUS main_test(DEVICE_OBJECT *device, IRP *irp, IO_STACK_LOCATION *stack) diff --git a/dlls/ntoskrnl.exe/tests/ntoskrnl.c b/dlls/ntoskrnl.exe/tests/ntoskrnl.c index 4cdcda7d3ea..c53f97e9d19 100644 --- a/dlls/ntoskrnl.exe/tests/ntoskrnl.c +++ b/dlls/ntoskrnl.exe/tests/ntoskrnl.c @@ -516,8 +516,10 @@ static void test_driver3(void)
static DWORD WINAPI wsk_test_thread(void *parameter) { + static const char test_send_string[] = "Client test string 1."; static const WORD version = MAKEWORD(2, 2); struct sockaddr_in addr; + char buffer[256]; int ret, err; WSADATA data; SOCKET s; @@ -541,6 +543,13 @@ static DWORD WINAPI wsk_test_thread(void *parameter) } ok(!ret, "Error connecting, WSAGetLastError() %u.\n", WSAGetLastError());
+ ret = send(s, test_send_string, sizeof(test_send_string), 0); + ok(ret == sizeof(test_send_string), "Got unexpected ret %d.\n", ret); + + ret = recv(s, buffer, sizeof(buffer), 0); + ok(ret == sizeof(buffer), "Got unexpected ret %d.\n", ret); + ok(!strcmp(buffer, "Server test string 1."), "Received unexpected data.\n"); + closesocket(s); return TRUE; }