Signed-off-by: Hans Leidekker hans@codeweavers.com --- dlls/winhttp/request.c | 101 ++++++++++++++++++++++++++++++++- dlls/winhttp/session.c | 13 ++--- dlls/winhttp/tests/winhttp.c | 34 +++++------ dlls/winhttp/winhttp_private.h | 14 +++++ 4 files changed, 133 insertions(+), 29 deletions(-)
diff --git a/dlls/winhttp/request.c b/dlls/winhttp/request.c index 9173bdb56d1..0511dd44eec 100644 --- a/dlls/winhttp/request.c +++ b/dlls/winhttp/request.c @@ -33,6 +33,7 @@ #include "httprequestid.h" #include "schannel.h" #include "winhttp.h" +#include "ntsecapi.h"
#include "wine/debug.h" #include "winhttp_private.h" @@ -2098,6 +2099,25 @@ static char *build_wire_request( struct request *request, DWORD *len ) return ret; }
+static WCHAR *create_websocket_key(void) +{ + WCHAR *ret; + char buf[16]; + DWORD base64_len = ((sizeof(buf) + 2) * 4) / 3; + if (!RtlGenRandom( buf, sizeof(buf) )) return NULL; + if ((ret = heap_alloc( (base64_len + 1) * sizeof(WCHAR) ))) encode_base64( buf, sizeof(buf), ret ); + return ret; +} + +static DWORD add_websocket_key_header( struct request *request ) +{ + WCHAR *key = create_websocket_key(); + if (!key) return ERROR_OUTOFMEMORY; + process_header( request, L"Sec-WebSocket-Key", key, WINHTTP_ADDREQ_FLAG_ADD | WINHTTP_ADDREQ_FLAG_REPLACE, TRUE ); + heap_free( key ); + return ERROR_SUCCESS; +} + static DWORD send_request( struct request *request, const WCHAR *headers, DWORD headers_len, void *optional, DWORD optional_len, DWORD total_len, DWORD_PTR context, BOOL async ) { @@ -2125,7 +2145,14 @@ static DWORD send_request( struct request *request, const WCHAR *headers, DWORD swprintf( length, ARRAY_SIZE(length), L"%ld", total_len ); process_header( request, L"Content-Length", length, WINHTTP_ADDREQ_FLAG_ADD_IF_NEW, TRUE ); } - if (!(request->hdr.disable_flags & WINHTTP_DISABLE_KEEP_ALIVE)) + if (request->flags & REQUEST_FLAG_WEBSOCKET_UPGRADE) + { + process_header( request, L"Upgrade", L"websocket", WINHTTP_ADDREQ_FLAG_ADD_IF_NEW, TRUE ); + process_header( request, L"Connection", L"Upgrade", WINHTTP_ADDREQ_FLAG_ADD_IF_NEW, TRUE ); + process_header( request, L"Sec-WebSocket-Version", L"13", WINHTTP_ADDREQ_FLAG_ADD_IF_NEW, TRUE ); + if ((ret = add_websocket_key_header( request ))) return ret; + } + else if (!(request->hdr.disable_flags & WINHTTP_DISABLE_KEEP_ALIVE)) { process_header( request, L"Connection", L"Keep-Alive", WINHTTP_ADDREQ_FLAG_ADD_IF_NEW, TRUE ); } @@ -3016,10 +3043,78 @@ BOOL WINAPI WinHttpWriteData( HINTERNET hrequest, LPCVOID buffer, DWORD to_write return !ret; }
+static BOOL socket_query_option( struct object_header *hdr, DWORD option, void *buffer, DWORD *buflen ) +{ + FIXME("unimplemented option %u\n", option); + SetLastError( ERROR_WINHTTP_INVALID_OPTION ); + return FALSE; +} + +static void socket_destroy( struct object_header *hdr ) +{ + struct socket *socket = (struct socket *)hdr; + + TRACE("%p\n", socket); + + release_object( &socket->request->hdr ); + heap_free( socket ); +} + +static BOOL socket_set_option( struct object_header *hdr, DWORD option, void *buffer, DWORD buflen ) +{ + FIXME("unimplemented option %u\n", option); + SetLastError( ERROR_WINHTTP_INVALID_OPTION ); + return FALSE; +} + +static const struct object_vtbl socket_vtbl = +{ + socket_destroy, + socket_query_option, + socket_set_option, +}; + HINTERNET WINAPI WinHttpWebSocketCompleteUpgrade( HINTERNET hrequest, DWORD_PTR context ) { - FIXME("%p, %08lx\n", hrequest, context); - return NULL; + struct socket *socket; + struct request *request; + HINTERNET hsocket = NULL; + + TRACE("%p, %08lx\n", hrequest, context); + + if (!(request = (struct request *)grab_object( hrequest ))) + { + SetLastError( ERROR_INVALID_HANDLE ); + return NULL; + } + if (request->hdr.type != WINHTTP_HANDLE_TYPE_REQUEST) + { + release_object( &request->hdr ); + SetLastError( ERROR_WINHTTP_INCORRECT_HANDLE_TYPE ); + return NULL; + } + if (!(socket = heap_alloc_zero( sizeof(struct socket) ))) + { + release_object( &request->hdr ); + return NULL; + } + socket->hdr.type = WINHTTP_HANDLE_TYPE_SOCKET; + socket->hdr.vtbl = &socket_vtbl; + socket->hdr.refs = 1; + socket->hdr.context = context; + list_init( &socket->hdr.children ); + + addref_object( &request->hdr ); + socket->request = request; + list_add_head( &request->hdr.children, &socket->hdr.entry ); + + if ((hsocket = alloc_handle( &socket->hdr ))) socket->hdr.handle = hsocket; + + release_object( &socket->hdr ); + release_object( &request->hdr ); + TRACE("returning %p\n", hsocket); + if (hsocket) SetLastError( ERROR_SUCCESS ); + return hsocket; }
DWORD WINAPI WinHttpWebSocketSend( HINTERNET hsocket, WINHTTP_WEB_SOCKET_BUFFER_TYPE type, void *buf, DWORD len ) diff --git a/dlls/winhttp/session.c b/dlls/winhttp/session.c index 24455d858a6..87974ffbd8e 100644 --- a/dlls/winhttp/session.c +++ b/dlls/winhttp/session.c @@ -63,9 +63,6 @@ BOOL WINAPI WinHttpCheckPlatform( void ) return TRUE; }
-/*********************************************************************** - * session_destroy (internal) - */ static void session_destroy( struct object_header *hdr ) { struct session *session = (struct session *)hdr; @@ -296,9 +293,6 @@ end: return handle; }
-/*********************************************************************** - * connect_destroy (internal) - */ static void connect_destroy( struct object_header *hdr ) { struct connect *connect = (struct connect *)hdr; @@ -581,9 +575,6 @@ end: return hconnect; }
-/*********************************************************************** - * request_destroy (internal) - */ static void request_destroy( struct object_header *hdr ) { struct request *request = (struct request *)hdr; @@ -1038,6 +1029,10 @@ static BOOL request_set_option( struct object_header *hdr, DWORD option, void *b return FALSE; }
+ case WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET: + request->flags |= REQUEST_FLAG_WEBSOCKET_UPGRADE; + return TRUE; + case WINHTTP_OPTION_CONNECT_RETRIES: FIXME("WINHTTP_OPTION_CONNECT_RETRIES\n"); return TRUE; diff --git a/dlls/winhttp/tests/winhttp.c b/dlls/winhttp/tests/winhttp.c index f08f50b9a9d..8083c4825b2 100644 --- a/dlls/winhttp/tests/winhttp.c +++ b/dlls/winhttp/tests/winhttp.c @@ -3119,7 +3119,7 @@ static void test_websocket(int port) ok(request != NULL, "got %u\n", GetLastError());
ret = WinHttpSetOption(request, WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET, NULL, 0); - todo_wine ok(ret, "got %u\n", GetLastError()); + ok(ret, "got %u\n", GetLastError());
size = sizeof(header); SetLastError(0xdeadbeef); @@ -3175,41 +3175,41 @@ static void test_websocket(int port) size = sizeof(buf); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_CUSTOM | WINHTTP_QUERY_FLAG_REQUEST_HEADERS, L"Sec-WebSocket-Key", buf, &size, &index); - todo_wine ok(ret, "got %u\n", GetLastError()); + ok(ret, "got %u\n", GetLastError());
index = 0; buf[0] = 0; size = sizeof(buf); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_CUSTOM | WINHTTP_QUERY_FLAG_REQUEST_HEADERS, L"Sec-WebSocket-Version", buf, &size, &index); - todo_wine ok(ret, "got %u\n", GetLastError()); + ok(ret, "got %u\n", GetLastError());
ret = WinHttpReceiveResponse(request, NULL); - todo_wine ok(ret, "got %u\n", GetLastError()); + ok(ret, "got %u\n", GetLastError());
count = 0xdeadbeef; ret = WinHttpQueryDataAvailable(request, &count); ok(ret, "got %u\n", GetLastError()); - todo_wine ok(!count, "got %u\n", count); + ok(!count, "got %u\n", count);
header[0] = 0; size = sizeof(header); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_UPGRADE, NULL, &header, &size, NULL); - todo_wine ok(ret, "got %u\n", GetLastError()); - todo_wine ok(!wcscmp( header, L"websocket" ), "got %s\n", wine_dbgstr_w(header)); + ok(ret, "got %u\n", GetLastError()); + ok(!wcscmp( header, L"websocket" ), "got %s\n", wine_dbgstr_w(header));
header[0] = 0; size = sizeof(header); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_CONNECTION, NULL, &header, &size, NULL); - todo_wine ok(ret, "got %u\n", GetLastError()); - todo_wine ok(!wcscmp( header, L"Upgrade" ), "got %s\n", wine_dbgstr_w(header)); + ok(ret, "got %u\n", GetLastError()); + ok(!wcscmp( header, L"Upgrade" ), "got %s\n", wine_dbgstr_w(header));
status = 0xdeadbeef; size = sizeof(status); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, NULL, &status, &size, NULL); ok(ret, "got %u\n", GetLastError()); - todo_wine ok(status == HTTP_STATUS_SWITCH_PROTOCOLS, "got %u\n", status); + ok(status == HTTP_STATUS_SWITCH_PROTOCOLS, "got %u\n", status);
len = 0xdeadbeef; size = sizeof(len); @@ -3220,27 +3220,27 @@ static void test_websocket(int port) index = 0; size = sizeof(buf); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_CUSTOM, L"Sec-WebSocket-Accept", buf, &size, &index); - todo_wine ok(ret, "got %u\n", GetLastError()); + ok(ret, "got %u\n", GetLastError());
socket = pWinHttpWebSocketCompleteUpgrade(request, 0); - todo_wine ok(socket != NULL, "got %u\n", GetLastError()); + ok(socket != NULL, "got %u\n", GetLastError());
header[0] = 0; size = sizeof(header); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_UPGRADE, NULL, &header, &size, NULL); - todo_wine ok(ret, "got %u\n", GetLastError()); - todo_wine ok(!wcscmp( header, L"websocket" ), "got %s\n", wine_dbgstr_w(header)); + ok(ret, "got %u\n", GetLastError()); + ok(!wcscmp( header, L"websocket" ), "got %s\n", wine_dbgstr_w(header));
header[0] = 0; size = sizeof(header); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_CONNECTION, NULL, &header, &size, NULL); - todo_wine ok(ret, "got %u\n", GetLastError()); - todo_wine ok(!wcscmp( header, L"Upgrade" ), "got %s\n", wine_dbgstr_w(header)); + ok(ret, "got %u\n", GetLastError()); + ok(!wcscmp( header, L"Upgrade" ), "got %s\n", wine_dbgstr_w(header));
index = 0; size = sizeof(buf); ret = WinHttpQueryHeaders(request, WINHTTP_QUERY_CUSTOM, L"Sec-WebSocket-Accept", buf, &size, &index); - todo_wine ok(ret, "got %u\n", GetLastError()); + ok(ret, "got %u\n", GetLastError());
/* Send/Receive on websock */
diff --git a/dlls/winhttp/winhttp_private.h b/dlls/winhttp/winhttp_private.h index 657f82f6421..af695f86c3e 100644 --- a/dlls/winhttp/winhttp_private.h +++ b/dlls/winhttp/winhttp_private.h @@ -26,6 +26,8 @@ #include "sspi.h" #include "wincrypt.h"
+#define WINHTTP_HANDLE_TYPE_SOCKET 4 + struct object_header; struct object_vtbl { @@ -154,10 +156,16 @@ struct authinfo BOOL finished; /* finished authenticating */ };
+enum request_flags +{ + REQUEST_FLAG_WEBSOCKET_UPGRADE = 0x01, +}; + struct request { struct object_header hdr; struct connect *connect; + enum request_flags flags; WCHAR *verb; WCHAR *path; WCHAR *version; @@ -201,6 +209,12 @@ struct request } creds[TARGET_MAX][SCHEME_MAX]; };
+struct socket +{ + struct object_header hdr; + struct request *request; +}; + struct task_header { struct list entry;