Module: wine Branch: master Commit: 0b69c706b99edac1663e433e99699b51d1adf11c URL: http://source.winehq.org/git/wine.git/?a=commit;h=0b69c706b99edac1663e433e99...
Author: Hans Leidekker hans@codeweavers.com Date: Thu Apr 30 11:50:26 2015 +0200
wininet: Reuse cached basic authorization across sessions.
---
dlls/wininet/http.c | 62 ++++++++++++++++++++++++++++++------------ dlls/wininet/tests/http.c | 68 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 17 deletions(-)
diff --git a/dlls/wininet/http.c b/dlls/wininet/http.c index 8a872e8..f91fd21 100644 --- a/dlls/wininet/http.c +++ b/dlls/wininet/http.c @@ -863,7 +863,7 @@ static void destroy_authinfo( struct HttpAuthInfo *authinfo ) heap_free(authinfo); }
-static UINT retrieve_cached_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR *auth_data) +static UINT retrieve_cached_basic_authorization(const WCHAR *host, const WCHAR *realm, char **auth_data) { basicAuthorizationData *ad; UINT rc = 0; @@ -873,7 +873,7 @@ static UINT retrieve_cached_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR EnterCriticalSection(&authcache_cs); LIST_FOR_EACH_ENTRY(ad, &basicAuthorizationCache, basicAuthorizationData, entry) { - if (!strcmpiW(host,ad->host) && !strcmpW(realm,ad->realm)) + if (!strcmpiW(host, ad->host) && (!realm || !strcmpW(realm, ad->realm))) { TRACE("Authorization found in cache\n"); *auth_data = heap_alloc(ad->authorizationLen); @@ -1620,6 +1620,21 @@ static UINT HTTP_DecodeBase64( LPCWSTR base64, LPSTR bin ) return n; }
+static WCHAR *encode_auth_data( const WCHAR *scheme, const char *data, UINT data_len ) +{ + WCHAR *ret; + UINT len, scheme_len = strlenW( scheme ); + + /* scheme + space + base64 encoded data (3/2/1 bytes data -> 4 bytes of characters) */ + len = scheme_len + 1 + ((data_len + 2) * 4) / 3; + if (!(ret = heap_alloc( (len + 1) * sizeof(WCHAR) ))) return NULL; + memcpy( ret, scheme, scheme_len * sizeof(WCHAR) ); + ret[scheme_len] = ' '; + HTTP_EncodeBase64( data, data_len, ret + scheme_len + 1 ); + return ret; +} + + /*********************************************************************** * HTTP_InsertAuthorization * @@ -1627,27 +1642,16 @@ static UINT HTTP_DecodeBase64( LPCWSTR base64, LPSTR bin ) */ static BOOL HTTP_InsertAuthorization( http_request_t *request, struct HttpAuthInfo *pAuthInfo, LPCWSTR header ) { + static const WCHAR wszBasic[] = {'B','a','s','i','c',0}; + WCHAR *host, *authorization = NULL; + if (pAuthInfo) { - static const WCHAR wszSpace[] = {' ',0}; - static const WCHAR wszBasic[] = {'B','a','s','i','c',0}; - unsigned int len; - WCHAR *authorization = NULL; - if (pAuthInfo->auth_data_len) { - /* scheme + space + base64 encoded data (3/2/1 bytes data -> 4 bytes of characters) */ - len = strlenW(pAuthInfo->scheme)+1+((pAuthInfo->auth_data_len+2)*4)/3; - authorization = heap_alloc((len+1)*sizeof(WCHAR)); - if (!authorization) + if (!(authorization = encode_auth_data(pAuthInfo->scheme, pAuthInfo->auth_data, pAuthInfo->auth_data_len))) return FALSE;
- strcpyW(authorization, pAuthInfo->scheme); - strcatW(authorization, wszSpace); - HTTP_EncodeBase64(pAuthInfo->auth_data, - pAuthInfo->auth_data_len, - authorization+strlenW(authorization)); - /* clear the data as it isn't valid now that it has been sent to the * server, unless it's Basic authentication which doesn't do * connection tracking */ @@ -1664,6 +1668,30 @@ static BOOL HTTP_InsertAuthorization( http_request_t *request, struct HttpAuthIn HTTP_ProcessHeader(request, header, authorization, HTTP_ADDHDR_FLAG_REQ | HTTP_ADDHDR_FLAG_REPLACE); heap_free(authorization); } + else if (!strcmpW(header, szAuthorization) && (host = get_host_header(request))) + { + UINT data_len; + char *data; + + if ((data_len = retrieve_cached_basic_authorization(host, NULL, &data))) + { + TRACE("Found cached basic authorization for %s\n", debugstr_w(host)); + + if (!(authorization = encode_auth_data(wszBasic, data, data_len))) + { + heap_free(data); + heap_free(host); + return FALSE; + } + + TRACE("Inserting authorization: %s\n", debugstr_w(authorization)); + + HTTP_ProcessHeader(request, header, authorization, HTTP_ADDHDR_FLAG_REQ | HTTP_ADDHDR_FLAG_REPLACE); + heap_free(data); + heap_free(authorization); + } + heap_free(host); + } return TRUE; }
diff --git a/dlls/wininet/tests/http.c b/dlls/wininet/tests/http.c index 3e66a27..3b2f9a5 100644 --- a/dlls/wininet/tests/http.c +++ b/dlls/wininet/tests/http.c @@ -2321,6 +2321,20 @@ static DWORD CALLBACK server_thread(LPVOID param) else send(c, notokmsg, sizeof notokmsg-1, 0); } + if (strstr(buffer, "HEAD /upload.txt")) + { + if (strstr(buffer, "Authorization: Basic dXNlcjpwd2Q=")) + send(c, okmsg, sizeof okmsg-1, 0); + else + send(c, noauthmsg, sizeof noauthmsg-1, 0); + } + if (strstr(buffer, "PUT /upload2.txt")) + { + if (strstr(buffer, "Authorization: Basic dXNlcjpwd2Q=")) + send(c, okmsg, sizeof okmsg-1, 0); + else + send(c, notokmsg, sizeof notokmsg-1, 0); + } shutdown(c, 2); closesocket(c); c = -1; @@ -4212,6 +4226,59 @@ static void test_accept_encoding(int port) InternetCloseHandle(ses); }
+static void test_basic_auth_credentials_reuse(int port) +{ + HINTERNET ses, con, req; + DWORD status, size; + BOOL ret; + + ses = InternetOpenA( "winetest", 0, NULL, NULL, 0 ); + ok( ses != NULL, "InternetOpenA failed\n" ); + + con = InternetConnectA( ses, "localhost", port, "user", "pwd", + INTERNET_SERVICE_HTTP, 0, 0 ); + ok( con != NULL, "InternetConnectA failed %u\n", GetLastError() ); + + req = HttpOpenRequestA( con, "HEAD", "/upload.txt", NULL, NULL, NULL, 0, 0 ); + ok( req != NULL, "HttpOpenRequestA failed %u\n", GetLastError() ); + + ret = HttpSendRequestA( req, NULL, 0, NULL, 0 ); + ok( ret, "HttpSendRequestA failed %u\n", GetLastError() ); + + status = 0xdeadbeef; + size = sizeof(status); + ret = HttpQueryInfoA( req, HTTP_QUERY_STATUS_CODE|HTTP_QUERY_FLAG_NUMBER, &status, &size, NULL ); + ok( ret, "HttpQueryInfoA failed %u\n", GetLastError() ); + ok( status == 200, "got %u\n", status ); + + InternetCloseHandle( req ); + InternetCloseHandle( con ); + InternetCloseHandle( ses ); + + ses = InternetOpenA( "winetest", 0, NULL, NULL, 0 ); + ok( ses != NULL, "InternetOpenA failed\n" ); + + con = InternetConnectA( ses, "localhost", port, NULL, NULL, + INTERNET_SERVICE_HTTP, 0, 0 ); + ok( con != NULL, "InternetConnectA failed %u\n", GetLastError() ); + + req = HttpOpenRequestA( con, "PUT", "/upload2.txt", NULL, NULL, NULL, 0, 0 ); + ok( req != NULL, "HttpOpenRequestA failed %u\n", GetLastError() ); + + ret = HttpSendRequestA( req, NULL, 0, NULL, 0 ); + ok( ret, "HttpSendRequestA failed %u\n", GetLastError() ); + + status = 0xdeadbeef; + size = sizeof(status); + ret = HttpQueryInfoA( req, HTTP_QUERY_STATUS_CODE|HTTP_QUERY_FLAG_NUMBER, &status, &size, NULL ); + ok( ret, "HttpQueryInfoA failed %u\n", GetLastError() ); + ok( status == 200, "got %u\n", status ); + + InternetCloseHandle( req ); + InternetCloseHandle( con ); + InternetCloseHandle( ses ); +} + static void test_http_connection(void) { struct server_info si; @@ -4259,6 +4326,7 @@ static void test_http_connection(void) test_head_request(si.port); test_request_content_length(si.port); test_accept_encoding(si.port); + test_basic_auth_credentials_reuse(si.port);
/* send the basic request again to shutdown the server thread */ test_basic_request(si.port, "GET", "/quit");