diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c index 462f153..5218b5c 100644 --- a/dlls/ws2_32/socket.c +++ b/dlls/ws2_32/socket.c @@ -1018,6 +1018,57 @@ static inline int get_rcvsnd_timeo( int fd, int optname) return ret; } +static inline int get_sock_family(int fd) +{ + int optval = AF_UNSPEC; +#ifdef SO_DOMAIN + socklen_t optlen = sizeof(optval); + if (getsockopt(fd, SOL_SOCKET, SO_DOMAIN, (char *) &optval, &optlen)) + ERR("getsockopt(SO_DOMAIN) failed\n"); +#else + union generic_unix_sockaddr uaddr; + socklen_t uaddrlen = sizeof(uaddr); + + if (!getsockname(fd, &uaddr.addr, &uaddrlen)) + optval = uaddr.addr.sa_family; + + if (optval == AF_UNSPEC) + { + optval = AF_INET; + ERR("could not detect socket family - defaulting to AF_INET\n"); + } +#endif + return optval; +} + +static inline int get_sock_type(int fd) +{ + int optval = 0; + socklen_t optlen = sizeof(optval); + if (getsockopt(fd, SOL_SOCKET, SO_TYPE, (char *) &optval, &optlen)) + ERR("getsockopt(SO_TYPE) failed\n"); + return optval; +} + +static inline int get_sock_protocol(int fd) +{ + int optval = IPPROTO_IP; +#ifdef SO_PROTOCOL + socklen_t optlen = sizeof(optval); + if (getsockopt(fd, SOL_SOCKET, SO_PROTOCOL, (char *) &optval, &optlen)) + ERR("getsockopt(SO_PROTOCOL) failed\n"); +#elif defined(SO_PROTOTYPE) + socklen_t optlen = sizeof(optval); + if (getsockopt(fd, SOL_SOCKET, SO_PROTOTYPE, (char *) &optval, &optlen)) + ERR("getsockopt(SO_PROTOTYPE) failed\n"); +#else + int socktype = get_sock_type(fd); + if (socktype == SOCK_STREAM) optval = IPPROTO_TCP; + else if (socktype == SOCK_DGRAM) optval = IPPROTO_UDP; +#endif + return optval; +} + /* macro wrappers for portability */ #ifdef SO_RCVTIMEO #define GET_RCVTIMEO(fd) get_rcvsnd_timeo( (fd), SO_RCVTIMEO) @@ -1116,6 +1167,60 @@ convert_socktype_u2w(int unixsocktype) { return -1; } +static int fill_protocol_info(int fd, int unicode, char *optval) +{ + int sockfamily, socktype, sockproto, items, sz, i; + DWORD listsize = 0; + WSAPROTOCOL_INFOW *buffer = NULL; + + union _infow + { + WSAPROTOCOL_INFOA *a; + WSAPROTOCOL_INFOA *w; + } info; + info.a = (WSAPROTOCOL_INFOA *) optval; + + sz = unicode ? sizeof(WSAPROTOCOL_INFOW) : sizeof(WSAPROTOCOL_INFOA); + memset(optval, 0, sz); + + sockfamily = convert_af_u2w(get_sock_family(fd)); + socktype = convert_socktype_u2w(get_sock_type(fd)); + sockproto = convert_proto_u2w(get_sock_protocol(fd)); + + /* Start by filling basic information in case our search below fails */ + info.a->iAddressFamily = sockfamily; + info.a->iSocketType = socktype; + info.a->iProtocol = sockproto; + + items = WSAEnumProtocolsW(NULL, NULL, &listsize); + if (items == SOCKET_ERROR && WSAGetLastError() == WSAENOBUFS && + (buffer = HeapAlloc(GetProcessHeap(), 0, listsize))) + { + items = WSAEnumProtocolsW(NULL, buffer, &listsize); + for (i = 0; i < items; i++) + { + if (buffer[i].iAddressFamily == sockfamily && + buffer[i].iSocketType == socktype && + buffer[i].iProtocol == sockproto) + { + if (unicode) + memcpy(info.w, &buffer[i], sz); + else + { + /* convert the structure from W to A */ + memcpy(info.a, &buffer[i], FIELD_OFFSET(WSAPROTOCOL_INFOA, szProtocol)); + WideCharToMultiByte(CP_ACP, 0, buffer[i].szProtocol, -1, + info.a->szProtocol, WSAPROTOCOL_LEN+1, NULL, NULL); + } + break; + } + } + } + + HeapFree(GetProcessHeap(), 0, buffer); + return sz; +} + /* ----------------------------------- API ----- * * Init / cleanup / error checking. @@ -2776,6 +2881,22 @@ INT WINAPI WS_getsockopt(SOCKET s, INT level, TRACE("getting global SO_OPENTYPE = 0x%x\n", *((int*)optval) ); return 0; + case WS_SO_PROTOCOL_INFOA: + case WS_SO_PROTOCOL_INFOW: + if (!optlen || !optval || + *optlen < (optname == WS_SO_PROTOCOL_INFOA ? + sizeof(WSAPROTOCOL_INFOA) : sizeof(WSAPROTOCOL_INFOW))) + { + SetLastError(WSAEFAULT); + return SOCKET_ERROR; + } + if ( (fd = get_sock_fd( s, 0, NULL )) == -1) + return SOCKET_ERROR; + + *optlen = fill_protocol_info(fd, optname == WS_SO_PROTOCOL_INFOW, optval); + release_sock_fd( s, fd ); + return ret; + #ifdef SO_RCVTIMEO case WS_SO_RCVTIMEO: #endif diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c index 5ce4959..38face9 100644 --- a/dlls/ws2_32/tests/sock.c +++ b/dlls/ws2_32/tests/sock.c @@ -1105,6 +1105,18 @@ static void test_set_getsockopt(void) int timeout; LINGER lingval; int size; + WSAPROTOCOL_INFOA infoA; + WSAPROTOCOL_INFOW infoW; + char providername[WSAPROTOCOL_LEN+1]; + struct _prottest + { + int family, type, proto; + } prottest[] = { + {AF_INET, SOCK_STREAM, IPPROTO_TCP}, + {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, + {AF_INET6, SOCK_STREAM, IPPROTO_TCP}, + {AF_INET6, SOCK_DGRAM, IPPROTO_UDP} + }; s = socket(AF_INET, SOCK_STREAM, 0); ok(s!=INVALID_SOCKET, "socket() failed error: %d\n", WSAGetLastError()); @@ -1221,6 +1233,82 @@ todo_wine err, WSAGetLastError()); closesocket(s); + + /* test SO_PROTOCOL_INFOA invalid parameters */ + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, NULL, NULL), + "getsockopt should have failed\n"); + err = WSAGetLastError(); + ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err); + ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, (char *) &infoA, NULL), + "getsockopt should have failed\n"); + err = WSAGetLastError(); + ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err); + ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, NULL, &size), + "getsockopt should have failed\n"); + err = WSAGetLastError(); + ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err); + + size = sizeof(WSAPROTOCOL_INFOA) / 2; + ok(getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, (char *) &infoA, &size), + "getsockopt should have failed\n"); + err = WSAGetLastError(); + ok(err == WSAEFAULT, "expected 10014, got %d instead\n", err); + closesocket(s); + + /* test SO_PROTOCOL_INFO structure returned for different protocols */ + for (i = 0; i < sizeof(prottest) / sizeof(prottest[0]); i++) + { + s = socket(prottest[i].family, prottest[i].type, prottest[i].proto); + if (s == INVALID_SOCKET && prottest[i].family == AF_INET6) continue; + + ok(s != INVALID_SOCKET, "Failed to create socket: %d\n", + WSAGetLastError()); + + /* compare both A and W version */ + infoA.szProtocol[0] = 0; + size = sizeof(WSAPROTOCOL_INFOA); + err = getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOA, (char *) &infoA, &size); + ok(!err,"getsockopt failed with %d\n", WSAGetLastError()); + + infoW.szProtocol[0] = 0; + size = sizeof(WSAPROTOCOL_INFOW); + err = getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOW, (char *) &infoW, &size); + ok(!err,"getsockopt failed with %d\n", WSAGetLastError()); + + trace("provider name '%s', family %d, type %d, proto %d\n", + infoA.szProtocol, prottest[i].family, prottest[i].type, prottest[i].proto); + + /* TODO: remove when WSAEnumProtocols return AF_INET6 data */ + if (prottest[i].family == AF_INET6) + { + todo_wine { + ok(infoA.szProtocol[0], "WSAPROTOCOL_INFOA was not filled\n"); + ok(infoW.szProtocol[0], "WSAPROTOCOL_INFOW was not filled\n"); + } + } + else + { + ok(infoA.szProtocol[0], "WSAPROTOCOL_INFOA was not filled\n"); + ok(infoW.szProtocol[0], "WSAPROTOCOL_INFOW was not filled\n"); + } + + WideCharToMultiByte(CP_ACP, 0, infoW.szProtocol, -1, + providername, sizeof(providername), NULL, NULL); + ok(!strcmp(infoA.szProtocol,providername), + "different provider names '%s' != '%s'\n", infoA.szProtocol, providername); + + ok(!memcmp(&infoA, &infoW, FIELD_OFFSET(WSAPROTOCOL_INFOA, szProtocol)), + "SO_PROTOCOL_INFO[A/W] comparison failed\n"); + ok(infoA.iAddressFamily == prottest[i].family, "socket family invalid, expected %d received %d\n", + prottest[i].family, infoA.iAddressFamily); + ok(infoA.iSocketType == prottest[i].type, "socket type invalid, expected %d received %d\n", + prottest[i].type, infoA.iSocketType); + ok(infoA.iProtocol == prottest[i].proto, "socket protocol invalid, expected %d received %d\n", + prottest[i].proto, infoA.iProtocol); + + closesocket(s); + } } static void test_so_reuseaddr(void)