Signed-off-by: Zebediah Figura <z.figura12(a)gmail.com>
---
v2: Avoid making assumptions about alignment.
dlls/ws2_32/socket.c | 183 +++------------------
dlls/ws2_32/tests/sock.c | 27 ++--
include/wine/afd.h | 7 +
server/sock.c | 333 +++++++++++++++++++++++++++++++++++++--
4 files changed, 356 insertions(+), 194 deletions(-)
diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c
index 2eb1e1a7307..e93c4ccf589 100644
--- a/dlls/ws2_32/socket.c
+++ b/dlls/ws2_32/socket.c
@@ -2481,99 +2481,6 @@ static NTSTATUS WS2_async_recv( void *user, IO_STATUS_BLOCK *iosb, NTSTATUS stat
return status;
}
-/***********************************************************************
- * WS2_async_accept_recv (INTERNAL)
- *
- * This function is used to finish the read part of an accept request. It is
- * needed to place the completion on the correct socket (listener).
- */
-static NTSTATUS WS2_async_accept_recv( void *user, IO_STATUS_BLOCK *iosb, NTSTATUS status )
-{
- struct ws2_accept_async *wsa = user;
-
- status = WS2_async_recv( wsa->read, iosb, status );
- if (status == STATUS_PENDING)
- return status;
-
- if (wsa->cvalue)
- WS_AddCompletion( HANDLE2SOCKET(wsa->listen_socket), wsa->cvalue, iosb->u.Status, iosb->Information, TRUE );
-
- release_async_io( &wsa->io );
- return status;
-}
-
-/***********************************************************************
- * WS2_async_accept (INTERNAL)
- *
- * This is the function called to satisfy the AcceptEx callback
- */
-static NTSTATUS WS2_async_accept( void *user, IO_STATUS_BLOCK *iosb, NTSTATUS status )
-{
- struct ws2_accept_async *wsa = user;
- int len;
- char *addr;
-
- TRACE("status: 0x%x listen: %p, accept: %p\n", status, wsa->listen_socket, wsa->accept_socket);
-
- if (status == STATUS_ALERTED)
- {
- obj_handle_t accept_handle = wine_server_obj_handle( wsa->accept_socket );
- IO_STATUS_BLOCK io;
-
- status = NtDeviceIoControlFile( wsa->listen_socket, NULL, NULL, NULL, &io, IOCTL_AFD_ACCEPT_INTO,
- &accept_handle, sizeof(accept_handle), NULL, 0 );
-
- if (NtStatusToWSAError( status ) == WSAEWOULDBLOCK)
- return STATUS_PENDING;
-
- if (status == STATUS_INVALID_HANDLE)
- {
- FIXME("AcceptEx accepting socket closed but request was not cancelled\n");
- status = STATUS_CANCELLED;
- }
- }
- else if (status == STATUS_HANDLES_CLOSED)
- status = STATUS_CANCELLED; /* strange windows behavior */
-
- if (status != STATUS_SUCCESS)
- goto finish;
-
- /* WS2 Spec says size param is extra 16 bytes long...what do we put in it? */
- addr = ((char *)wsa->buf) + wsa->data_len;
- len = wsa->local_len - sizeof(int);
- WS_getsockname(HANDLE2SOCKET(wsa->accept_socket),
- (struct WS_sockaddr *)(addr + sizeof(int)), &len);
- *(int *)addr = len;
-
- addr += wsa->local_len;
- len = wsa->remote_len - sizeof(int);
- WS_getpeername(HANDLE2SOCKET(wsa->accept_socket),
- (struct WS_sockaddr *)(addr + sizeof(int)), &len);
- *(int *)addr = len;
-
- if (!wsa->read)
- goto finish;
-
- wsa->io.callback = WS2_async_accept_recv;
- status = register_async( ASYNC_TYPE_READ, wsa->accept_socket, &wsa->io,
- wsa->user_overlapped->hEvent, NULL, NULL, iosb);
-
- if (status != STATUS_PENDING)
- goto finish;
-
- /* The APC has finished but no completion should be sent for the operation yet, additional processing
- * needs to be performed by WS2_async_accept_recv() first. */
- return STATUS_MORE_PROCESSING_REQUIRED;
-
-finish:
- iosb->u.Status = status;
- iosb->Information = 0;
-
- if (wsa->read) release_async_io( &wsa->read->io );
- release_async_io( &wsa->io );
- return status;
-}
-
/***********************************************************************
* WS2_send (INTERNAL)
*
@@ -2820,23 +2727,30 @@ error:
/***********************************************************************
* AcceptEx
*/
-static BOOL WINAPI WS2_AcceptEx(SOCKET listener, SOCKET acceptor, PVOID dest, DWORD dest_len,
- DWORD local_addr_len, DWORD rem_addr_len, LPDWORD received,
- LPOVERLAPPED overlapped)
+static BOOL WINAPI WS2_AcceptEx( SOCKET listener, SOCKET acceptor, void *dest, DWORD recv_len,
+ DWORD local_len, DWORD remote_len, DWORD *ret_len, OVERLAPPED *overlapped)
{
- DWORD status;
- struct ws2_accept_async *wsa;
- int fd;
+ struct afd_accept_into_params params =
+ {
+ .accept_handle = acceptor,
+ .recv_len = recv_len,
+ .local_len = local_len,
+ };
+ void *cvalue = NULL;
+ NTSTATUS status;
- TRACE("(%04lx, %04lx, %p, %d, %d, %d, %p, %p)\n", listener, acceptor, dest, dest_len, local_addr_len,
- rem_addr_len, received, overlapped);
+ TRACE( "listener %#lx, acceptor %#lx, dest %p, recv_len %u, local_len %u, remote_len %u, ret_len %p, "
+ "overlapped %p\n", listener, acceptor, dest, recv_len, local_len, remote_len, ret_len, overlapped );
if (!overlapped)
{
SetLastError(WSA_INVALID_PARAMETER);
return FALSE;
}
+
+ if (!((ULONG_PTR)overlapped->hEvent & 1)) cvalue = overlapped;
overlapped->Internal = STATUS_PENDING;
+ overlapped->InternalHigh = 0;
if (!dest)
{
@@ -2844,72 +2758,19 @@ static BOOL WINAPI WS2_AcceptEx(SOCKET listener, SOCKET acceptor, PVOID dest, DW
return FALSE;
}
- if (!rem_addr_len)
+ if (!remote_len)
{
SetLastError(WSAEFAULT);
return FALSE;
}
- fd = get_sock_fd( listener, FILE_READ_DATA, NULL );
- if (fd == -1) return FALSE;
- release_sock_fd( listener, fd );
-
- fd = get_sock_fd( acceptor, FILE_READ_DATA, NULL );
- if (fd == -1) return FALSE;
- release_sock_fd( acceptor, fd );
+ status = NtDeviceIoControlFile( SOCKET2HANDLE(listener), overlapped->hEvent, NULL, cvalue,
+ (IO_STATUS_BLOCK *)overlapped, IOCTL_AFD_ACCEPT_INTO, ¶ms, sizeof(params),
+ dest, recv_len + local_len + remote_len );
- wsa = (struct ws2_accept_async *)alloc_async_io( sizeof(*wsa), WS2_async_accept );
- if(!wsa)
- {
- SetLastError(WSAEFAULT);
- return FALSE;
- }
-
- wsa->listen_socket = SOCKET2HANDLE(listener);
- wsa->accept_socket = SOCKET2HANDLE(acceptor);
- wsa->user_overlapped = overlapped;
- wsa->cvalue = !((ULONG_PTR)overlapped->hEvent & 1) ? (ULONG_PTR)overlapped : 0;
- wsa->buf = dest;
- wsa->data_len = dest_len;
- wsa->local_len = local_addr_len;
- wsa->remote_len = rem_addr_len;
- wsa->read = NULL;
-
- if (wsa->data_len)
- {
- /* set up a read request if we need it */
- wsa->read = (struct ws2_async *)alloc_async_io( offsetof(struct ws2_async, iovec[1]), WS2_async_accept_recv );
- if (!wsa->read)
- {
- HeapFree( GetProcessHeap(), 0, wsa );
- SetLastError(WSAEFAULT);
- return FALSE;
- }
-
- wsa->read->hSocket = wsa->accept_socket;
- wsa->read->flags = 0;
- wsa->read->lpFlags = &wsa->read->flags;
- wsa->read->addr = NULL;
- wsa->read->addrlen.ptr = NULL;
- wsa->read->control = NULL;
- wsa->read->n_iovecs = 1;
- wsa->read->first_iovec = 0;
- wsa->read->completion_func = NULL;
- wsa->read->iovec[0].iov_base = wsa->buf;
- wsa->read->iovec[0].iov_len = wsa->data_len;
- }
-
- status = register_async( ASYNC_TYPE_READ, SOCKET2HANDLE(listener), &wsa->io,
- overlapped->hEvent, NULL, (void *)wsa->cvalue, (IO_STATUS_BLOCK *)overlapped );
-
- if(status != STATUS_PENDING)
- {
- HeapFree( GetProcessHeap(), 0, wsa->read );
- HeapFree( GetProcessHeap(), 0, wsa );
- }
-
- SetLastError( NtStatusToWSAError(status) );
- return FALSE;
+ if (ret_len) *ret_len = overlapped->InternalHigh;
+ WSASetLastError( NtStatusToWSAError(status) );
+ return !status;
}
/***********************************************************************
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c
index 5fb85bfaa60..c20a83b4261 100644
--- a/dlls/ws2_32/tests/sock.c
+++ b/dlls/ws2_32/tests/sock.c
@@ -7379,6 +7379,8 @@ todo_wine
ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on a non-listening socket "
"returned %d + errno %d\n", bret, WSAGetLastError());
ok(overlapped.Internal == STATUS_PENDING, "got %08x\n", (ULONG)overlapped.Internal);
+ if (!bret && WSAGetLastError() == ERROR_IO_PENDING)
+ CancelIo((HANDLE)listener);
iret = listen(listener, 5);
ok(!iret, "failed to listen, error %u\n", GetLastError());
@@ -7452,9 +7454,9 @@ todo_wine
bytesReturned = 0xdeadbeef;
SetLastError(0xdeadbeef);
bret = GetOverlappedResult((HANDLE)listener, &overlapped, &bytesReturned, FALSE);
- todo_wine ok(!bret, "expected failure\n");
- todo_wine ok(GetLastError() == ERROR_INSUFFICIENT_BUFFER, "got error %u\n", GetLastError());
- todo_wine ok((NTSTATUS)overlapped.Internal == STATUS_BUFFER_TOO_SMALL, "got %#lx\n", overlapped.Internal);
+ ok(!bret, "expected failure\n");
+ ok(GetLastError() == ERROR_INSUFFICIENT_BUFFER, "got error %u\n", GetLastError());
+ ok((NTSTATUS)overlapped.Internal == STATUS_BUFFER_TOO_SMALL, "got %#lx\n", overlapped.Internal);
ok(!bytesReturned, "got size %u\n", bytesReturned);
closesocket(acceptor);
@@ -7796,19 +7798,10 @@ todo_wine
closesocket(acceptor);
dwret = WaitForSingleObject(overlapped.hEvent, 1000);
- todo_wine ok(dwret == WAIT_OBJECT_0,
+ ok(dwret == WAIT_OBJECT_0,
"Waiting for accept event failed with %d + errno %d\n", dwret, GetLastError());
-
- if (dwret != WAIT_TIMEOUT) {
- bret = GetOverlappedResult((HANDLE)listener, &overlapped, &bytesReturned, FALSE);
- ok(!bret && GetLastError() == ERROR_OPERATION_ABORTED, "GetOverlappedResult failed, error %d\n", GetLastError());
- }
- else {
- bret = CancelIo((HANDLE) listener);
- ok(bret, "Failed to cancel failed test. Bailing...\n");
- if (!bret) return;
- WaitForSingleObject(overlapped.hEvent, 0);
- }
+ bret = GetOverlappedResult((HANDLE)listener, &overlapped, &bytesReturned, FALSE);
+ ok(!bret && GetLastError() == ERROR_OPERATION_ABORTED, "GetOverlappedResult failed, error %d\n", GetLastError());
acceptor = socket(AF_INET, SOCK_STREAM, 0);
ok(acceptor != INVALID_SOCKET, "failed to create socket, error %u\n", GetLastError());
@@ -9381,12 +9374,12 @@ static void test_completion_port(void)
bret = GetQueuedCompletionStatus(io_port, &num_bytes, &key, &olp, 100);
ok(bret == FALSE, "failed to get completion status %u\n", bret);
- todo_wine ok(GetLastError() == ERROR_OPERATION_ABORTED
+ ok(GetLastError() == ERROR_OPERATION_ABORTED
|| GetLastError() == ERROR_CONNECTION_ABORTED, "got error %u\n", GetLastError());
ok(key == 125, "Key is %lu\n", key);
ok(num_bytes == 0, "Number of bytes transferred is %u\n", num_bytes);
ok(olp == &ov, "Overlapped structure is at %p\n", olp);
- todo_wine ok((NTSTATUS)olp->Internal == STATUS_CANCELLED
+ ok((NTSTATUS)olp->Internal == STATUS_CANCELLED
|| (NTSTATUS)olp->Internal == STATUS_CONNECTION_ABORTED, "got status %#lx\n", olp->Internal);
SetLastError(0xdeadbeef);
diff --git a/include/wine/afd.h b/include/wine/afd.h
index 5a994084e16..07320e7bab5 100644
--- a/include/wine/afd.h
+++ b/include/wine/afd.h
@@ -22,6 +22,7 @@
#define __WINE_WINE_AFD_H
#include <winioctl.h>
+#include "wine/server_protocol.h"
#define IOCTL_AFD_CREATE CTL_CODE(FILE_DEVICE_NETWORK, 200, METHOD_BUFFERED, FILE_WRITE_ACCESS)
#define IOCTL_AFD_ACCEPT CTL_CODE(FILE_DEVICE_NETWORK, 201, METHOD_BUFFERED, FILE_WRITE_ACCESS)
@@ -35,4 +36,10 @@ struct afd_create_params
unsigned int flags;
};
+struct afd_accept_into_params
+{
+ obj_handle_t accept_handle;
+ unsigned int recv_len, local_len;
+};
+
#endif
diff --git a/server/sock.c b/server/sock.c
index 4f97fe72080..e09b0ee8ed0 100644
--- a/server/sock.c
+++ b/server/sock.c
@@ -84,6 +84,7 @@
#include "winerror.h"
#define USE_WS_PREFIX
#include "winsock2.h"
+#include "ws2tcpip.h"
#include "wsipx.h"
#include "wine/afd.h"
@@ -120,6 +121,15 @@
#define FD_WINE_RAW 0x80000000
#define FD_WINE_INTERNAL 0xFFFF0000
+struct accept_req
+{
+ struct list entry;
+ struct async *async;
+ struct sock *acceptsock;
+ int accepted;
+ unsigned int recv_len, local_len;
+};
+
struct sock
{
struct object obj; /* object header */
@@ -143,8 +153,11 @@ struct sock
struct async_queue read_q; /* queue for asynchronous reads */
struct async_queue write_q; /* queue for asynchronous writes */
struct async_queue ifchange_q; /* queue for interface change notifications */
+ struct async_queue accept_q; /* queue for asynchronous accepts */
struct object *ifchange_obj; /* the interface change notification object */
struct list ifchange_entry; /* entry in ifchange notification list */
+ struct list accept_list; /* list of pending accept requests */
+ struct accept_req *accept_recv_req; /* pending accept-into request which will recv on this socket */
};
static void sock_dump( struct object *obj, int verbose );
@@ -161,6 +174,7 @@ static int sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async );
static void sock_queue_async( struct fd *fd, struct async *async, int type, int count );
static void sock_reselect_async( struct fd *fd, struct async_queue *queue );
+static int accept_into_socket( struct sock *sock, struct sock *acceptsock );
static int sock_get_ntstatus( int err );
static unsigned int sock_get_error( int err );
@@ -203,6 +217,98 @@ static const struct fd_ops sock_fd_ops =
sock_reselect_async /* reselect_async */
};
+union unix_sockaddr
+{
+ struct sockaddr addr;
+ struct sockaddr_in in;
+ struct sockaddr_in6 in6;
+#ifdef HAS_IPX
+ struct sockaddr_ipx ipx;
+#endif
+#ifdef HAS_IRDA
+ struct sockaddr_irda irda;
+#endif
+};
+
+static int sockaddr_from_unix( const union unix_sockaddr *uaddr, struct WS_sockaddr *wsaddr, socklen_t wsaddrlen )
+{
+ memset( wsaddr, 0, wsaddrlen );
+
+ switch (uaddr->addr.sa_family)
+ {
+ case AF_INET:
+ {
+ struct WS_sockaddr_in win = {0};
+
+ if (wsaddrlen < sizeof(win)) return -1;
+ win.sin_family = WS_AF_INET;
+ win.sin_port = uaddr->in.sin_port;
+ memcpy( &win.sin_addr, &uaddr->in.sin_addr, sizeof(win.sin_addr) );
+ memcpy( wsaddr, &win, sizeof(win) );
+ return sizeof(win);
+ }
+
+ case AF_INET6:
+ {
+ struct WS_sockaddr_in6 win = {0};
+
+ if (wsaddrlen < sizeof(struct WS_sockaddr_in6_old)) return -1;
+ win.sin6_family = WS_AF_INET6;
+ win.sin6_port = uaddr->in6.sin6_port;
+ win.sin6_flowinfo = uaddr->in6.sin6_flowinfo;
+ memcpy( &win.sin6_addr, &uaddr->in6.sin6_addr, sizeof(win.sin6_addr) );
+#ifdef HAVE_STRUCT_SOCKADDR_IN6_SIN6_SCOPE_ID
+ win.sin6_scope_id = uaddr->in6.sin6_scope_id;
+#endif
+ if (wsaddrlen >= sizeof(struct WS_sockaddr_in6))
+ {
+ memcpy( wsaddr, &win, sizeof(struct WS_sockaddr_in6) );
+ return sizeof(struct WS_sockaddr_in6);
+ }
+ memcpy( wsaddr, &win, sizeof(struct WS_sockaddr_in6_old) );
+ return sizeof(struct WS_sockaddr_in6_old);
+ }
+
+#ifdef HAS_IPX
+ case AF_IPX:
+ {
+ struct WS_sockaddr_ipx win = {0};
+
+ if (wsaddrlen < sizeof(win)) return -1;
+ win.sa_family = WS_AF_IPX;
+ memcpy( win.sa_netnum, &uaddr->ipx.sipx_network, sizeof(win.sa_netnum) );
+ memcpy( win.sa_nodenum, &uaddr->ipx.sipx_node, sizeof(win.sa_nodenum) );
+ win.sa_socket = uaddr->ipx.sipx_port;
+ memcpy( wsaddr, &win, sizeof(win) );
+ return sizeof(win);
+ }
+#endif
+
+#ifdef HAS_IRDA
+ case AF_IRDA:
+ {
+ SOCKADDR_IRDA win;
+
+ if (wsaddrlen < sizeof(win)) return -1;
+ win.irdaAddressFamily = WS_AF_IRDA;
+ memcpy( win.irdaDeviceID, &uaddr->irda.sir_addr, sizeof(win.irdaDeviceID) );
+ if (uaddr->irda.sir_lsap_sel != LSAP_ANY)
+ snprintf( win.irdaServiceName, sizeof(win.irdaServiceName), "LSAP-SEL%u", uaddr->irda.sir_lsap_sel );
+ else
+ memcpy( win.irdaServiceName, uaddr->irda.sir_name, sizeof(win.irdaServiceName) );
+ memcpy( wsaddr, &win, sizeof(win) );
+ return sizeof(win);
+ }
+#endif
+
+ case AF_UNSPEC:
+ return 0;
+
+ default:
+ return -1;
+
+ }
+}
/* Permutation of 0..FD_MAX_EVENTS - 1 representing the order in which
* we post messages if there are multiple events. Used to send
@@ -339,8 +445,134 @@ static inline int sock_error( struct fd *fd )
return optval;
}
+static void free_accept_req( struct accept_req *req )
+{
+ list_remove( &req->entry );
+ req->acceptsock->accept_recv_req = NULL;
+ release_object( req->async );
+ free( req );
+}
+
+static void fill_accept_output( struct accept_req *req, struct iosb *iosb )
+{
+ union unix_sockaddr unix_addr;
+ struct WS_sockaddr *win_addr;
+ socklen_t unix_len;
+ int fd, size = 0;
+ char *out_data;
+ int win_len;
+
+ if (!(out_data = mem_alloc( iosb->out_size ))) return;
+
+ fd = get_unix_fd( req->acceptsock->fd );
+
+ if (req->recv_len && (size = recv( fd, out_data, req->recv_len, 0 )) < 0)
+ {
+ if (!req->accepted && errno == EWOULDBLOCK)
+ {
+ req->accepted = 1;
+ sock_reselect( req->acceptsock );
+ set_error( STATUS_PENDING );
+ return;
+ }
+
+ set_win32_error( sock_get_error( errno ) );
+ free( out_data );
+ return;
+ }
+
+ if (req->local_len)
+ {
+ if (req->local_len < sizeof(int))
+ {
+ set_error( STATUS_BUFFER_TOO_SMALL );
+ free( out_data );
+ return;
+ }
+
+ unix_len = sizeof(unix_addr);
+ win_addr = (struct WS_sockaddr *)(out_data + req->recv_len + sizeof(int));
+ if (getsockname( fd, &unix_addr.addr, &unix_len ) < 0 ||
+ (win_len = sockaddr_from_unix( &unix_addr, win_addr, req->local_len )) < 0)
+ {
+ set_win32_error( sock_get_error( errno ) );
+ free( out_data );
+ return;
+ }
+ memcpy( out_data + req->recv_len, &win_len, sizeof(int) );
+ }
+
+ unix_len = sizeof(unix_addr);
+ win_addr = (struct WS_sockaddr *)(out_data + req->recv_len + req->local_len + sizeof(int));
+ if (getpeername( fd, &unix_addr.addr, &unix_len ) < 0 ||
+ (win_len = sockaddr_from_unix( &unix_addr, win_addr, iosb->out_size - req->recv_len - req->local_len )) < 0)
+ {
+ set_win32_error( sock_get_error( errno ) );
+ free( out_data );
+ return;
+ }
+ memcpy( out_data + req->recv_len + req->local_len, &win_len, sizeof(int) );
+
+ iosb->status = STATUS_SUCCESS;
+ iosb->result = size;
+ iosb->out_data = out_data;
+ set_error( STATUS_ALERTED );
+}
+
+static void complete_async_accept( struct sock *sock, struct accept_req *req )
+{
+ struct sock *acceptsock = req->acceptsock;
+ struct async *async = req->async;
+ struct iosb *iosb;
+
+ if (debug_level) fprintf( stderr, "completing accept request for socket %p\n", sock );
+
+ if (!accept_into_socket( sock, acceptsock )) return;
+
+ iosb = async_get_iosb( async );
+ fill_accept_output( req, iosb );
+ release_object( iosb );
+}
+
+static void complete_async_accept_recv( struct accept_req *req )
+{
+ struct async *async = req->async;
+ struct iosb *iosb;
+
+ if (debug_level) fprintf( stderr, "completing accept recv request for socket %p\n", req->acceptsock );
+
+ assert( req->recv_len );
+
+ iosb = async_get_iosb( async );
+ fill_accept_output( req, iosb );
+ release_object( iosb );
+}
+
static int sock_dispatch_asyncs( struct sock *sock, int event, int error )
{
+ if (event & (POLLIN | POLLPRI))
+ {
+ struct accept_req *req;
+
+ LIST_FOR_EACH_ENTRY( req, &sock->accept_list, struct accept_req, entry )
+ {
+ if (!req->accepted)
+ {
+ complete_async_accept( sock, req );
+ if (get_error() != STATUS_PENDING)
+ async_terminate( req->async, get_error() );
+ break;
+ }
+ }
+
+ if (sock->accept_recv_req)
+ {
+ complete_async_accept_recv( sock->accept_recv_req );
+ if (get_error() != STATUS_PENDING)
+ async_terminate( sock->accept_recv_req->async, get_error() );
+ }
+ }
+
if (is_fd_overlapped( sock->fd ))
{
if (event & (POLLIN|POLLPRI) && async_waiting( &sock->read_q ))
@@ -355,16 +587,25 @@ static int sock_dispatch_asyncs( struct sock *sock, int event, int error )
async_wake_up( &sock->write_q, STATUS_ALERTED );
event &= ~POLLOUT;
}
- if ( event & (POLLERR|POLLHUP) )
- {
- int status = sock_get_ntstatus( error );
+ }
- if ( !(sock->state & FD_READ) )
- async_wake_up( &sock->read_q, status );
- if ( !(sock->state & FD_WRITE) )
- async_wake_up( &sock->write_q, status );
- }
+ if (event & (POLLERR | POLLHUP))
+ {
+ int status = sock_get_ntstatus( error );
+ struct accept_req *req, *next;
+
+ if (!(sock->state & FD_READ))
+ async_wake_up( &sock->read_q, status );
+ if (!(sock->state & FD_WRITE))
+ async_wake_up( &sock->write_q, status );
+
+ LIST_FOR_EACH_ENTRY_SAFE( req, next, &sock->accept_list, struct accept_req, entry )
+ async_terminate( req->async, status );
+
+ if (sock->accept_recv_req)
+ async_terminate( sock->accept_recv_req->async, status );
}
+
return event;
}
@@ -539,7 +780,11 @@ static int sock_get_poll_events( struct fd *fd )
/* connecting, wait for writable */
return POLLOUT;
- if (async_queued( &sock->read_q ))
+ if (!list_empty( &sock->accept_list ) || sock->accept_recv_req )
+ {
+ ev |= POLLIN | POLLPRI;
+ }
+ else if (async_queued( &sock->read_q ))
{
if (async_waiting( &sock->read_q )) ev |= POLLIN | POLLPRI;
}
@@ -601,6 +846,16 @@ static void sock_queue_async( struct fd *fd, struct async *async, int type, int
static void sock_reselect_async( struct fd *fd, struct async_queue *queue )
{
struct sock *sock = get_fd_user( fd );
+ struct accept_req *req, *next;
+
+ LIST_FOR_EACH_ENTRY_SAFE( req, next, &sock->accept_list, struct accept_req, entry )
+ {
+ struct iosb *iosb = async_get_iosb( req->async );
+ if (iosb->status != STATUS_PENDING)
+ free_accept_req( req );
+ release_object( iosb );
+ }
+
/* ignore reselect on ifchange queue */
if (&sock->ifchange_q != queue)
sock_reselect( sock );
@@ -615,6 +870,8 @@ static struct fd *sock_get_fd( struct object *obj )
static void sock_destroy( struct object *obj )
{
struct sock *sock = (struct sock *)obj;
+ struct accept_req *req, *next;
+
assert( obj->ops == &sock_ops );
/* FIXME: special socket shutdown stuff? */
@@ -622,11 +879,18 @@ static void sock_destroy( struct object *obj )
if ( sock->deferred )
release_object( sock->deferred );
+ if (sock->accept_recv_req)
+ async_terminate( sock->accept_recv_req->async, STATUS_CANCELLED );
+
+ LIST_FOR_EACH_ENTRY_SAFE( req, next, &sock->accept_list, struct accept_req, entry )
+ async_terminate( req->async, STATUS_CANCELLED );
+
async_wake_up( &sock->ifchange_q, STATUS_CANCELLED );
sock_release_ifchange( sock );
free_async_queue( &sock->read_q );
free_async_queue( &sock->write_q );
free_async_queue( &sock->ifchange_q );
+ free_async_queue( &sock->accept_q );
if (sock->event) release_object( sock->event );
if (sock->fd)
{
@@ -658,10 +922,13 @@ static struct sock *create_socket(void)
sock->connect_time = 0;
sock->deferred = NULL;
sock->ifchange_obj = NULL;
+ sock->accept_recv_req = NULL;
init_async_queue( &sock->read_q );
init_async_queue( &sock->write_q );
init_async_queue( &sock->ifchange_q );
+ init_async_queue( &sock->accept_q );
memset( sock->errors, 0, sizeof(sock->errors) );
+ list_init( &sock->accept_list );
return sock;
}
@@ -1065,6 +1332,24 @@ static int sock_get_ntstatus( int err )
}
}
+static struct accept_req *alloc_accept_req( struct sock *acceptsock, struct async *async,
+ const struct afd_accept_into_params *params )
+{
+ struct accept_req *req = mem_alloc( sizeof(*req) );
+
+ if (req)
+ {
+ req->async = (struct async *)grab_object( async );
+ req->acceptsock = acceptsock;
+ req->accepted = 0;
+ req->recv_len = 0;
+ req->local_len = 0;
+ req->recv_len = params->recv_len;
+ req->local_len = params->local_len;
+ }
+ return req;
+}
+
static int sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async )
{
struct sock *sock = get_fd_user( fd );
@@ -1111,22 +1396,38 @@ static int sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async )
case IOCTL_AFD_ACCEPT_INTO:
{
static const int access = FILE_READ_ATTRIBUTES | FILE_WRITE_ATTRIBUTES | FILE_READ_DATA;
+ const struct afd_accept_into_params *params = get_req_data();
struct sock *acceptsock;
- obj_handle_t handle;
+ unsigned int remote_len;
+ struct accept_req *req;
- if (get_req_data_size() != sizeof(handle))
+ if (get_req_data_size() != sizeof(*params) ||
+ get_reply_max_size() < params->recv_len + params->local_len)
{
set_error( STATUS_BUFFER_TOO_SMALL );
return 0;
}
- handle = *(obj_handle_t *)get_req_data();
- if (!(acceptsock = (struct sock *)get_handle_obj( current->process, handle, access, &sock_ops )))
+ remote_len = get_reply_max_size() - params->recv_len - params->local_len;
+ if (remote_len < sizeof(int))
+ {
+ set_error( STATUS_INVALID_PARAMETER );
+ return 0;
+ }
+
+ if (!(acceptsock = (struct sock *)get_handle_obj( current->process, params->accept_handle, access, &sock_ops )))
return 0;
- if (accept_into_socket( sock, acceptsock ))
- acceptsock->wparam = handle;
+
+ if (!(req = alloc_accept_req( acceptsock, async, params ))) return 0;
+ list_add_tail( &sock->accept_list, &req->entry );
+ acceptsock->accept_recv_req = req;
release_object( acceptsock );
- return 0;
+
+ acceptsock->wparam = params->accept_handle;
+ queue_async( &sock->accept_q, async );
+ sock_reselect( sock );
+ set_error( STATUS_PENDING );
+ return 1;
}
case IOCTL_AFD_ADDRESS_LIST_CHANGE:
--
2.28.0