diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c index 0391554..5bdf10f 100644 --- a/dlls/ws2_32/socket.c +++ b/dlls/ws2_32/socket.c @@ -28,6 +28,7 @@ #include "config.h" #include "wine/port.h" +#include "wine/list.h" #include #include @@ -176,6 +177,14 @@ WINE_DEFAULT_DEBUG_CHANNEL(winsock); WINE_DECLARE_DEBUG_CHANNEL(winediag); +static struct sockentry +{ + struct list entry; + SOCKET socket; +}; + +static struct list socklist; + /* names of the protocols */ static const WCHAR NameIpxW[] = {'I', 'P', 'X', '\0'}; static const WCHAR NameSpxW[] = {'S', 'P', 'X', '\0'}; @@ -267,12 +276,52 @@ static CRITICAL_SECTION_DEBUG critsect_debug = }; static CRITICAL_SECTION csWSgetXXXbyYYY = { &critsect_debug, -1, 0, 0, 0, 0 }; +static CRITICAL_SECTION critical_queue; +static CRITICAL_SECTION_DEBUG critical_queue_debug = +{ + 0, 0, &critical_queue, + { &critical_queue_debug.ProcessLocksList, &critsect_debug.ProcessLocksList }, + 0, 0, { (DWORD_PTR)(__FILE__ ": critical_queue") } +}; +static CRITICAL_SECTION critical_queue = { &critical_queue_debug, -1, 0, 0, 0, 0 }; + union generic_unix_sockaddr { struct sockaddr addr; char data[128]; /* should be big enough for all families */ }; +static void enqueue(SOCKET s) +{ + struct sockentry *sockentry = HeapAlloc(GetProcessHeap(), 0, sizeof(struct sockentry)); + if (sockentry) + { + sockentry->socket = s; + EnterCriticalSection( &critical_queue ); + list_add_tail(&socklist, &sockentry->entry); + TRACE("socket %04lx added, current count %d\n", s, list_count(&socklist)); + LeaveCriticalSection(&critical_queue); + } + else + ERR("Failed to alloc sockentry memory for socket %04lx\n", s); +} + +static void dequeue(SOCKET s) +{ + struct sockentry *sockentry, *sockentry2; + EnterCriticalSection(&critical_queue); + LIST_FOR_EACH_ENTRY_SAFE( sockentry, sockentry2, &socklist, struct sockentry, entry ) + { + if (s == sockentry->socket) + { + list_remove(&sockentry->entry); + HeapFree(GetProcessHeap(), 0, sockentry); + TRACE("socket %04lx removed, current count %d\n", s, list_count(&socklist)); + } + } + LeaveCriticalSection(&critical_queue); +} + static inline const char *debugstr_sockaddr( const struct WS_sockaddr *a ) { if (!a) return "(nil)"; @@ -1219,6 +1268,9 @@ int WINAPI WSAStartup(WORD wVersionRequested, LPWSADATA lpWSAData) if (!lpWSAData) return WSAEINVAL; + if (!num_startup) + list_init(&socklist); + num_startup++; /* that's the whole of the negotiation for now */ @@ -1242,8 +1294,28 @@ int WINAPI WSAStartup(WORD wVersionRequested, LPWSADATA lpWSAData) */ INT WINAPI WSACleanup(void) { - if (num_startup) { + if (num_startup) + { num_startup--; + if (!num_startup) + { + int count; + + EnterCriticalSection(&critical_queue); + count = list_count(&socklist); + if (count) + { + struct sockentry *sockentry, *sockentry2; + + TRACE("auto closing %d sockets\n", count); + LIST_FOR_EACH_ENTRY_SAFE(sockentry, sockentry2, &socklist, struct sockentry, entry) + { + WS_closesocket(sockentry->socket); + } + } + LeaveCriticalSection(&critical_queue); + + } return 0; } SetLastError(WSANOTINITIALISED); @@ -2080,6 +2152,8 @@ static NTSTATUS WS2_async_accept( void *arg, IO_STATUS_BLOCK *iosb, NTSTATUS sta (struct WS_sockaddr *)(addr + sizeof(int)), &len); *(int *)addr = len; + enqueue(HANDLE2SOCKET(wsa->accept_socket)); + if (!wsa->read) goto finish; @@ -2339,6 +2413,7 @@ SOCKET WINAPI WS_accept(SOCKET s, struct WS_sockaddr *addr, int *addrlen32) WS_closesocket(as); return SOCKET_ERROR; } + enqueue(as); return as; } if (is_blocking && status == STATUS_CANT_WAIT) @@ -2698,6 +2773,8 @@ int WINAPI WS_bind(SOCKET s, const struct WS_sockaddr* name, int namelen) int WINAPI WS_closesocket(SOCKET s) { TRACE("socket %04lx\n", s); + dequeue(s); + if (CloseHandle(SOCKET2HANDLE(s))) return 0; return SOCKET_ERROR; } @@ -6029,6 +6106,8 @@ SOCKET WINAPI WSASocketW(int af, int type, int protocol, if (lpProtocolInfo && lpProtocolInfo->dwServiceFlags4 == 0xff00ff00) { ret = lpProtocolInfo->dwServiceFlags3; TRACE("\tgot duplicate %04lx\n", ret); + + enqueue(ret); return ret; } @@ -6129,7 +6208,9 @@ SOCKET WINAPI WSASocketW(int af, int type, int protocol, TRACE("\tcreated %04lx\n", ret ); if (ipxptype > 0) set_ipx_packettype(ret, ipxptype); - return ret; + + enqueue(ret); + return ret; } err = GetLastError();