From 53c2c033585b2c969f4fbdb91054b3f336f8dda8 Mon Sep 17 00:00:00 2001
From: Mike Kaplinskiy <mike.kaplinskiy@gmail.com>
Date: Sun, 12 Jul 2009 15:45:43 -0400
Subject: ws2: Implement AcceptEx and GetAcceptExSockaddrs

---
 dlls/mswsock/mswsock.spec |    4 +-
 dlls/ws2_32/socket.c      |  370 ++++++++++++++++++++++++++++++++++++++++++++-
 dlls/ws2_32/tests/sock.c  |    6 +-
 dlls/ws2_32/ws2_32.spec   |    2 +
 4 files changed, 373 insertions(+), 9 deletions(-)

diff --git a/dlls/mswsock/mswsock.spec b/dlls/mswsock/mswsock.spec
index 1f2bd78..c0814fb 100644
--- a/dlls/mswsock/mswsock.spec
+++ b/dlls/mswsock/mswsock.spec
@@ -1,7 +1,7 @@
-@ stdcall AcceptEx(long long ptr long long long ptr ptr)
+@ stdcall AcceptEx(long long ptr long long long ptr ptr) ws2_32.AcceptEx
 @ stdcall EnumProtocolsA(ptr ptr ptr) ws2_32.WSAEnumProtocolsA
 @ stdcall EnumProtocolsW(ptr ptr ptr) ws2_32.WSAEnumProtocolsW
-@ stdcall GetAcceptExSockaddrs(ptr long long long ptr ptr ptr ptr)
+@ stdcall GetAcceptExSockaddrs(ptr long long long ptr ptr ptr ptr) ws2_32.GetAcceptExSockaddrs
 @ stub GetAddressByNameA
 @ stub GetAddressByNameW
 @ stub GetNameByTypeA
diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c
index b1c55c4..5ff7463 100644
--- a/dlls/ws2_32/socket.c
+++ b/dlls/ws2_32/socket.c
@@ -203,6 +203,19 @@ typedef struct ws2_async
     struct iovec                        iovec[1];
 } ws2_async;
 
+typedef struct ws2_acceptex_async
+{
+    SOCKET              s_listen;
+    SOCKET              s_accept;
+    LPWSAOVERLAPPED     user_overlapped;
+    IO_STATUS_BLOCK     local_iosb;
+    ULONG_PTR           cvalue;
+    PVOID               buf;      /* buffer to write data to */
+    int                 data_len;
+    int                 local_len;
+    int                 remote_len;
+} ws2_acceptex_async;
+
 /****************************************************************/
 
 /* ----------------------------------- internal data */
@@ -241,6 +254,8 @@ static FARPROC blocking_hook = (FARPROC)WSA_DefaultBlockingHook;
 static struct WS_hostent *WS_dup_he(const struct hostent* p_he);
 static struct WS_protoent *WS_dup_pe(const struct protoent* p_pe);
 static struct WS_servent *WS_dup_se(const struct servent* p_se);
+static void WS_AddCompletion( SOCKET sock, ULONG_PTR CompletionValue, NTSTATUS CompletionStatus, ULONG Information );
+static NTSTATUS WS2_async_recv( void* user, IO_STATUS_BLOCK* iosb, NTSTATUS status, void **apc);
 
 int WSAIOCTL_GetInterfaceCount(void);
 int WSAIOCTL_GetInterfaceName(int intNumber, char *intName);
@@ -1070,6 +1085,335 @@ static void WINAPI ws2_async_apc( void *arg, IO_STATUS_BLOCK *iosb, ULONG reserv
                                                     wsa->flags );
     HeapFree( GetProcessHeap(), 0, wsa );
 }
+/***********************************************************************
+ *              WS2_acceptex_finish                (INTERNAL)
+ *
+ * This function is used to finish the request set events or start a
+ * async read if necessary.
+ */
+static BOOL WS2_acceptex_finish( struct ws2_acceptex_async *wsa, IO_STATUS_BLOCK *iosb, NTSTATUS status )
+{
+    BOOL ret;
+    int len;
+    char *addr = ((char *)wsa->buf) + wsa->data_len;
+    
+    wine_server_clear_cache( SOCKET2HANDLE(wsa->s_accept) );
+    
+    wsa->user_overlapped->Internal = status;
+    wsa->user_overlapped->InternalHigh = 0;
+    iosb->u.Status = status;
+    iosb->Information = 0;
+
+    if (status != STATUS_SUCCESS)
+    {
+        ret = FALSE;
+        goto error;
+    }
+    
+    /* WS2 Spec says size param is extra 16 bytes long...what do we put in it? */
+    len = wsa->local_len - sizeof(int);
+    WS_getpeername(wsa->s_accept, (struct WS_sockaddr *)(addr + sizeof(int)),&len);
+    *(int*)addr = len;
+    
+    len = wsa->remote_len - sizeof(int);
+    WS_getsockname(wsa->s_accept, (struct WS_sockaddr *)(addr + sizeof(int)),&len);
+    *(int*)addr = len;
+    
+    if (!wsa->data_len)
+    {
+        ret = TRUE;
+        goto finish;
+    }
+    else
+    {
+        struct ws2_async *wsaa;
+        /* Can't do WSARecv, since we don't know if this socket has WSA_FLAG_OVERLAPPED */
+        if (!(wsaa = HeapAlloc( GetProcessHeap(), 0, FIELD_OFFSET(struct ws2_async, iovec[1]) )))
+        {
+            status = STATUS_NO_MEMORY;
+            ret = FALSE;
+            goto error;
+        }
+
+        /* FIXME: is this right? won't the completion be put on the wrong socket? */
+        wsaa->hSocket     = SOCKET2HANDLE(wsa->s_accept);
+        wsaa->flags       = 0;
+        wsaa->addr        = NULL;
+        wsaa->addrlen.ptr = NULL;
+        wsaa->n_iovecs    = 1;
+        wsaa->first_iovec = 0;
+        wsaa->iovec[0].iov_base = wsa->buf;
+        wsaa->iovec[0].iov_len  = wsa->data_len;
+        
+        SERVER_START_REQ( register_async )
+        {
+            req->type           = ASYNC_TYPE_READ;
+            req->async.handle   = wine_server_obj_handle( wsaa->hSocket );
+            req->async.callback = wine_server_client_ptr( WS2_async_recv );
+            req->async.iosb     = wine_server_client_ptr( wsa->user_overlapped );
+            req->async.arg      = wine_server_client_ptr( wsaa );
+            req->async.event    = wine_server_obj_handle( wsa->user_overlapped->hEvent );
+            req->async.cvalue   = wsa->cvalue;
+            status = wine_server_call( req );
+        }
+        SERVER_END_REQ;
+
+        ret = FALSE;
+            
+        if (status != STATUS_PENDING) {
+            ERR("Could not register async read, %d\n", status);
+            HeapFree( GetProcessHeap(), 0, wsaa );
+            
+            goto error;
+        }
+        
+        set_error( status );
+    }
+
+    return ret;
+
+error:
+    wsa->user_overlapped->Internal = status;
+    wsa->user_overlapped->InternalHigh = 0;
+    iosb->u.Status = status;
+    iosb->Information = 0;
+        
+    set_error( status );
+    
+finish:
+    if (wsa->user_overlapped->hEvent) SetEvent(wsa->user_overlapped->hEvent);
+    if (wsa->cvalue) WS_AddCompletion( wsa->s_listen, wsa->cvalue, status, 0 );
+
+    HeapFree( GetProcessHeap(), 0, wsa );
+    return ret;
+}
+/***********************************************************************
+ *              WS2_async_accept                (INTERNAL)
+ *
+ * This is the function called to satisfy the AcceptEx callback
+ */
+static NTSTATUS WINAPI WS2_async_accept( void *arg, IO_STATUS_BLOCK *iosb, NTSTATUS status )
+{
+    struct ws2_acceptex_async *wsa = arg;
+    
+    TRACE("status Message= %x listen: %lx, accept: %lx\n", status, wsa->s_listen, wsa->s_accept);
+    
+    if(status != STATUS_ALERTED && status != STATUS_HANDLES_CLOSED)
+    {
+        FIXME("Unexpected/Unhandeled status Message=%x\n", status);
+    }
+    
+    if(status == STATUS_HANDLES_CLOSED)
+    {
+        WS2_acceptex_finish ( wsa, iosb, status );
+        
+        return status;
+    }
+    
+    SERVER_START_REQ( accept_socket )
+    {
+        req->lhandle    = wine_server_obj_handle( SOCKET2HANDLE(wsa->s_listen) );
+        req->ahandle    = wine_server_obj_handle( SOCKET2HANDLE(wsa->s_accept) );
+        status = wine_server_call( req );
+    }
+    SERVER_END_REQ;
+    
+    if(status != STATUS_SUCCESS)
+    {
+        FIXME("error in getting socket. socket: %lx, status: %x\n", wsa->s_listen, status);
+        WS2_acceptex_finish ( wsa, iosb, status );
+        return status;
+    }
+
+    WS2_acceptex_finish ( wsa, iosb, status );
+
+    return status;
+}
+
+/***********************************************************************
+ *     AcceptEx (ws2_32.@)
+ *
+ * Accept a new connection, retrieving the connected addresses and initial data.
+ *
+ * listener       [I] Listening socket
+ * acceptor       [I] Socket to accept on
+ * dest           [O] Destination for inital data
+ * dest_len       [I] Size of dest in bytes
+ * local_addr_len [I] Number of bytes reserved in dest for local addrress
+ * rem_addr_len   [I] Number of bytes reserved in dest for remote addrress
+ * received       [O] Destination for number of bytes of initial data
+ * overlapped     [I] For asynchronous execution
+ *
+ * RETURNS
+ * Success: TRUE (We always return false because its simple)
+ * Failure: FALSE. Use WSAGetLastError() for details of the error.
+ */
+BOOL WINAPI AcceptEx(SOCKET listener, SOCKET acceptor, PVOID dest, DWORD dest_len,
+                     DWORD local_addr_len, DWORD rem_addr_len, LPDWORD received,
+                     LPOVERLAPPED overlapped)
+{
+    DWORD status;
+    struct ws2_acceptex_async *wsa;
+    IO_STATUS_BLOCK *iosb;
+    int fd;
+    BOOL is_blocking;
+    ULONG_PTR cvalue = (overlapped && ((ULONG_PTR)overlapped->hEvent & 1) == 0) ? (ULONG_PTR)overlapped : 0;
+
+    TRACE("(%lx, %lx, %p, %d, %d, %d, %p, %p)\n", listener, acceptor, dest, dest_len, local_addr_len, 
+                                                  rem_addr_len, received, overlapped);
+
+    fd = get_sock_fd( acceptor, FILE_READ_DATA, NULL );
+    if (fd == -1)
+    {
+        set_error(STATUS_INVALID_PARAMETER);
+        return FALSE;
+    }
+    release_sock_fd( acceptor, fd );
+
+    fd = get_sock_fd( listener, FILE_READ_DATA, NULL );
+    if (fd == -1)
+    {
+        set_error(STATUS_OBJECT_TYPE_MISMATCH);
+        return FALSE;
+    }
+    release_sock_fd( listener, fd );
+
+    if (!dest)
+    {
+        set_error(STATUS_INVALID_PARAMETER);
+        return FALSE;
+    }
+
+    if (!overlapped)
+    {
+        WSASetLastError(WSA_INVALID_PARAMETER);
+        return FALSE;
+    }
+
+    is_blocking = _is_blocking(listener);
+
+    if (is_blocking)
+    {
+        int fd = get_sock_fd( listener, FILE_READ_DATA, NULL );
+        if (fd == -1) {
+            set_error(STATUS_OBJECT_TYPE_MISMATCH);
+            return FALSE;
+        }
+        /* block here */
+        do_block(fd, POLLIN, -1);
+        _sync_sock_state(listener); /* let wineserver notice connection */
+        release_sock_fd( listener, fd );
+        /* retrieve any error codes from it */
+        SetLastError(_get_sock_error(listener, FD_ACCEPT_BIT));
+        /* FIXME: care about the error? */
+    }
+
+    wsa = HeapAlloc( GetProcessHeap(), 0, sizeof(*wsa) );
+    if(!wsa)
+    {
+        set_error(ERROR_NOT_ENOUGH_MEMORY);
+        return FALSE;
+    }
+
+    /*Setup the internal data structures!*/
+    overlapped->Internal = STATUS_PENDING;
+    overlapped->InternalHigh = 0;
+    iosb = (IO_STATUS_BLOCK *) overlapped;
+    iosb->u.Status = STATUS_PENDING;
+    iosb->Information = 0;
+
+    wsa->s_listen        = listener;
+    wsa->s_accept        = acceptor;
+    wsa->user_overlapped = overlapped;
+    wsa->cvalue          = cvalue;
+    wsa->buf             = dest;
+    wsa->data_len        = dest_len;
+    wsa->local_len       = local_addr_len;
+    wsa->remote_len      = rem_addr_len;
+    
+    SERVER_START_REQ( accept_socket )
+    {
+        req->lhandle    = wine_server_obj_handle( SOCKET2HANDLE(wsa->s_listen) );
+        req->ahandle    = wine_server_obj_handle( SOCKET2HANDLE(wsa->s_accept) );
+        status = wine_server_call( req );
+    }
+    SERVER_END_REQ;
+    if(status == STATUS_SUCCESS)
+    {
+        return WS2_acceptex_finish( wsa, iosb, status );
+    }
+    else
+    {
+        SERVER_START_REQ( register_accept_async )
+        {
+            req->async.handle   = wine_server_obj_handle( SOCKET2HANDLE(wsa->s_listen) );
+            req->ahandle        = wine_server_obj_handle( SOCKET2HANDLE(wsa->s_accept) );
+            req->async.callback = wine_server_client_ptr( WS2_async_accept );
+            req->async.iosb     = wine_server_client_ptr( iosb );
+            req->async.arg      = wine_server_client_ptr( wsa );
+            req->async.cvalue   = 0;
+            status = wine_server_call( req );
+        }
+        SERVER_END_REQ; 
+        
+        if(status != STATUS_PENDING)
+        {
+            HeapFree( GetProcessHeap(), 0, wsa );
+            set_error(status);
+            return FALSE;
+        }
+        set_error(STATUS_PENDING);
+        return FALSE;
+    }
+}
+
+/***********************************************************************
+ *     GetAcceptExSockaddrs (WS2_32.@)
+ *
+ * Get infomation about an accepted socket.
+ *
+ * _buf                [O] Destination for the first block of data from AcceptEx()
+ * data_size           [I] length of data in bytes
+ * local_size          [I] Bytes reserved for local addrinfo
+ * remote_size     [I] Bytes reserved for remote addrinfo
+ * local_addr      [O] Destination for local sockaddr
+ * local_addr_len  [I] Size of local_addr
+ * remote_addr         [O] Destination for remote sockaddr
+ * remote_addr_len     [I] Size of rem_addr
+ *
+ * RETURNS
+ *  Nothing.
+ */
+void WINAPI GetAcceptExSockaddrs( PVOID _buf, DWORD data_size, DWORD local_size, DWORD remote_size,
+                  struct sockaddr ** local_addr, LPINT local_addr_len, struct sockaddr ** remote_addr, LPINT remote_addr_len)
+{
+    int len;
+    char *buf = _buf;
+
+    TRACE("buf=%p, data_size=%d, local_size=%d, remote_size=%d, local_addr=%p (%p), remote_addr=%p (%p)\n", buf, data_size, local_size, remote_size,
+            local_addr, local_addr_len, remote_addr, remote_addr_len );
+
+    buf += data_size;
+    if (local_size)
+    {
+        len = *(int*)buf;
+        *local_addr_len = len;
+        *local_addr = (struct sockaddr*)(buf+sizeof(int));
+        buf += local_size;
+        TRACE("local %d bytes to %p\n", len, local_addr);
+    }
+    else
+        *local_addr_len = 0;
+    if (remote_size)
+    {
+        len = *(int*)buf;
+        *remote_addr_len = len;
+        *remote_addr = (struct sockaddr*)(buf+sizeof(int));
+        TRACE("remote %d bytes to %p\n", len, remote_addr);
+    }
+    else
+        *remote_addr_len = 0;
+}
 
 /***********************************************************************
  *              WS2_recv                (INTERNAL)
@@ -1161,7 +1505,7 @@ static NTSTATUS WS2_async_recv( void* user, IO_STATUS_BLOCK* iosb, NTSTATUS stat
         break;
     }
     if (status != STATUS_PENDING)
-    {
+    { 
         iosb->u.Status = status;
         iosb->Information = result;
         *apc = ws2_async_apc;
@@ -2359,9 +2703,27 @@ INT WINAPI WSAIoctl(SOCKET s,
 	break;
 
    case WS_SIO_GET_EXTENSION_FUNCTION_POINTER:
-       FIXME("SIO_GET_EXTENSION_FUNCTION_POINTER %s: stub\n", debugstr_guid(lpvInBuffer));
-       WSASetLastError(WSAEOPNOTSUPP);
-       return SOCKET_ERROR;
+   {
+        GUID acceptex_guid = WSAID_ACCEPTEX;
+        GUID acceptexsockaddrs_guid = WSAID_GETACCEPTEXSOCKADDRS;
+        if( IsEqualGUID(&acceptex_guid,lpvInBuffer) )
+        {
+            LPFN_ACCEPTEX *lpfvDummy = (LPFN_ACCEPTEX*)lpbOutBuffer;
+            *lpfvDummy = AcceptEx;
+            WSASetLastError(STATUS_SUCCESS);
+            return STATUS_SUCCESS;
+        }
+        if( IsEqualGUID(&acceptexsockaddrs_guid,lpvInBuffer) )
+        {
+            LPFN_GETACCEPTEXSOCKADDRS *lpfvDummy = (LPFN_GETACCEPTEXSOCKADDRS*)lpbOutBuffer;
+            *lpfvDummy = GetAcceptExSockaddrs;
+            WSASetLastError(STATUS_SUCCESS);
+            return STATUS_SUCCESS;
+        }
+        FIXME("SIO_GET_EXTENSION_FUNCTION_POINTER %s: stub\n", debugstr_guid(lpvInBuffer));
+        WSASetLastError(WSAEOPNOTSUPP);
+        return SOCKET_ERROR;
+   } 
 
    case WS_SIO_KEEPALIVE_VALS:
    {
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c
index acb2638..4415c06 100644
--- a/dlls/ws2_32/tests/sock.c
+++ b/dlls/ws2_32/tests/sock.c
@@ -2627,12 +2627,12 @@ static void test_AcceptEx(void)
 
     bret = pAcceptEx(listener, acceptor, buffer, 0, 0, sizeof(struct sockaddr_in) + 16,
         &bytesReturned, &overlapped);
-    ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on too small local address size "
+    todo_wine ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on too small local address size "
         "returned %d + errno %d\n", bret, WSAGetLastError());
 
     bret = pAcceptEx(listener, acceptor, buffer, 0, sizeof(struct sockaddr_in) + 16, 0,
         &bytesReturned, &overlapped);
-    ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on too small remote address size "
+    todo_wine ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on too small remote address size "
         "returned %d + errno %d\n", bret, WSAGetLastError());
 
     bret = pAcceptEx(listener, acceptor, buffer, 0,
@@ -2644,7 +2644,7 @@ static void test_AcceptEx(void)
     bret = pAcceptEx(listener, acceptor, buffer, sizeof(buffer) - 2*(sizeof(struct sockaddr_in) + 16),
         sizeof(struct sockaddr_in) + 16, sizeof(struct sockaddr_in) + 16,
         &bytesReturned, &overlapped);
-    ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on a non-listening socket "
+    todo_wine ok(bret == FALSE && WSAGetLastError() == WSAEINVAL, "AcceptEx on a non-listening socket "
         "returned %d + errno %d\n", bret, WSAGetLastError());
 
     iret = listen(listener, 5);
diff --git a/dlls/ws2_32/ws2_32.spec b/dlls/ws2_32/ws2_32.spec
index a77c215..2e20d38 100644
--- a/dlls/ws2_32/ws2_32.spec
+++ b/dlls/ws2_32/ws2_32.spec
@@ -119,3 +119,5 @@
 @ stdcall getaddrinfo(str str ptr ptr) WS_getaddrinfo
 @ stdcall getnameinfo(ptr long ptr long ptr long long) WS_getnameinfo
 @ stdcall inet_ntop(long ptr ptr long) WS_inet_ntop
+@ stdcall AcceptEx(long long ptr long long long ptr ptr) 
+@ stdcall GetAcceptExSockaddrs(ptr long long long ptr ptr ptr ptr)
-- 
1.6.3.3

