Module: wine Branch: master Commit: 1d1faa0283727ddec8855ee877bf7033b9869d37 URL: http://source.winehq.org/git/wine.git/?a=commit;h=1d1faa0283727ddec8855ee877...
Author: Bruno Jesus 00cpxxx@gmail.com Date: Thu Sep 5 09:56:19 2013 -0300
ws2_32: Cope with invalid protocols in WSAEnumProtocols.
---
dlls/ws2_32/socket.c | 45 +++++++++++++++++++++-------- dlls/ws2_32/tests/protocol.c | 64 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 91 insertions(+), 18 deletions(-)
diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c index 1c72ce5..a916772 100644 --- a/dlls/ws2_32/socket.c +++ b/dlls/ws2_32/socket.c @@ -191,6 +191,16 @@ static const GUID ProviderIdIPX = { 0x11058240, 0xbe47, 0x11cf, static const GUID ProviderIdSPX = { 0x11058241, 0xbe47, 0x11cf, { 0x95, 0xc8, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92 } };
+static const INT valid_protocols[] = +{ + WS_IPPROTO_TCP, + WS_IPPROTO_UDP, + NSPROTO_IPX, + NSPROTO_SPX, + NSPROTO_SPXII, + 0 +}; + #if defined(IP_UNICAST_IF) && defined(SO_ATTACH_FILTER) # define LINUX_BOUND_IF struct interface_filter { @@ -1259,6 +1269,14 @@ static inline BOOL supported_pf(int pf) } }
+static inline BOOL supported_protocol(int protocol) +{ + int i; + for (i = 0; i < sizeof(valid_protocols) / sizeof(valid_protocols[0]); i++) + if (protocol == valid_protocols[i]) + return TRUE; + return FALSE; +}
/**********************************************************************/
@@ -1665,11 +1683,10 @@ static INT WS_EnterSingleProtocolA( INT protocol, WSAPROTOCOL_INFOA* info ) return ret; }
-static INT WS_EnumProtocols( BOOL unicode, LPINT protocols, LPWSAPROTOCOL_INFOW buffer, LPDWORD len ) +static INT WS_EnumProtocols( BOOL unicode, const INT *protocols, LPWSAPROTOCOL_INFOW buffer, LPDWORD len ) { - INT i = 0; + INT i = 0, items = 0; DWORD size = 0; - INT local[] = { WS_IPPROTO_TCP, WS_IPPROTO_UDP, NSPROTO_IPX, NSPROTO_SPX, NSPROTO_SPXII, 0 }; union _info { LPWSAPROTOCOL_INFOA a; @@ -1677,11 +1694,15 @@ static INT WS_EnumProtocols( BOOL unicode, LPINT protocols, LPWSAPROTOCOL_INFOW } info; info.w = buffer;
- if (!protocols) protocols = local; + if (!protocols) protocols = valid_protocols;
- while (protocols[i]) i++; + while (protocols[i]) + { + if(supported_protocol(protocols[i++])) + items++; + }
- size = i * (unicode ? sizeof(WSAPROTOCOL_INFOW) : sizeof(WSAPROTOCOL_INFOA)); + size = items * (unicode ? sizeof(WSAPROTOCOL_INFOW) : sizeof(WSAPROTOCOL_INFOA));
if (*len < size || !buffer) { @@ -1690,20 +1711,20 @@ static INT WS_EnumProtocols( BOOL unicode, LPINT protocols, LPWSAPROTOCOL_INFOW return SOCKET_ERROR; }
- for (i = 0; protocols[i]; i++) + for (i = items = 0; protocols[i]; i++) { if (unicode) { - if (WS_EnterSingleProtocolW( protocols[i], &info.w[i] ) == SOCKET_ERROR) - break; + if (WS_EnterSingleProtocolW( protocols[i], &info.w[items] ) != SOCKET_ERROR) + items++; } else { - if (WS_EnterSingleProtocolA( protocols[i], &info.a[i] ) == SOCKET_ERROR) - break; + if (WS_EnterSingleProtocolA( protocols[i], &info.a[items] ) != SOCKET_ERROR) + items++; } } - return i; + return items; }
/************************************************************************** diff --git a/dlls/ws2_32/tests/protocol.c b/dlls/ws2_32/tests/protocol.c index 2b95381..1b9a429 100644 --- a/dlls/ws2_32/tests/protocol.c +++ b/dlls/ws2_32/tests/protocol.c @@ -60,9 +60,10 @@ static void test_service_flags(int family, int version, int socktype, int protoc
static void test_WSAEnumProtocolsA(void) { - INT ret; + INT ret, i, j, found; DWORD len = 0, error; WSAPROTOCOL_INFOA info, *buffer; + INT ptest[] = {0xdead, IPPROTO_TCP, 0xcafe, IPPROTO_UDP, 0xbeef, 0};
ret = WSAEnumProtocolsA( NULL, NULL, &len ); ok( ret == SOCKET_ERROR, "WSAEnumProtocolsA() succeeded unexpectedly\n"); @@ -80,8 +81,6 @@ static void test_WSAEnumProtocolsA(void)
if (buffer) { - INT i; - ret = WSAEnumProtocolsA( NULL, buffer, &len ); ok( ret != SOCKET_ERROR, "WSAEnumProtocolsA() failed unexpectedly: %d\n", WSAGetLastError() ); @@ -96,13 +95,41 @@ static void test_WSAEnumProtocolsA(void)
HeapFree( GetProcessHeap(), 0, buffer ); } + + /* Test invalid protocols in the list */ + ret = WSAEnumProtocolsA( ptest, NULL, &len ); + ok( ret == SOCKET_ERROR, "WSAEnumProtocolsA() succeeded unexpectedly\n"); + error = WSAGetLastError(); + ok( error == WSAENOBUFS, "Expected 10055, received %d\n", error); + + buffer = HeapAlloc( GetProcessHeap(), 0, len ); + + if (buffer) + { + ret = WSAEnumProtocolsA( ptest, buffer, &len ); + ok( ret != SOCKET_ERROR, "WSAEnumProtocolsA() failed unexpectedly: %d\n", + WSAGetLastError() ); + ok( ret >= 2, "Expected at least 2 items, received %d\n", ret); + + for (i = found = 0; i < ret; i++) + for (j = 0; j < sizeof(ptest) / sizeof(ptest[0]); j++) + if (buffer[i].iProtocol == ptest[j]) + { + found |= 1 << j; + break; + } + ok(found == 0x0A, "Expected 2 bits represented as 0xA, received 0x%x\n", found); + + HeapFree( GetProcessHeap(), 0, buffer ); + } }
static void test_WSAEnumProtocolsW(void) { - INT ret; + INT ret, i, j, found; DWORD len = 0, error; WSAPROTOCOL_INFOW info, *buffer; + INT ptest[] = {0xdead, IPPROTO_TCP, 0xcafe, IPPROTO_UDP, 0xbeef, 0};
ret = WSAEnumProtocolsW( NULL, NULL, &len ); ok( ret == SOCKET_ERROR, "WSAEnumProtocolsW() succeeded unexpectedly\n"); @@ -120,8 +147,6 @@ static void test_WSAEnumProtocolsW(void)
if (buffer) { - INT i; - ret = WSAEnumProtocolsW( NULL, buffer, &len ); ok( ret != SOCKET_ERROR, "WSAEnumProtocolsW() failed unexpectedly: %d\n", WSAGetLastError() ); @@ -136,6 +161,33 @@ static void test_WSAEnumProtocolsW(void)
HeapFree( GetProcessHeap(), 0, buffer ); } + + /* Test invalid protocols in the list */ + ret = WSAEnumProtocolsW( ptest, NULL, &len ); + ok( ret == SOCKET_ERROR, "WSAEnumProtocolsW() succeeded unexpectedly\n"); + error = WSAGetLastError(); + ok( error == WSAENOBUFS, "Expected 10055, received %d\n", error); + + buffer = HeapAlloc( GetProcessHeap(), 0, len ); + + if (buffer) + { + ret = WSAEnumProtocolsW( ptest, buffer, &len ); + ok( ret != SOCKET_ERROR, "WSAEnumProtocolsW() failed unexpectedly: %d\n", + WSAGetLastError() ); + ok( ret >= 2, "Expected at least 2 items, received %d\n", ret); + + for (i = found = 0; i < ret; i++) + for (j = 0; j < sizeof(ptest) / sizeof(ptest[0]); j++) + if (buffer[i].iProtocol == ptest[j]) + { + found |= 1 << j; + break; + } + ok(found == 0x0A, "Expected 2 bits represented as 0xA, received 0x%x\n", found); + + HeapFree( GetProcessHeap(), 0, buffer ); + } }
START_TEST( protocol )