From: Paul Gofman pgofman@codeweavers.com
--- dlls/ws2_32/tests/sock.c | 59 +++++++++++++++++++++++++++++++++++++++- server/sock.c | 10 +++++-- 2 files changed, 65 insertions(+), 4 deletions(-)
diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c index e53d9991e56..4fe50838c64 100644 --- a/dlls/ws2_32/tests/sock.c +++ b/dlls/ws2_32/tests/sock.c @@ -161,6 +161,7 @@ static GUID WSARecvMsg_GUID = WSAID_WSARECVMSG;
static SOCKET setup_server_socket(struct sockaddr_in *addr, int *len); static SOCKET setup_connector_socket(const struct sockaddr_in *addr, int len, BOOL nonblock); +static int sync_recv(SOCKET s, void *buffer, int len, DWORD flags);
static void tcp_socketpair_flags(SOCKET *src, SOCKET *dst, DWORD flags) { @@ -1225,7 +1226,7 @@ static void test_set_getsockopt(void) {AF_INET6, SOCK_DGRAM, IPPROTO_IPV6, IPV6_UNICAST_HOPS, TRUE, {1, 1, 4}, {0}, FALSE}, {AF_INET6, SOCK_DGRAM, IPPROTO_IPV6, IPV6_V6ONLY, TRUE, {1, 1, 1}, {0}, TRUE}, }; - SOCKET s, s2; + SOCKET s, s2, src, dst; int i, j, err, lasterr; int timeout; LINGER lingval; @@ -1237,6 +1238,7 @@ static void test_set_getsockopt(void) int expected_err, expected_size; DWORD value, save_value; UINT64 value64; + char buffer[4096];
struct _prottest { @@ -1302,6 +1304,61 @@ static void test_set_getsockopt(void) ok( !err, "getsockopt(SO_RCVBUF) failed error: %u\n", WSAGetLastError() ); ok( value == 4096, "expected 4096, got %lu\n", value );
+ value = 0; + size = sizeof(value); + err = setsockopt(s, SOL_SOCKET, SO_RCVBUF, (char *)&value, size); + ok( !err, "setsockopt(SO_RCVBUF) failed error: %u\n", WSAGetLastError() ); + value = 0xdeadbeef; + err = getsockopt(s, SOL_SOCKET, SO_RCVBUF, (char *)&value, &size); + ok( !err, "getsockopt(SO_RCVBUF) failed error: %u\n", WSAGetLastError() ); + ok( value == 0, "expected 0, got %lu\n", value ); + + /* Test non-blocking receive with too short SO_RCVBUF. */ + tcp_socketpair(&src, &dst); + set_blocking(src, FALSE); + set_blocking(dst, FALSE); + + value = 0; + size = sizeof(value); + err = setsockopt(src, SOL_SOCKET, SO_SNDBUF, (char *)&value, size); + ok( !err, "got %d, error %u.\n", err, WSAGetLastError() ); + + value = 0xdeadbeef; + err = getsockopt(dst, SOL_SOCKET, SO_RCVBUF, (char *)&value, &size); + ok( !err, "got %d, error %u.\n", err, WSAGetLastError() ); + if (value >= sizeof(buffer) * 3) + { + value = 1024; + size = sizeof(value); + err = setsockopt(dst, SOL_SOCKET, SO_RCVBUF, (char *)&value, size); + ok( !err, "got %d, error %u.\n", err, WSAGetLastError() ); + value = 0xdeadbeef; + err = getsockopt(dst, SOL_SOCKET, SO_RCVBUF, (char *)&value, &size); + ok( !err, "got %d, error %u.\n", err, WSAGetLastError() ); + ok( value == 1024, "expected 0, got %lu\n", value ); + + err = send(src, buffer, sizeof(buffer), 0); + ok(err == sizeof(buffer), "got %d\n", err); + err = send(src, buffer, sizeof(buffer), 0); + ok(err == sizeof(buffer), "got %d\n", err); + err = send(src, buffer, sizeof(buffer), 0); + ok(err == sizeof(buffer), "got %d\n", err); + + err = sync_recv(dst, buffer, sizeof(buffer), 0); + ok(err == sizeof(buffer), "got %d, error %u\n", err, WSAGetLastError()); + err = sync_recv(dst, buffer, sizeof(buffer), 0); + ok(err == sizeof(buffer), "got %d, error %u\n", err, WSAGetLastError()); + err = sync_recv(dst, buffer, sizeof(buffer), 0); + ok(err == sizeof(buffer), "got %d, error %u\n", err, WSAGetLastError()); + } + else + { + skip("Default SO_RCVBUF %lu is too small, skipping test.\n", value); + } + + closesocket(src); + closesocket(dst); + /* SO_LINGER */ for( i = 0; i < ARRAY_SIZE(linger_testvals);i++) { size = sizeof(lingval); diff --git a/server/sock.c b/server/sock.c index b63412ab216..abf267f13f6 100644 --- a/server/sock.c +++ b/server/sock.c @@ -240,6 +240,7 @@ struct sock struct poll_req *main_poll; /* main poll */ union win_sockaddr addr; /* socket name */ int addr_len; /* socket name length */ + unsigned int default_rcvbuf; /* initial advisory recv buffer size */ unsigned int rcvbuf; /* advisory recv buffer size */ unsigned int sndbuf; /* advisory send buffer size */ unsigned int rcvtimeo; /* receive timeout in ms */ @@ -1733,6 +1734,7 @@ static struct sock *create_socket(void) sock->reset = 0; sock->reuseaddr = 0; sock->exclusiveaddruse = 0; + sock->default_rcvbuf = 0; sock->rcvbuf = 0; sock->sndbuf = 0; sock->rcvtimeo = 0; @@ -1916,7 +1918,7 @@ static int init_socket( struct sock *sock, int family, int type, int protocol )
len = sizeof(value); if (!getsockopt( sockfd, SOL_SOCKET, SO_RCVBUF, &value, &len )) - sock->rcvbuf = value; + sock->rcvbuf = sock->default_rcvbuf = value;
len = sizeof(value); if (!getsockopt( sockfd, SOL_SOCKET, SO_SNDBUF, &value, &len )) @@ -2011,6 +2013,7 @@ static struct sock *accept_socket( struct sock *sock ) acceptsock->reuseaddr = sock->reuseaddr; acceptsock->exclusiveaddruse = sock->exclusiveaddruse; acceptsock->sndbuf = sock->sndbuf; + acceptsock->default_rcvbuf = sock->default_rcvbuf; acceptsock->rcvbuf = sock->rcvbuf; acceptsock->sndtimeo = sock->sndtimeo; acceptsock->rcvtimeo = sock->rcvtimeo; @@ -3064,7 +3067,7 @@ static void sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async )
case IOCTL_AFD_WINE_SET_SO_RCVBUF: { - DWORD rcvbuf; + DWORD rcvbuf, set_rcvbuf;
if (get_req_data_size() < sizeof(rcvbuf)) { @@ -3072,8 +3075,9 @@ static void sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async ) return; } rcvbuf = *(DWORD *)get_req_data(); + set_rcvbuf = max( rcvbuf, sock->default_rcvbuf );
- if (!setsockopt( unix_fd, SOL_SOCKET, SO_RCVBUF, (char *)&rcvbuf, sizeof(rcvbuf) )) + if (!setsockopt( unix_fd, SOL_SOCKET, SO_RCVBUF, (char *)&set_rcvbuf, sizeof(set_rcvbuf) )) sock->rcvbuf = rcvbuf; else set_error( sock_get_ntstatus( errno ) );