From: Paul Gofman pgofman@codeweavers.com
--- dlls/secur32/schannel.c | 12 +++++++++--- dlls/secur32/schannel_gnutls.c | 24 ++++++++++++++++++++++++ dlls/secur32/secur32_priv.h | 7 +++++++ dlls/secur32/tests/schannel.c | 12 ++++++------ 4 files changed, 46 insertions(+), 9 deletions(-)
diff --git a/dlls/secur32/schannel.c b/dlls/secur32/schannel.c index 29b870e95a5..98c6f93bc1c 100644 --- a/dlls/secur32/schannel.c +++ b/dlls/secur32/schannel.c @@ -65,6 +65,7 @@ struct schan_context ULONG req_ctx_attr; const CERT_CONTEXT *cert; SIZE_T header_size; + BOOL shutdown_requested; };
static struct schan_handle *schan_handle_table; @@ -901,9 +902,9 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( unsigned char *ptr;
if (!(ctx = schan_get_object(phContext->dwLower, SCHAN_HANDLE_CTX))) return SEC_E_INVALID_HANDLE; - if (!pInput && !is_dtls_context(ctx)) return SEC_E_INCOMPLETE_MESSAGE; + if (!pInput && !ctx->shutdown_requested && !is_dtls_context(ctx)) return SEC_E_INCOMPLETE_MESSAGE;
- if (pInput) + if (!ctx->shutdown_requested && pInput) { if (!validate_input_buffers(pInput)) return SEC_E_INVALID_TOKEN; if ((idx = schan_find_sec_buffer_idx(pInput, 0, SECBUFFER_TOKEN)) == -1) return SEC_E_INCOMPLETE_MESSAGE; @@ -976,6 +977,8 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( params.input_offset = &input_offset; params.output_buffer_idx = &output_buffer_idx; params.output_offset = &output_offset; + params.control_token = ctx->shutdown_requested ? control_token_shutdown : control_token_none; + ctx->shutdown_requested = FALSE; ret = GNUTLS_CALL( handshake, ¶ms );
if (output_buffer_idx != -1) @@ -1575,6 +1578,8 @@ static SECURITY_STATUS SEC_ENTRY schan_DeleteSecurityContext(PCtxtHandle context
static SECURITY_STATUS SEC_ENTRY schan_ApplyControlToken(PCtxtHandle context_handle, PSecBufferDesc input) { + struct schan_context *ctx; + TRACE("%p %p\n", context_handle, input);
dump_buffer_desc(input); @@ -1587,7 +1592,8 @@ static SECURITY_STATUS SEC_ENTRY schan_ApplyControlToken(PCtxtHandle context_han if (input->pBuffers[0].cbBuffer < sizeof(DWORD)) return SEC_E_UNSUPPORTED_FUNCTION; if (*(DWORD *)input->pBuffers[0].pvBuffer != SCHANNEL_SHUTDOWN) return SEC_E_UNSUPPORTED_FUNCTION;
- FIXME("stub.\n"); + ctx = schan_get_object(context_handle->dwLower, SCHAN_HANDLE_CTX); + ctx->shutdown_requested = TRUE;
return SEC_E_OK; } diff --git a/dlls/secur32/schannel_gnutls.c b/dlls/secur32/schannel_gnutls.c index 6d65f41ca1b..b26344aa85e 100644 --- a/dlls/secur32/schannel_gnutls.c +++ b/dlls/secur32/schannel_gnutls.c @@ -121,6 +121,7 @@ MAKE_FUNCPTR(gnutls_x509_crt_deinit); MAKE_FUNCPTR(gnutls_x509_crt_import); MAKE_FUNCPTR(gnutls_x509_crt_init); MAKE_FUNCPTR(gnutls_x509_privkey_deinit); +MAKE_FUNCPTR(gnutls_alert_send); #undef MAKE_FUNCPTR
#if GNUTLS_VERSION_MAJOR < 3 @@ -557,6 +558,25 @@ static NTSTATUS schan_handshake( void *args ) t->in.limit = params->input_size; init_schan_buffers(&t->out, params->output);
+ if (params->control_token == control_token_shutdown) + { + err = pgnutls_alert_send(s, GNUTLS_AL_WARNING, GNUTLS_A_CLOSE_NOTIFY); + if (err == GNUTLS_E_SUCCESS) + { + status = SEC_E_OK; + } + else if (err == GNUTLS_E_AGAIN) + { + status = SEC_E_INVALID_TOKEN; + } + else + { + pgnutls_perror(err); + status = SEC_E_INTERNAL_ERROR; + } + goto done; + } + while (1) { err = pgnutls_handshake(s); @@ -598,6 +618,7 @@ static NTSTATUS schan_handshake( void *args ) break; }
+done: *params->input_offset = t->in.offset; *params->output_buffer_idx = t->out.current_buffer_idx; *params->output_offset = t->out.offset; @@ -1427,6 +1448,7 @@ static NTSTATUS process_attach( void *args ) LOAD_FUNCPTR(gnutls_x509_crt_import) LOAD_FUNCPTR(gnutls_x509_crt_init) LOAD_FUNCPTR(gnutls_x509_privkey_deinit) + LOAD_FUNCPTR(gnutls_alert_send) #undef LOAD_FUNCPTR
if (!(pgnutls_cipher_get_block_size = dlsym(libgnutls_handle, "gnutls_cipher_get_block_size"))) @@ -1707,6 +1729,7 @@ static NTSTATUS wow64_schan_handshake( void *args ) PTR32 input_offset; PTR32 output_buffer_idx; PTR32 output_offset; + enum control_token control_token; } const *params32 = args; struct handshake_params params = { @@ -1717,6 +1740,7 @@ static NTSTATUS wow64_schan_handshake( void *args ) ULongToPtr(params32->input_offset), ULongToPtr(params32->output_buffer_idx), ULongToPtr(params32->output_offset), + params32->control_token, }; if (params32->input) { diff --git a/dlls/secur32/secur32_priv.h b/dlls/secur32/secur32_priv.h index d1321b7d6fd..c43b1f446c4 100644 --- a/dlls/secur32/secur32_priv.h +++ b/dlls/secur32/secur32_priv.h @@ -147,6 +147,12 @@ struct get_unique_channel_binding_params ULONG *bufsize; };
+enum control_token +{ + control_token_none, + control_token_shutdown, +}; + struct handshake_params { schan_session session; @@ -156,6 +162,7 @@ struct handshake_params ULONG *input_offset; int *output_buffer_idx; ULONG *output_offset; + enum control_token control_token; };
struct recv_params diff --git a/dlls/secur32/tests/schannel.c b/dlls/secur32/tests/schannel.c index 62ef9c75837..33915351cb3 100644 --- a/dlls/secur32/tests/schannel.c +++ b/dlls/secur32/tests/schannel.c @@ -1869,13 +1869,13 @@ static void test_connection_shutdown(void) context2.dwLower = context2.dwUpper = 0xdeadbeef; status = InitializeSecurityContextA( &cred_handle, &context, NULL, 0, 0, 0, &buffers[1], 0, &context2, &buffers[0], &attrs, NULL ); - todo_wine ok( status == SEC_E_OK, "got %08lx.\n", status ); + ok( status == SEC_E_OK, "got %08lx.\n", status ); ok( context.dwLower == context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", context.dwLower, context2.dwLower ); ok( context.dwUpper == context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", context.dwUpper, context2.dwUpper ); - todo_wine ok( buf->cbBuffer == sizeof(message), "got cbBuffer %#lx.\n", buf->cbBuffer ); - todo_wine ok( !memcmp( buf->pvBuffer, message, sizeof(message) ), "message data mismatch.\n" ); + ok( buf->cbBuffer == sizeof(message), "got cbBuffer %#lx.\n", buf->cbBuffer ); + ok( !memcmp( buf->pvBuffer, message, sizeof(message) ), "message data mismatch.\n" );
buf->BufferType = SECBUFFER_TOKEN; buf->cbBuffer = 1000; @@ -1896,9 +1896,9 @@ static void test_connection_shutdown(void) buf->cbBuffer = 1000; status = InitializeSecurityContextA( &cred_handle, &context, NULL, 0, 0, 0, NULL, 0, NULL, &buffers[0], &attrs, NULL ); - todo_wine ok( status == SEC_E_OK, "got %08lx.\n", status ); - todo_wine ok( buf->cbBuffer == sizeof(message), "got cbBuffer %#lx.\n", buf->cbBuffer ); - todo_wine ok( !memcmp( buf->pvBuffer, message, sizeof(message) ), "message data mismatch.\n" ); + ok( status == SEC_E_OK, "got %08lx.\n", status ); + ok( buf->cbBuffer == sizeof(message), "got cbBuffer %#lx.\n", buf->cbBuffer ); + ok( !memcmp( buf->pvBuffer, message, sizeof(message) ), "message data mismatch.\n" );
free_buffers( &buffers[0] ); free_buffers( &buffers[1] );