Signed-off-by: Paul Gofman pgofman@codeweavers.com --- dlls/netio.sys/netio.c | 92 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-)
diff --git a/dlls/netio.sys/netio.c b/dlls/netio.sys/netio.c index cac5b8ceb51..dd6d2f48140 100644 --- a/dlls/netio.sys/netio.c +++ b/dlls/netio.sys/netio.c @@ -53,6 +53,15 @@ struct listen_socket_callback_context SOCKET acceptor; };
+struct connect_socket_callback_context +{ + struct wsk_socket_internal *socket; + SOCKADDR *remote_address; + const void *client_dispatch; + void *client_context; + IRP *pending_irp; +}; + #define MAX_PENDING_IO 10
struct wsk_pending_io @@ -73,6 +82,7 @@ struct wsk_socket_internal ADDRESS_FAMILY address_family; USHORT socket_type; ULONG protocol; + BOOL bound;
CRITICAL_SECTION cs_socket;
@@ -86,6 +96,8 @@ struct wsk_socket_internal };
static LPFN_ACCEPTEX pAcceptEx; +static LPFN_CONNECTEX pConnectEx; + static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch;
static inline struct wsk_socket_internal *wsk_socket_internal_from_wsk_socket(WSK_SOCKET *wsk_socket) @@ -294,6 +306,9 @@ static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULO else status = STATUS_SUCCESS;
+ if (status == STATUS_SUCCESS) + s->bound = TRUE; + TRACE("status %#x.\n", status); irp->IoStatus.Information = 0; dispatch_irp(irp, status); @@ -468,11 +483,84 @@ static const WSK_PROVIDER_LISTEN_DISPATCH wsk_provider_listen_dispatch = wsk_get_local_address, };
+static void WINAPI connect_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait, + TP_WAIT_RESULT wait_result) +{ + struct wsk_socket_internal *socket = socket_; + struct wsk_pending_io *io; + DWORD size; + + TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result); + + lock_socket(socket); + io = find_pending_io(socket, wait); + + GetOverlappedResult((HANDLE)socket->s, &io->ovr, &size, FALSE); + dispatch_pending_io(io, io->ovr.Internal, 0); + unlock_socket(socket); +} + +static BOOL WINAPI init_connect_functions(INIT_ONCE *once, void *param, void **context) +{ + GUID connectex_guid = WSAID_CONNECTEX; + SOCKET s = (SOCKET)param; + DWORD size; + + if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &connectex_guid, sizeof(connectex_guid), + &pConnectEx, sizeof(pConnectEx), &size, NULL, NULL)) + { + ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError()); + return FALSE; + } + return TRUE; +} + static NTSTATUS WINAPI wsk_connect(WSK_SOCKET *socket, SOCKADDR *remote_address, ULONG flags, IRP *irp) { - FIXME("socket %p, remote_address %p, flags %#x, irp %p stub.\n", socket, remote_address, flags, irp); + struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(socket); + static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT; + struct wsk_pending_io *io; + int error;
- return STATUS_NOT_IMPLEMENTED; + TRACE("socket %p, remote_address %p, flags %#x, irp %p.\n", + socket, remote_address, flags, irp); + + if (!irp) + return STATUS_INVALID_PARAMETER; + + if (!InitOnceExecuteOnce(&init_once, init_connect_functions, (void *)s->s, NULL)) + { + dispatch_irp(irp, STATUS_UNSUCCESSFUL); + return STATUS_PENDING; + } + + lock_socket(s); + + if (!(io = allocate_pending_io(s, connect_callback, irp))) + { + irp->IoStatus.Information = 0; + dispatch_irp(irp, STATUS_UNSUCCESSFUL); + unlock_socket(s); + return STATUS_PENDING; + } + + if (!s->bound) + { + dispatch_pending_io(io, STATUS_INVALID_DEVICE_STATE, 0); + unlock_socket(s); + return STATUS_INVALID_DEVICE_STATE; + } + + if (pConnectEx(s->s, remote_address, sizeof(*remote_address), NULL, 0, NULL, &io->ovr)) + dispatch_pending_io(io, STATUS_SUCCESS, 0); + else if ((error = WSAGetLastError()) == ERROR_IO_PENDING) + SetThreadpoolWait(io->tp_wait, io->ovr.hEvent, NULL); + else + dispatch_pending_io(io, sock_error_to_ntstatus(error), 0); + + unlock_socket(s); + + return STATUS_PENDING; }
static NTSTATUS WINAPI wsk_get_remote_address(WSK_SOCKET *socket, SOCKADDR *remote_address, IRP *irp)