-- v2: ws2_32/tests: Register an exception with the firewall to avoid a dialog.
From: Jinoh Kang jinoh.kang.kr@gmail.com
--- dlls/ws2_32/tests/Makefile.in | 2 +- dlls/ws2_32/tests/sock.c | 275 ++++++++++++++++++++++++++-------- 2 files changed, 215 insertions(+), 62 deletions(-)
diff --git a/dlls/ws2_32/tests/Makefile.in b/dlls/ws2_32/tests/Makefile.in index b1b10c1636e..7cc79b8838b 100644 --- a/dlls/ws2_32/tests/Makefile.in +++ b/dlls/ws2_32/tests/Makefile.in @@ -1,5 +1,5 @@ TESTDLL = ws2_32.dll -IMPORTS = iphlpapi ws2_32 user32 +IMPORTS = iphlpapi ws2_32 user32 oleaut32 ole32 advapi32 shlwapi
C_SRCS = \ afd.c \ diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c index e53d9991e56..c8abc56443f 100644 --- a/dlls/ws2_32/tests/sock.c +++ b/dlls/ws2_32/tests/sock.c @@ -23,6 +23,7 @@
#include <ntstatus.h> #define WIN32_NO_STATUS +#define COBJMACROS #include <winsock2.h> #include <windows.h> #include <winternl.h> @@ -33,6 +34,9 @@ #include <mswsock.h> #include <mstcpip.h> #include <stdio.h> +#include <initguid.h> +#include <shlwapi.h> +#include <netfw.h> #include "wine/test.h"
#define MAX_CLIENTS 4 /* Max number of clients */ @@ -2008,7 +2012,7 @@ static void test_set_getsockopt(void) } }
-static void test_reuseaddr(void) +static void test_reuseaddr(BOOL firewall_may_trigger) { static struct sockaddr_in6 saddr_in6_any, saddr_in6_loopback; static struct sockaddr_in6 saddr_in6_any_v4mapped, saddr_in6_loopback_v4mapped; @@ -2255,90 +2259,97 @@ static void test_reuseaddr(void) closesocket(s3); closesocket(s4);
- /* Test binding and listening on any addr together with loopback, any addr first. */ - s1 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s1 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + if (firewall_may_trigger) + { + skip( "skipping binding to ANY to avoid triggering firewall dialog\n" ); + } + else + { + /* Test binding and listening on any addr together with loopback, any addr first. */ + s1 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s1 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = bind(s1, tests[i].addr_any, tests[i].addrlen); - ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = bind(s1, tests[i].addr_any, tests[i].addrlen); + ok(!rc, "got error %d.\n", WSAGetLastError());
- rc = listen(s1, 1); - ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = listen(s1, 1); + ok(!rc, "got error %d.\n", WSAGetLastError());
- s2 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s2 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + s2 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s2 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = bind(s2, tests[i].addr_loopback, tests[i].addrlen); - todo_wine ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = bind(s2, tests[i].addr_loopback, tests[i].addrlen); + todo_wine ok(!rc, "got error %d.\n", WSAGetLastError());
- rc = listen(s2, 1); - todo_wine ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = listen(s2, 1); + todo_wine ok(!rc, "got error %d.\n", WSAGetLastError());
- s3 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s3 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + s3 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s3 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = connect(s3, tests[i].addr_loopback, tests[i].addrlen); - ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = connect(s3, tests[i].addr_loopback, tests[i].addrlen); + ok(!rc, "got error %d.\n", WSAGetLastError());
- size = tests[i].addrlen; - s4 = accept(s2, (struct sockaddr *)&saddr, &size); - todo_wine ok(s4 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + size = tests[i].addrlen; + s4 = accept(s2, (struct sockaddr *)&saddr, &size); + todo_wine ok(s4 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- closesocket(s1); - closesocket(s2); - closesocket(s3); - closesocket(s4); + closesocket(s1); + closesocket(s2); + closesocket(s3); + closesocket(s4);
- /* Test binding and listening on any addr together with loopback, loopback addr first. */ + /* Test binding and listening on any addr together with loopback, loopback addr first. */
- s1 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s1 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + s1 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s1 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = bind(s1, tests[i].addr_loopback, tests[i].addrlen); - ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = bind(s1, tests[i].addr_loopback, tests[i].addrlen); + ok(!rc, "got error %d.\n", WSAGetLastError());
- rc = listen(s1, 1); - ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = listen(s1, 1); + ok(!rc, "got error %d.\n", WSAGetLastError());
- s2 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s2 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + s2 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s2 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = bind(s2, tests[i].addr_any, tests[i].addrlen); - todo_wine ok(!rc, "got rc %d, error %d.\n", rc, WSAGetLastError()); + rc = bind(s2, tests[i].addr_any, tests[i].addrlen); + todo_wine ok(!rc, "got rc %d, error %d.\n", rc, WSAGetLastError());
- rc = listen(s2, 1); - todo_wine ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = listen(s2, 1); + todo_wine ok(!rc, "got error %d.\n", WSAGetLastError());
- s3 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s3 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + s3 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s3 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = connect(s3, tests[i].addr_loopback, tests[i].addrlen); - ok(!rc, "got error %d.\n", WSAGetLastError()); - size = tests[i].addrlen; - s4 = accept(s1, (struct sockaddr *)&saddr, &size); + rc = connect(s3, tests[i].addr_loopback, tests[i].addrlen); + ok(!rc, "got error %d.\n", WSAGetLastError()); + size = tests[i].addrlen; + s4 = accept(s1, (struct sockaddr *)&saddr, &size);
- ok(s4 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + ok(s4 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- closesocket(s1); - closesocket(s2); - closesocket(s3); - closesocket(s4); + closesocket(s1); + closesocket(s2); + closesocket(s3); + closesocket(s4);
- /* Test binding to INADDR_ANY on two sockets. */ - s1 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s1 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + /* Test binding to INADDR_ANY on two sockets. */ + s1 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s1 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = bind(s1, tests[i].addr_any, tests[i].addrlen); - ok(!rc, "got error %d.\n", WSAGetLastError()); + rc = bind(s1, tests[i].addr_any, tests[i].addrlen); + ok(!rc, "got error %d.\n", WSAGetLastError());
- s2 = socket(tests[i].domain, SOCK_STREAM, 0); - ok(s2 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError()); + s2 = socket(tests[i].domain, SOCK_STREAM, 0); + ok(s2 != INVALID_SOCKET, "got error %d.\n", WSAGetLastError());
- rc = bind(s2, tests[i].addr_any, tests[i].addrlen); - ok(rc == SOCKET_ERROR && WSAGetLastError() == WSAEADDRINUSE, "got rc %d, error %d.\n", rc, WSAGetLastError()); + rc = bind(s2, tests[i].addr_any, tests[i].addrlen); + ok(rc == SOCKET_ERROR && WSAGetLastError() == WSAEADDRINUSE, "got rc %d, error %d.\n", rc, WSAGetLastError());
- closesocket(s1); - closesocket(s2); + closesocket(s1); + closesocket(s2); + }
winetest_pop_context(); } @@ -13820,8 +13831,130 @@ static void test_tcp_sendto_recvfrom(void) closesocket(server); }
+enum firewall_op +{ + APP_ADD, + APP_REMOVE +}; + +static BOOL is_process_elevated(void) +{ + HANDLE token; + if (OpenProcessToken( GetCurrentProcess(), TOKEN_QUERY, &token )) + { + TOKEN_ELEVATION_TYPE type; + DWORD size; + BOOL ret; + + ret = GetTokenInformation( token, TokenElevationType, &type, sizeof(type), &size ); + CloseHandle( token ); + return (ret && type == TokenElevationTypeFull); + } + return FALSE; +} + +static BOOL is_firewall_enabled(void) +{ + HRESULT hr, init; + INetFwMgr *mgr = NULL; + INetFwPolicy *policy = NULL; + INetFwProfile *profile = NULL; + VARIANT_BOOL enabled = VARIANT_FALSE; + + init = CoInitializeEx( 0, COINIT_APARTMENTTHREADED ); + + hr = CoCreateInstance( &CLSID_NetFwMgr, NULL, CLSCTX_INPROC_SERVER, &IID_INetFwMgr, + (void **)&mgr ); + ok( hr == S_OK, "got %#lx\n", hr ); + if (hr != S_OK) goto done; + + hr = INetFwMgr_get_LocalPolicy( mgr, &policy ); + ok( hr == S_OK, "got %#lx\n", hr ); + if (hr != S_OK) goto done; + + hr = INetFwPolicy_get_CurrentProfile( policy, &profile ); + if (hr != S_OK) goto done; + + hr = INetFwProfile_get_FirewallEnabled( profile, &enabled ); + ok( hr == S_OK, "got %#lx\n", hr ); + +done: + if (policy) INetFwPolicy_Release( policy ); + if (profile) INetFwProfile_Release( profile ); + if (mgr) INetFwMgr_Release( mgr ); + if (SUCCEEDED( init )) CoUninitialize(); + return (enabled == VARIANT_TRUE); +} + +static HRESULT set_firewall( enum firewall_op op ) +{ + HRESULT hr, init; + INetFwMgr *mgr = NULL; + INetFwPolicy *policy = NULL; + INetFwProfile *profile = NULL; + INetFwAuthorizedApplication *app = NULL; + INetFwAuthorizedApplications *apps = NULL; + BSTR name, image = SysAllocStringLen( NULL, MAX_PATH ); + + if (!GetModuleFileNameW( NULL, image, MAX_PATH )) + { + SysFreeString( image ); + return E_FAIL; + } + init = CoInitializeEx( 0, COINIT_APARTMENTTHREADED ); + + hr = CoCreateInstance( &CLSID_NetFwMgr, NULL, CLSCTX_INPROC_SERVER, &IID_INetFwMgr, + (void **)&mgr ); + ok( hr == S_OK, "got %#lx\n", hr ); + if (hr != S_OK) goto done; + + hr = INetFwMgr_get_LocalPolicy( mgr, &policy ); + ok( hr == S_OK, "got %#lx\n", hr ); + if (hr != S_OK) goto done; + + hr = INetFwPolicy_get_CurrentProfile( policy, &profile ); + if (hr != S_OK) goto done; + + hr = INetFwProfile_get_AuthorizedApplications( profile, &apps ); + ok( hr == S_OK, "got %#lx\n", hr ); + if (hr != S_OK) goto done; + + hr = CoCreateInstance( &CLSID_NetFwAuthorizedApplication, NULL, CLSCTX_INPROC_SERVER, + &IID_INetFwAuthorizedApplication, (void **)&app ); + ok( hr == S_OK, "got %#lx\n", hr ); + if (hr != S_OK) goto done; + + hr = INetFwAuthorizedApplication_put_ProcessImageFileName( app, image ); + if (hr != S_OK) goto done; + + name = SysAllocString( PathFindFileNameW( image ) ); + hr = INetFwAuthorizedApplication_put_Name( app, name ); + SysFreeString( name ); + ok( hr == S_OK, "got %#lx\n", hr ); + if (hr != S_OK) goto done; + + if (op == APP_ADD) + hr = INetFwAuthorizedApplications_Add( apps, app ); + else if (op == APP_REMOVE) + hr = INetFwAuthorizedApplications_Remove( apps, image ); + else + hr = E_INVALIDARG; + +done: + if (app) INetFwAuthorizedApplication_Release( app ); + if (apps) INetFwAuthorizedApplications_Release( apps ); + if (policy) INetFwPolicy_Release( policy ); + if (profile) INetFwProfile_Release( profile ); + if (mgr) INetFwMgr_Release( mgr ); + if (SUCCEEDED( init )) CoUninitialize(); + SysFreeString( image ); + return hr; +} + START_TEST( sock ) { + BOOL firewall_enabled = is_firewall_enabled(); + BOOL firewall_app_added = FALSE; int i;
/* Leave these tests at the beginning. They depend on WSAStartup not having been @@ -13829,10 +13962,28 @@ START_TEST( sock ) test_WithoutWSAStartup(); test_WithWSAStartup();
+ if (firewall_enabled) + { + HRESULT hr; + + if (!is_process_elevated()) + { + trace( "no privileges or token query failed, skipping firewall configuration\n" ); + } + else if ((hr = set_firewall( APP_ADD )) != S_OK) + { + trace( "can't authorize app in firewall %#lx\n", hr ); + } + else + { + firewall_app_added = TRUE; + } + } + Init();
test_set_getsockopt(); - test_reuseaddr(); + test_reuseaddr(firewall_enabled && !firewall_app_added); test_ip_pktinfo(); test_ipv4_cmsg(); test_ipv6_cmsg(); @@ -13931,4 +14082,6 @@ START_TEST( sock ) test_send();
Exit(); + + if (firewall_app_added) set_firewall( APP_REMOVE ); }
'ws2_32_test'.
Thanks! I made sure there are no more references to `webservices_test`.
Maybe we should just reuse 'image'.
Unsure about backslashes, I just ripped out the basename.