Hi all,
Here’s another attempt, with all of Ken’s suggestions in place.
-Matt
---
dlls/ws2_32/socket.c | 116 +++++++++++++++++++++++++++++++++++++++++++++--
dlls/ws2_32/tests/sock.c | 2 -
2 files changed, 113 insertions(+), 5 deletions(-)
diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c
index ca82ec9..465be2e 100644
--- a/dlls/ws2_32/socket.c
+++ b/dlls/ws2_32/socket.c
@@ -1429,6 +1429,15 @@ static int set_ipx_packettype(int sock, int ptype)
#endif
}
+#define TABLE_SIZE 256
+#define BUCKET_DEPTH 16
+static SOCKET socket_table[TABLE_SIZE][BUCKET_DEPTH];
+
+/* Cache support */
+static void add_to_table(SOCKET s);
+static BOOL remove_from_table(SOCKET s);
+static BOOL socket_in_table(SOCKET s);
+
/* ----------------------------------- API -----
*
* Init / cleanup / error checking.
@@ -1470,8 +1479,24 @@ int WINAPI WSAStartup(WORD wVersionRequested, LPWSADATA lpWSAData)
INT WINAPI WSACleanup(void)
{
if (num_startup) {
- num_startup--;
- TRACE("pending cleanups: %d\n", num_startup);
+ /* WS_closesocket needs num_startup to be non-zero, so decrement afterwards */
+ if (num_startup == 1) {
+ TRACE("cleaning up sockets");
+ int i, j;
+ for (i = 0; i < TABLE_SIZE; ++i) {
+ for (j = 0; j < BUCKET_DEPTH; ++j) {
+ SOCKET s = socket_table[i][j];
+ if (s) {
+ WS_closesocket(s);
+ }
+ }
+ }
+ num_startup = 0;
+ } else {
+ num_startup--;
+ TRACE("pending cleanups: %d\n", num_startup);
+ }
+
return 0;
}
SetLastError(WSANOTINITIALISED);
@@ -2595,6 +2620,11 @@ SOCKET WINAPI WS_accept(SOCKET s, struct WS_sockaddr *addr, int *addrlen32)
WS_closesocket(as);
return SOCKET_ERROR;
}
+ else
+ {
+ add_to_table(as);
+ }
+
TRACE("\taccepted %04lx\n", as);
return as;
}
@@ -2965,15 +2995,21 @@ int WINAPI WS_closesocket(SOCKET s)
int res = SOCKET_ERROR, fd;
if (num_startup)
{
+ BOOL success = FALSE;
fd = get_sock_fd(s, FILE_READ_DATA, NULL);
if (fd >= 0)
{
release_sock_fd(s, fd);
if (CloseHandle(SOCKET2HANDLE(s)))
res = 0;
+
+ success = remove_from_table(s);
}
- else
+
+ if (!success) {
SetLastError(WSAENOTSOCK);
+ res = SOCKET_ERROR;
+ }
}
else
SetLastError(WSANOTINITIALISED);
@@ -4987,6 +5023,11 @@ static int WS2_sendto( SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount,
DWORD bytes_sent;
BOOL is_blocking;
+ if (!socket_in_table(s)) {
+ SetLastError(WSAENOTSOCK);
+ return SOCKET_ERROR;;
+ }
+
TRACE("socket %04lx, wsabuf %p, nbufs %d, flags %d, to %p, tolen %d, ovl %p, func %p\n",
s, lpBuffers, dwBufferCount, dwFlags,
to, tolen, lpOverlapped, lpCompletionRoutine);
@@ -6706,6 +6747,7 @@ SOCKET WINAPI WSASocketW(int af, int type, int protocol,
SERVER_END_REQ;
if (ret)
{
+ add_to_table(ret);
TRACE("\tcreated %04lx\n", ret );
if (ipxptype > 0)
set_ipx_packettype(ret, ipxptype);
@@ -8107,3 +8149,71 @@ INT WINAPI WSCEnumProtocols( LPINT protocols, LPWSAPROTOCOL_INFOW buffer, LPDWOR
return ret;
}
+
+/*****************/
+/* Cache support */
+
+static inline DWORD socket_to_index(SOCKET s)
+{
+ /* Hash to entry using Bernstein function */
+ DWORD h = 52812;
+ BYTE *b = (BYTE*)&s;
+ h = ((h << 5) + h) ^ b[0];
+ h = ((h << 5) + h) ^ b[1];
+ h = ((h << 5) + h) ^ b[2];
+ h = ((h << 5) + h) ^ b[3];
+ return h % TABLE_SIZE;
+}
+
+static void cache_update(SOCKET new_entry, SOCKET old_entry)
+{
+ int index, depth, tries;
+ LONG *bucket;
+
+ index = socket_to_index(new_entry == 0 ? old_entry : new_entry);
+ depth = 0;
+ bucket = (LONG*)&socket_table[index];
+ tries = 0;
+ do
+ {
+ LONG *slot = (LONG*)&bucket[depth++];
+ if (InterlockedCompareExchange(slot, new_entry, old_entry) == old_entry)
+ break;
+ if (depth == BUCKET_DEPTH)
+ {
+ ++tries;
+ depth = 0;
+ }
+
+ } while (tries < 3);
+ if (tries == 3)
+ {
+ ERR("Socket hash table bucket overflow. Resize buckets");
+ }
+}
+
+static void add_to_table(SOCKET s)
+{
+ cache_update(s, 0);
+}
+
+static BOOL remove_from_table(SOCKET s)
+{
+ if (!socket_in_table(s))
+ return FALSE;
+
+ cache_update(0, s);
+ return TRUE;
+}
+
+static inline BOOL socket_in_table(SOCKET s)
+{
+ int bucket, depth;
+ bucket = socket_to_index(s);
+ for (depth = 0; depth < TABLE_SIZE; ++depth) {
+ if (socket_table[bucket][depth] == s)
+ return TRUE;
+ }
+
+ return FALSE;
+}
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c
index 2d14496..afa4063 100644
--- a/dlls/ws2_32/tests/sock.c
+++ b/dlls/ws2_32/tests/sock.c
@@ -1119,7 +1119,6 @@ static void test_WithWSAStartup(void)
ok(res == 0, "WSAStartup() failed unexpectedly: %d\n", res);
/* show that sockets are destroyed automatically after WSACleanup */
- todo_wine {
SetLastError(0xdeadbeef);
res = send(src, "TEST", 4, 0);
error = WSAGetLastError();
@@ -1131,7 +1130,6 @@ static void test_WithWSAStartup(void)
error = WSAGetLastError();
ok(res == SOCKET_ERROR, "closesocket should have failed\n");
ok(error == WSAENOTSOCK, "expected 10038, got %d\n", error);
- }
closesocket(src);
closesocket(dst);
--
2.3.2 (Apple Git-55)
Responses inline.
On Tuesday, Aug 25, 2015 at 12:46 AM, Ken Thomases <ken@codeweavers.com>, wrote:
Thanks for continuing to work on this. By no means a thorough review:
> On Aug 24, 2015, at 10:49 PM, Matt Durgavich <mattdurgavich@gmail.com> wrote:
> +#define CACHE_SIZE 256
> +#define CACHE_DEPTH 16
> +static SOCKET socket_cache[CACHE_SIZE][CACHE_DEPTH];
> +
> +/* Cache support */
> +static void add_to_cache(SOCKET s);
> +static BOOL remove_from_cache(SOCKET s);
> +static BOOL socket_in_cache(SOCKET s);
I'm not sure this is a "cache", per se. It's a hash table, but naming issues are fairly minor.
Let's go with socket_table
> @@ -1470,8 +1479,24 @@ int WINAPI WSAStartup(WORD wVersionRequested, LPWSADATA lpWSAData)
> INT WINAPI WSACleanup(void)
> {
> if (num_startup) {
> - num_startup--;
> - TRACE("pending cleanups: %d\n", num_startup);
> + /* WS_closesocket needs num_startup to be non-zero, so decrement afterwards */
> + if (num_startup - 1 == 0) {
Why not "if (num_startup == 1)"?
> +static void add_to_cache(SOCKET s)
> +{
> + int index, depth;
> + SOCKET old;
> + LONG *dest;
> + index = socket_to_index(s) % CACHE_SIZE;
If you always mod the result of socket_to_index() by CACHE_SIZE, why not do that in that function?
Point taken.
> + for (depth = 0; depth < CACHE_DEPTH; ++depth) {
> + if (socket_cache[index][depth] == 0)
> + break;
> + }
> +
> + if (depth == CACHE_DEPTH) {
> + ERR("Socket hash table collision\n");
This should exit here or you access beyond the end of socket_cache[index] just below.
Good catch.
> + }
> +
> + dest = (PLONG)&socket_cache[index][depth];
Don't use PLONG, just use LONG*.
Ok
> + old = InterlockedExchange(dest, s);
> +
> + if (old != 0) {
> + ERR("Socket hash table internal corruption");
Reporting corruption isn't right. You should use InterlockedCompareExchange() with 0 as the comparand to _avoid_ corruption. If it fails, loop back to the depth loop and try again. The only error possible should be filling the hash table bucket.
Right. Makes sense. Will fix.
> + }
> +}
> +
> +static BOOL remove_from_cache(SOCKET s)
> +{
> + int index,depth;
> + SOCKET old;
> + LONG *dest;
> + index = socket_to_index(s) % CACHE_SIZE;
> + for (depth = 0; depth < CACHE_DEPTH; ++depth) {
> + if (socket_cache[index][depth] == s)
> + break;
> + }
> +
> + if (depth == CACHE_DEPTH) {
> + return FALSE;
> + }
> +
> + dest = (PLONG)&socket_cache[index][depth];
> + old = InterlockedExchange(dest, 0);
> + return (old == s);
Under what circumstances could old not equal s? If that can happen, then this should also use InterlockedCompareExchange(), otherwise you've removed some other socket.
If I fork a child, I'll inherit their sockets right? So some paranoia in the remove is needed. Unless wine doesn't work that way. I'll write it as the analog to add. Thanks! I appreciate the feedback.