Adds support for AcceptSecurityContext to secur32 schannel
Note: I'm not especially familiar with the secure channel API and mostly poked things until they worked. I can confirm it works well enough for [an application](https://store.steampowered.com/app/1787820/Mirror_Party/) using Microsoft's [Microsoft.Extensions.Hosting](https://learn.microsoft.com/en-us/dotnet/api/microsoft.extensions.hosting) .NET api to successfully negotiate a TLS connection with curl though.
From: Evan Tang etang@codeweavers.com
--- dlls/secur32/schannel.c | 103 +++++++++++++++++++++++---------- dlls/secur32/schannel_gnutls.c | 51 +++++++++++----- 2 files changed, 107 insertions(+), 47 deletions(-)
diff --git a/dlls/secur32/schannel.c b/dlls/secur32/schannel.c index 23917626497..efb5c85a747 100644 --- a/dlls/secur32/schannel.c +++ b/dlls/secur32/schannel.c @@ -769,14 +769,11 @@ static BOOL validate_input_buffers(SecBufferDesc *desc) return TRUE; }
-/*********************************************************************** - * InitializeSecurityContextW - */ -static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( +static SECURITY_STATUS schan_AcceptOrInitializeSecurityContext( PCredHandle phCredential, PCtxtHandle phContext, SEC_WCHAR *pszTargetName, - ULONG fContextReq, ULONG Reserved1, ULONG TargetDataRep, - PSecBufferDesc pInput, ULONG Reserved2, PCtxtHandle phNewContext, - PSecBufferDesc pOutput, ULONG *pfContextAttr, PTimeStamp ptsExpiry) + PSecBufferDesc pInput, ULONG fContextReq, ULONG TargetDataRep, + PCtxtHandle phNewContext, PSecBufferDesc pOutput, ULONG *pfContextAttr, + PTimeStamp ptsTimeStamp, BOOLEAN bIsServer) { const ULONG extra_size = 0x10000; struct schan_context *ctx; @@ -791,26 +788,20 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( ULONG input_offset = 0, output_offset = 0; SecBufferDesc input_desc, output_desc;
- TRACE("%p %p %s 0x%08lx %ld %ld %p %ld %p %p %p %p\n", phCredential, phContext, - debugstr_w(pszTargetName), fContextReq, Reserved1, TargetDataRep, pInput, - Reserved1, phNewContext, pOutput, pfContextAttr, ptsExpiry); - - dump_buffer_desc(pInput); - dump_buffer_desc(pOutput); - - if (ptsExpiry) + if (ptsTimeStamp) { - ptsExpiry->LowPart = 0; - ptsExpiry->HighPart = 0; + ptsTimeStamp->LowPart = 0; + ptsTimeStamp->HighPart = 0; }
if (!pOutput || !pOutput->cBuffers) return SEC_E_INVALID_TOKEN; for (i = 0; i < pOutput->cBuffers; i++) { ULONG type = pOutput->pBuffers[i].BufferType; + ULONG allocate_memory_flag = bIsServer ? ASC_REQ_ALLOCATE_MEMORY : ISC_REQ_ALLOCATE_MEMORY;
if (type != SECBUFFER_TOKEN && type != SECBUFFER_ALERT) continue; - if (!pOutput->pBuffers[i].cbBuffer && !(fContextReq & ISC_REQ_ALLOCATE_MEMORY)) + if (!pOutput->pBuffers[i].cbBuffer && !(fContextReq & allocate_memory_flag)) return SEC_E_INSUFFICIENT_MEMORY; }
@@ -818,15 +809,16 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( { ULONG_PTR handle; struct create_session_params create_params; + ULONG credential_use = bIsServer ? SECPKG_CRED_INBOUND : SECPKG_CRED_OUTBOUND;
if (!phCredential) return SEC_E_INVALID_HANDLE;
cred = schan_get_object(phCredential->dwLower, SCHAN_HANDLE_CRED); if (!cred) return SEC_E_INVALID_HANDLE;
- if (!(cred->credential_use & SECPKG_CRED_OUTBOUND)) + if (!(cred->credential_use & credential_use)) { - WARN("Invalid credential use %#lx\n", cred->credential_use); + WARN("Invalid credential use %#lx, expected %#lx\n", cred->credential_use, credential_use); return SEC_E_INVALID_HANDLE; }
@@ -848,7 +840,7 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( return SEC_E_INTERNAL_ERROR; }
- if (cred->enabled_protocols & (SP_PROT_DTLS1_0_CLIENT | SP_PROT_DTLS1_2_CLIENT)) + if (cred->enabled_protocols & SP_PROT_DTLS1_X) ctx->header_size = HEADER_SIZE_DTLS; else ctx->header_size = HEADER_SIZE_TLS; @@ -894,12 +886,13 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( phNewContext->dwLower = handle; phNewContext->dwUpper = 0; } - else + + if (bIsServer || phContext) { SIZE_T record_size = 0; unsigned char *ptr;
- if (!(ctx = schan_get_object(phContext->dwLower, SCHAN_HANDLE_CTX))) return SEC_E_INVALID_HANDLE; + if (phContext && !(ctx = schan_get_object(phContext->dwLower, SCHAN_HANDLE_CTX))) return SEC_E_INVALID_HANDLE; if (!pInput && !ctx->shutdown_requested && !is_dtls_context(ctx)) return SEC_E_INCOMPLETE_MESSAGE;
if (!ctx->shutdown_requested && pInput) @@ -938,7 +931,7 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( TRACE("Using expected_size %Iu.\n", expected_size); }
- if (phNewContext) *phNewContext = *phContext; + if (phNewContext && phContext) *phNewContext = *phContext; }
ctx->req_ctx_attr = fContextReq; @@ -1014,16 +1007,45 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( if (buffer->BufferType == SECBUFFER_ALERT) buffer->cbBuffer = 0; }
- *pfContextAttr = ISC_RET_REPLAY_DETECT | ISC_RET_SEQUENCE_DETECT | ISC_RET_CONFIDENTIALITY | ISC_RET_STREAM; - if (ctx->req_ctx_attr & ISC_REQ_EXTENDED_ERROR) *pfContextAttr |= ISC_RET_EXTENDED_ERROR; - if (ctx->req_ctx_attr & ISC_REQ_DATAGRAM) *pfContextAttr |= ISC_RET_DATAGRAM; - if (ctx->req_ctx_attr & ISC_REQ_ALLOCATE_MEMORY) *pfContextAttr |= ISC_RET_ALLOCATED_MEMORY; - if (ctx->req_ctx_attr & ISC_REQ_USE_SUPPLIED_CREDS) *pfContextAttr |= ISC_RET_USED_SUPPLIED_CREDS; - if (ctx->req_ctx_attr & ISC_REQ_MANUAL_CRED_VALIDATION) *pfContextAttr |= ISC_RET_MANUAL_CRED_VALIDATION; + if (bIsServer) + { + *pfContextAttr = ASC_RET_REPLAY_DETECT | ASC_RET_SEQUENCE_DETECT | ASC_RET_CONFIDENTIALITY | ASC_RET_STREAM; + if (ctx->req_ctx_attr & ASC_REQ_EXTENDED_ERROR) *pfContextAttr |= ASC_RET_EXTENDED_ERROR; + if (ctx->req_ctx_attr & ASC_REQ_DATAGRAM) *pfContextAttr |= ASC_RET_DATAGRAM; + if (ctx->req_ctx_attr & ASC_REQ_ALLOCATE_MEMORY) *pfContextAttr |= ASC_RET_ALLOCATED_MEMORY; + } + else + { + *pfContextAttr = ISC_RET_REPLAY_DETECT | ISC_RET_SEQUENCE_DETECT | ISC_RET_CONFIDENTIALITY | ISC_RET_STREAM; + if (ctx->req_ctx_attr & ISC_REQ_EXTENDED_ERROR) *pfContextAttr |= ISC_RET_EXTENDED_ERROR; + if (ctx->req_ctx_attr & ISC_REQ_DATAGRAM) *pfContextAttr |= ISC_RET_DATAGRAM; + if (ctx->req_ctx_attr & ISC_REQ_ALLOCATE_MEMORY) *pfContextAttr |= ISC_RET_ALLOCATED_MEMORY; + if (ctx->req_ctx_attr & ISC_REQ_USE_SUPPLIED_CREDS) *pfContextAttr |= ISC_RET_USED_SUPPLIED_CREDS; + if (ctx->req_ctx_attr & ISC_REQ_MANUAL_CRED_VALIDATION) *pfContextAttr |= ISC_RET_MANUAL_CRED_VALIDATION; + }
return ret; }
+/*********************************************************************** + * InitializeSecurityContextW + */ +static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( + PCredHandle phCredential, PCtxtHandle phContext, SEC_WCHAR *pszTargetName, + ULONG fContextReq, ULONG Reserved1, ULONG TargetDataRep, + PSecBufferDesc pInput, ULONG Reserved2, PCtxtHandle phNewContext, + PSecBufferDesc pOutput, ULONG *pfContextAttr, PTimeStamp ptsExpiry) +{ + TRACE("%p %p %s 0x%08lx %ld %ld %p %ld %p %p %p %p\n", phCredential, phContext, + debugstr_w(pszTargetName), fContextReq, Reserved1, TargetDataRep, pInput, + Reserved1, phNewContext, pOutput, pfContextAttr, ptsExpiry); + + dump_buffer_desc(pInput); + dump_buffer_desc(pOutput); + + return schan_AcceptOrInitializeSecurityContext(phCredential, phContext, pszTargetName, pInput, fContextReq, TargetDataRep, phNewContext, pOutput, pfContextAttr, ptsExpiry, FALSE); +} + /*********************************************************************** * InitializeSecurityContextA */ @@ -1055,6 +1077,23 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextA( return ret; }
+/*********************************************************************** + * AcceptSecurityContext + */ +static SECURITY_STATUS SEC_ENTRY schan_AcceptSecurityContext( + PCredHandle phCredential, PCtxtHandle phContext, PSecBufferDesc pInput, + ULONG fContextReq, ULONG TargetDataRep, PCtxtHandle phNewContext, + PSecBufferDesc pOutput, ULONG *pfContextAttr, PTimeStamp ptsTimeStamp) +{ + TRACE("%p %p %p 0x%08lx %ld %p %p %p %p\n", phCredential, phContext, pInput, + fContextReq, TargetDataRep, phNewContext, pOutput, pfContextAttr, ptsTimeStamp); + + dump_buffer_desc(pInput); + dump_buffer_desc(pOutput); + + return schan_AcceptOrInitializeSecurityContext(phCredential, phContext, NULL, pInput, fContextReq, TargetDataRep, phNewContext, pOutput, pfContextAttr, ptsTimeStamp, TRUE); +} + static void *get_alg_name(ALG_ID id, BOOL wide) { static const struct { @@ -1604,7 +1643,7 @@ static const SecurityFunctionTableA schanTableA = { schan_FreeCredentialsHandle, NULL, /* Reserved2 */ schan_InitializeSecurityContextA, - NULL, /* AcceptSecurityContext */ + schan_AcceptSecurityContext, NULL, /* CompleteAuthToken */ schan_DeleteSecurityContext, schan_ApplyControlToken, /* ApplyControlToken */ @@ -1635,7 +1674,7 @@ static const SecurityFunctionTableW schanTableW = { schan_FreeCredentialsHandle, NULL, /* Reserved2 */ schan_InitializeSecurityContextW, - NULL, /* AcceptSecurityContext */ + schan_AcceptSecurityContext, NULL, /* CompleteAuthToken */ schan_DeleteSecurityContext, schan_ApplyControlToken, /* ApplyControlToken */ diff --git a/dlls/secur32/schannel_gnutls.c b/dlls/secur32/schannel_gnutls.c index b26344aa85e..06d56fccee1 100644 --- a/dlls/secur32/schannel_gnutls.c +++ b/dlls/secur32/schannel_gnutls.c @@ -354,10 +354,12 @@ static ssize_t push_adapter(gnutls_transport_ptr_t transport, const void *buff, return len; }
-static const struct { +struct protocol_priority_flag { DWORD enable_flag; const char *gnutls_flag; -} protocol_priority_flags[] = { +}; + +static const struct protocol_priority_flag client_protocol_priority_flags[] = { {SP_PROT_DTLS1_2_CLIENT, "VERS-DTLS1.2"}, {SP_PROT_DTLS1_0_CLIENT, "VERS-DTLS1.0"}, {SP_PROT_TLS1_3_CLIENT, "VERS-TLS1.3"}, @@ -368,33 +370,46 @@ static const struct { /* {SP_PROT_SSL2_CLIENT} is not supported by GnuTLS */ };
+static const struct protocol_priority_flag server_protocol_priority_flags[] = { + {SP_PROT_DTLS1_2_SERVER, "VERS-DTLS1.2"}, + {SP_PROT_DTLS1_0_SERVER, "VERS-DTLS1.0"}, + {SP_PROT_TLS1_3_SERVER, "VERS-TLS1.3"}, + {SP_PROT_TLS1_2_SERVER, "VERS-TLS1.2"}, + {SP_PROT_TLS1_1_SERVER, "VERS-TLS1.1"}, + {SP_PROT_TLS1_0_SERVER, "VERS-TLS1.0"}, + {SP_PROT_SSL3_SERVER, "VERS-SSL3.0"} + /* {SP_PROT_SSL2_SERVER} is not supported by GnuTLS */ +}; + static DWORD supported_protocols;
-static void check_supported_protocols(void) +static void check_supported_protocols( + const struct protocol_priority_flag *flags, int num_flags, BOOLEAN server) { + const char *type_desc = server ? "server" : "client"; gnutls_session_t session; char priority[64]; unsigned i; int err;
- err = pgnutls_init(&session, GNUTLS_CLIENT); + err = pgnutls_init(&session, server ? GNUTLS_SERVER : GNUTLS_CLIENT); if (err != GNUTLS_E_SUCCESS) { pgnutls_perror(err); return; }
- for(i = 0; i < ARRAY_SIZE(protocol_priority_flags); i++) + for(i = 0; i < num_flags; i++) { - sprintf(priority, "NORMAL:-%s", protocol_priority_flags[i].gnutls_flag); + sprintf(priority, "NORMAL:-%s", flags[i].gnutls_flag); err = pgnutls_priority_set_direct(session, priority, NULL); if (err == GNUTLS_E_SUCCESS) { - TRACE("%s is supported\n", protocol_priority_flags[i].gnutls_flag); - supported_protocols |= protocol_priority_flags[i].enable_flag; + TRACE("%s %s is supported\n", type_desc, flags[i].gnutls_flag); + supported_protocols |= flags[i].enable_flag; } else - TRACE("%s is not supported\n", protocol_priority_flags[i].gnutls_flag); + TRACE("%s %s is not supported\n", type_desc, flags[i].gnutls_flag); }
pgnutls_deinit(session); @@ -420,6 +435,11 @@ static int pull_timeout(gnutls_transport_ptr_t transport, unsigned int timeout) static NTSTATUS set_priority(schan_credentials *cred, gnutls_session_t session) { char priority[128] = "NORMAL:%LATEST_RECORD_VERSION", *p; + BOOL server = !!(cred->credential_use & SECPKG_CRED_INBOUND); + const struct protocol_priority_flag *protocols = + server ? server_protocol_priority_flags : client_protocol_priority_flags; + int num_protocols = server ? ARRAYSIZE(server_protocol_priority_flags) + : ARRAYSIZE(client_protocol_priority_flags); BOOL using_vers_all = FALSE, disabled; int i, err;
@@ -447,16 +467,16 @@ static NTSTATUS set_priority(schan_credentials *cred, gnutls_session_t session) using_vers_all = TRUE; }
- for (i = 0; i < ARRAY_SIZE(protocol_priority_flags); i++) + for (i = 0; i < num_protocols; i++) { - if (!(supported_protocols & protocol_priority_flags[i].enable_flag)) continue; + if (!(supported_protocols & protocols[i].enable_flag)) continue;
- disabled = !(cred->enabled_protocols & protocol_priority_flags[i].enable_flag); + disabled = !(cred->enabled_protocols & protocols[i].enable_flag); if (using_vers_all && disabled) continue;
*p++ = ':'; *p++ = disabled ? '-' : '+'; - strcpy(p, protocol_priority_flags[i].gnutls_flag); + strcpy(p, protocols[i].gnutls_flag); p += strlen(p); }
@@ -483,7 +503,7 @@ static NTSTATUS schan_create_session( void *args )
*params->session = 0;
- if (cred->enabled_protocols & (SP_PROT_DTLS1_0_CLIENT | SP_PROT_DTLS1_2_CLIENT)) + if (cred->enabled_protocols & SP_PROT_DTLS1_X) { flags |= GNUTLS_DATAGRAM | GNUTLS_NONBLOCK; } @@ -1505,7 +1525,8 @@ static NTSTATUS process_attach( void *args ) pgnutls_global_set_log_function(gnutls_log); }
- check_supported_protocols(); + check_supported_protocols(client_protocol_priority_flags, ARRAYSIZE(client_protocol_priority_flags), FALSE); + check_supported_protocols(server_protocol_priority_flags, ARRAYSIZE(server_protocol_priority_flags), TRUE); return STATUS_SUCCESS;
fail:
From: Evan Tang etang@codeweavers.com
--- dlls/secur32/schannel.c | 157 ++++++++++++++++------------------------ 1 file changed, 63 insertions(+), 94 deletions(-)
diff --git a/dlls/secur32/schannel.c b/dlls/secur32/schannel.c index efb5c85a747..6a73cb0e38c 100644 --- a/dlls/secur32/schannel.c +++ b/dlls/secur32/schannel.c @@ -161,23 +161,24 @@ static void read_config(void) DWORD enabled = 0, default_disabled = 0; HKEY protocols_key, key; WCHAR subkey_name[64]; - unsigned i; + unsigned i, server; DWORD res;
static BOOL config_read = FALSE; static const struct { WCHAR key_name[20]; DWORD prot_client_flag; + DWORD prot_server_flag; BOOL enabled; /* If no config is present, enable the protocol */ BOOL disabled_by_default; /* Disable if caller asks for default protocol set */ } protocol_config_keys[] = { - { L"SSL 2.0", SP_PROT_SSL2_CLIENT, FALSE, TRUE }, /* NOTE: TRUE, TRUE on Windows */ - { L"SSL 3.0", SP_PROT_SSL3_CLIENT, TRUE, FALSE }, - { L"TLS 1.0", SP_PROT_TLS1_0_CLIENT, TRUE, FALSE }, - { L"TLS 1.1", SP_PROT_TLS1_1_CLIENT, TRUE, FALSE /* NOTE: not enabled by default on Windows */ }, - { L"TLS 1.2", SP_PROT_TLS1_2_CLIENT, TRUE, FALSE /* NOTE: not enabled by default on Windows */ }, - { L"DTLS 1.0", SP_PROT_DTLS1_0_CLIENT, TRUE, TRUE }, - { L"DTLS 1.2", SP_PROT_DTLS1_2_CLIENT, TRUE, TRUE }, + { L"SSL 2.0", SP_PROT_SSL2_CLIENT, SP_PROT_SSL2_SERVER, FALSE, TRUE }, /* NOTE: TRUE, TRUE on Windows */ + { L"SSL 3.0", SP_PROT_SSL3_CLIENT, SP_PROT_SSL3_SERVER, TRUE, FALSE }, + { L"TLS 1.0", SP_PROT_TLS1_0_CLIENT, SP_PROT_TLS1_0_SERVER, TRUE, FALSE }, + { L"TLS 1.1", SP_PROT_TLS1_1_CLIENT, SP_PROT_TLS1_1_SERVER, TRUE, FALSE /* NOTE: not enabled by default on Windows */ }, + { L"TLS 1.2", SP_PROT_TLS1_2_CLIENT, SP_PROT_TLS1_2_SERVER, TRUE, FALSE /* NOTE: not enabled by default on Windows */ }, + { L"DTLS 1.0", SP_PROT_DTLS1_0_CLIENT, SP_PROT_DTLS1_0_SERVER, TRUE, TRUE }, + { L"DTLS 1.2", SP_PROT_DTLS1_2_CLIENT, SP_PROT_DTLS1_2_SERVER, TRUE, TRUE }, };
/* No need for thread safety */ @@ -191,44 +192,49 @@ static void read_config(void) DWORD type, size, value;
for(i = 0; i < ARRAY_SIZE(protocol_config_keys); i++) { - wcscpy(subkey_name, protocol_config_keys[i].key_name); - wcscat(subkey_name, L"\Client"); - res = RegOpenKeyExW(protocols_key, subkey_name, 0, KEY_READ, &key); - if(res != ERROR_SUCCESS) { - if(protocol_config_keys[i].enabled) - enabled |= protocol_config_keys[i].prot_client_flag; - if(protocol_config_keys[i].disabled_by_default) - default_disabled |= protocol_config_keys[i].prot_client_flag; - continue; - } - - size = sizeof(value); - res = RegQueryValueExW(key, L"enabled", NULL, &type, (BYTE *)&value, &size); - if(res == ERROR_SUCCESS) { - if(type == REG_DWORD && value) - enabled |= protocol_config_keys[i].prot_client_flag; - }else if(protocol_config_keys[i].enabled) { - enabled |= protocol_config_keys[i].prot_client_flag; + for (server = 0; server < 2; server++) { + DWORD flag = server ? protocol_config_keys[i].prot_server_flag + : protocol_config_keys[i].prot_client_flag; + wcscpy(subkey_name, protocol_config_keys[i].key_name); + wcscat(subkey_name, server ? L"\Server" : L"\Client"); + res = RegOpenKeyExW(protocols_key, subkey_name, 0, KEY_READ, &key); + if(res != ERROR_SUCCESS) { + if(protocol_config_keys[i].enabled) + enabled |= flag; + if(protocol_config_keys[i].disabled_by_default) + default_disabled |= flag; + continue; + } + + size = sizeof(value); + res = RegQueryValueExW(key, L"enabled", NULL, &type, (BYTE *)&value, &size); + if(res == ERROR_SUCCESS) { + if(type == REG_DWORD && value) + enabled |= flag; + }else if(protocol_config_keys[i].enabled) { + enabled |= flag; + } + + size = sizeof(value); + res = RegQueryValueExW(key, L"DisabledByDefault", NULL, &type, (BYTE *)&value, &size); + if(res == ERROR_SUCCESS) { + if(type != REG_DWORD || value) + default_disabled |= flag; + }else if(protocol_config_keys[i].disabled_by_default) { + default_disabled |= flag; + } + + RegCloseKey(key); } - - size = sizeof(value); - res = RegQueryValueExW(key, L"DisabledByDefault", NULL, &type, (BYTE *)&value, &size); - if(res == ERROR_SUCCESS) { - if(type != REG_DWORD || value) - default_disabled |= protocol_config_keys[i].prot_client_flag; - }else if(protocol_config_keys[i].disabled_by_default) { - default_disabled |= protocol_config_keys[i].prot_client_flag; - } - - RegCloseKey(key); } }else { /* No config, enable all known protocols. */ for(i = 0; i < ARRAY_SIZE(protocol_config_keys); i++) { + DWORD flag = protocol_config_keys[i].prot_client_flag | protocol_config_keys[i].prot_server_flag; if(protocol_config_keys[i].enabled) - enabled |= protocol_config_keys[i].prot_client_flag; + enabled |= flag; if(protocol_config_keys[i].disabled_by_default) - default_disabled |= protocol_config_keys[i].prot_client_flag; + default_disabled |= flag; } }
@@ -533,8 +539,8 @@ static BYTE *get_key_blob(const CERT_CONTEXT *ctx, DWORD *size) return ret; }
-static SECURITY_STATUS schan_AcquireClientCredentials(const void *schanCred, - PCredHandle phCredential, PTimeStamp ptsExpiry) +static SECURITY_STATUS schan_AcquireCredentialsHandle(ULONG fCredentialUse, + const SCHANNEL_CRED *schanCred, PCredHandle phCredential, PTimeStamp ptsExpiry) { struct schan_credentials *creds; DWORD enabled_protocols, cred_enabled_protocols; @@ -545,7 +551,7 @@ static SECURITY_STATUS schan_AcquireClientCredentials(const void *schanCred, BYTE *key_blob = NULL; ULONG key_size = 0;
- TRACE("schanCred %p, phCredential %p, ptsExpiry %p\n", schanCred, phCredential, ptsExpiry); + TRACE("fCredentialUse %#lx, schanCred %p, phCredential %p, ptsExpiry %p\n", fCredentialUse, schanCred, phCredential, ptsExpiry);
if (schanCred) { @@ -563,6 +569,10 @@ static SECURITY_STATUS schan_AcquireClientCredentials(const void *schanCred,
status = SEC_E_OK; } + else if (fCredentialUse & SECPKG_CRED_INBOUND) + { + return SEC_E_NO_CREDENTIALS; + }
read_config(); if(schanCred && cred_enabled_protocols) @@ -575,7 +585,7 @@ static SECURITY_STATUS schan_AcquireClientCredentials(const void *schanCred, }
if (!(creds = malloc(sizeof(*creds)))) return SEC_E_INSUFFICIENT_MEMORY; - creds->credential_use = SECPKG_CRED_OUTBOUND; + creds->credential_use = fCredentialUse; creds->enabled_protocols = enabled_protocols;
if (cert && !(key_blob = get_key_blob(cert, &key_size))) goto fail; @@ -598,11 +608,18 @@ static SECURITY_STATUS schan_AcquireClientCredentials(const void *schanCred, phCredential->dwLower = handle; phCredential->dwUpper = 0;
- /* Outbound credentials have no expiry */ if (ptsExpiry) { - ptsExpiry->LowPart = 0; - ptsExpiry->HighPart = 0; + if (fCredentialUse & SECPKG_CRED_INBOUND) + { + /* FIXME: get expiry from cert */ + } + else + { + /* Outbound credentials have no expiry */ + ptsExpiry->LowPart = 0; + ptsExpiry->HighPart = 0; + } }
return status; @@ -612,54 +629,6 @@ fail: return SEC_E_INTERNAL_ERROR; }
-static SECURITY_STATUS schan_AcquireServerCredentials(const SCHANNEL_CRED *schanCred, - PCredHandle phCredential, PTimeStamp ptsExpiry) -{ - SECURITY_STATUS status; - const CERT_CONTEXT *cert = NULL; - - TRACE("schanCred %p, phCredential %p, ptsExpiry %p\n", schanCred, phCredential, ptsExpiry); - - if (!schanCred) return SEC_E_NO_CREDENTIALS; - - status = get_cert(schanCred, &cert); - if (status == SEC_E_OK) - { - ULONG_PTR handle; - struct schan_credentials *creds; - - if (!(creds = calloc(1, sizeof(*creds)))) return SEC_E_INSUFFICIENT_MEMORY; - creds->credential_use = SECPKG_CRED_INBOUND; - - handle = schan_alloc_handle(creds, SCHAN_HANDLE_CRED); - if (handle == SCHAN_INVALID_HANDLE) - { - free(creds); - return SEC_E_INTERNAL_ERROR; - } - - phCredential->dwLower = handle; - phCredential->dwUpper = 0; - - /* FIXME: get expiry from cert */ - } - return status; -} - -static SECURITY_STATUS schan_AcquireCredentialsHandle(ULONG fCredentialUse, - const SCHANNEL_CRED *schanCred, PCredHandle phCredential, PTimeStamp ptsExpiry) -{ - SECURITY_STATUS ret; - - if (fCredentialUse == SECPKG_CRED_OUTBOUND) - ret = schan_AcquireClientCredentials(schanCred, phCredential, - ptsExpiry); - else - ret = schan_AcquireServerCredentials(schanCred, phCredential, - ptsExpiry); - return ret; -} - static SECURITY_STATUS SEC_ENTRY schan_AcquireCredentialsHandleA( SEC_CHAR *pszPrincipal, SEC_CHAR *pszPackage, ULONG fCredentialUse, PLUID pLogonID, PVOID pAuthData, SEC_GET_KEY_FN pGetKeyFn,
From: Evan Tang etang@codeweavers.com
--- dlls/secur32/schannel.c | 13 ++++++++----- dlls/secur32/tests/schannel.c | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 6 deletions(-)
diff --git a/dlls/secur32/schannel.c b/dlls/secur32/schannel.c index 6a73cb0e38c..5a559a82de7 100644 --- a/dlls/secur32/schannel.c +++ b/dlls/secur32/schannel.c @@ -555,16 +555,15 @@ static SECURITY_STATUS schan_AcquireCredentialsHandle(ULONG fCredentialUse,
if (schanCred) { - const unsigned dtls_protocols = SP_PROT_DTLS_CLIENT | SP_PROT_DTLS1_2_CLIENT; - const unsigned tls_protocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_0_CLIENT | SP_PROT_TLS1_1_CLIENT | - SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_3_CLIENT; + const unsigned dtls_protocols = SP_PROT_DTLS1_X; + const unsigned non_dtls_protocols = (SP_PROT_X_CLIENTS | SP_PROT_X_SERVERS) & ~SP_PROT_DTLS1_X;
status = get_cert(schanCred, &cert); if (status != SEC_E_OK && status != SEC_E_NO_CREDENTIALS) return status;
cred_enabled_protocols = get_enabled_protocols(schanCred); - if ((cred_enabled_protocols & tls_protocols) && + if ((cred_enabled_protocols & non_dtls_protocols) && (cred_enabled_protocols & dtls_protocols)) return SEC_E_ALGORITHM_MISMATCH;
status = SEC_E_OK; @@ -579,9 +578,13 @@ static SECURITY_STATUS schan_AcquireCredentialsHandle(ULONG fCredentialUse, enabled_protocols = cred_enabled_protocols & config_enabled_protocols; else enabled_protocols = config_enabled_protocols & ~config_default_disabled_protocols; + if (!(fCredentialUse & SECPKG_CRED_OUTBOUND)) + enabled_protocols &= ~SP_PROT_X_CLIENTS; + if (!(fCredentialUse & SECPKG_CRED_INBOUND)) + enabled_protocols &= ~SP_PROT_X_SERVERS; if(!enabled_protocols) { ERR("Could not find matching protocol\n"); - return SEC_E_NO_AUTHENTICATING_AUTHORITY; + return SEC_E_ALGORITHM_MISMATCH; }
if (!(creds = malloc(sizeof(*creds)))) return SEC_E_INSUFFICIENT_MEMORY; diff --git a/dlls/secur32/tests/schannel.c b/dlls/secur32/tests/schannel.c index 33915351cb3..455fcb97aa7 100644 --- a/dlls/secur32/tests/schannel.c +++ b/dlls/secur32/tests/schannel.c @@ -351,6 +351,8 @@ static void testAcquireSecurityContext(void) ok(st == SEC_E_OK, "AcquireCredentialsHandleA failed: %08lx\n", st); if(st == SEC_E_OK) FreeCredentialsHandle(&cred); + st = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_INBOUND, NULL, NULL, NULL, NULL, &cred, NULL); + ok(st == SEC_E_NO_CREDENTIALS, "st = %08lx\n", st); memset(&cred, 0, sizeof(cred)); st = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &cred, &exp); @@ -363,6 +365,22 @@ static void testAcquireSecurityContext(void)
FreeCredentialsHandle(&cred);
+ /* Should fail if no enabled protocols are available */ + init_cred(&schanCred); + schanCred.grbitEnabledProtocols = SP_PROT_TLS1_X_SERVER; + st = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_OUTBOUND, NULL, &schanCred, NULL, NULL, &cred, &exp); + ok(st == SEC_E_ALGORITHM_MISMATCH, "st = %08lx\n", st); + st = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_INBOUND, NULL, &schanCred, NULL, NULL, &cred, &exp); + ok(st == SEC_E_OK, "AcquireCredentialsHandleA failed: %08lx\n", st); + FreeCredentialsHandle(&cred); + + schanCred.grbitEnabledProtocols = SP_PROT_TLS1_X_CLIENT; + st = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_OUTBOUND, NULL, &schanCred, NULL, NULL, &cred, &exp); + ok(st == SEC_E_OK, "AcquireCredentialsHandleA failed: %08lx\n", st); + FreeCredentialsHandle(&cred); + st = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_INBOUND, NULL, &schanCred, NULL, NULL, &cred, &exp); + ok(st == SEC_E_ALGORITHM_MISMATCH, "st = %08lx\n", st); + /* Bad version in SCHANNEL_CRED */ memset(&schanCred, 0, sizeof(schanCred)); st = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_OUTBOUND, @@ -1668,7 +1686,7 @@ static void test_dtls(void) SECURITY_STATUS status; TimeStamp exp; SCHANNEL_CRED cred; - CredHandle cred_handle; + CredHandle cred_handle, cred_handle2; CtxtHandle ctx_handle, ctx_handle2; SecBufferDesc buffers[3]; ULONG flags_req, flags_ret, attr, prev_buf_len; @@ -1687,6 +1705,19 @@ static void test_dtls(void) } ok( status == SEC_E_OK, "got %08lx\n", status );
+ /* Should fail if both DTLS and TLS protocols are requested */ + cred.grbitEnabledProtocols |= SP_PROT_TLS1_CLIENT; + status = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_OUTBOUND, NULL, &cred, NULL, NULL, &cred_handle2, &exp); + ok(status == SEC_E_ALGORITHM_MISMATCH, "status = %08lx\n", status); + + cred.grbitEnabledProtocols = SP_PROT_DTLS1_X_CLIENT | SP_PROT_TLS1_SERVER; + status = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_OUTBOUND, NULL, &cred, NULL, NULL, &cred_handle2, &exp); + ok(status == SEC_E_ALGORITHM_MISMATCH, "status = got %08lx\n", status); + + cred.grbitEnabledProtocols = SP_PROT_DTLS1_X_CLIENT | SP_PROT_SSL3_SERVER; + status = AcquireCredentialsHandleA(NULL, unisp_name_a, SECPKG_CRED_OUTBOUND, NULL, &cred, NULL, NULL, &cred_handle2, &exp); + ok(status == SEC_E_ALGORITHM_MISMATCH, "status = got %08lx\n", status); + flags_req = ISC_REQ_MANUAL_CRED_VALIDATION | ISC_REQ_EXTENDED_ERROR | ISC_REQ_DATAGRAM | ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_CONFIDENTIALITY | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT; test_context_output_buffer_size(SP_PROT_DTLS_CLIENT | SP_PROT_DTLS1_2_CLIENT, SCH_CRED_NO_DEFAULT_CREDS, flags_req);
From: Evan Tang etang@codeweavers.com
--- dlls/secur32/tests/schannel.c | 141 ++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+)
diff --git a/dlls/secur32/tests/schannel.c b/dlls/secur32/tests/schannel.c index 455fcb97aa7..c6791c774dd 100644 --- a/dlls/secur32/tests/schannel.c +++ b/dlls/secur32/tests/schannel.c @@ -1671,6 +1671,146 @@ static void test_application_protocol_negotiation(void) closesocket(sock); }
+static void test_server_protocol_negotiation(void) { + BOOL ret; + SECURITY_STATUS client_status, server_status; + ULONG attrs; + SCHANNEL_CRED client_cred, server_cred; + CredHandle client_cred_handle, server_cred_handle; + CtxtHandle client_context, server_context, client_context2, server_context2; + SecPkgContext_ApplicationProtocol protocol; + SecBufferDesc buffers[3]; + PCCERT_CONTEXT cert; + HCRYPTPROV csp; + HCRYPTKEY key; + CRYPT_KEY_PROV_INFO keyProvInfo; + WCHAR ms_def_prov_w[MAX_PATH]; + unsigned buf_size = 8192; + unsigned char *alpn_buffer; + unsigned int *extension_len; + unsigned short *list_len; + int list_start_index, offset = 0; + + if (!pQueryContextAttributesA) + { + win_skip("Required secur32 functions not available\n"); + return; + } + + lstrcpyW(ms_def_prov_w, MS_DEF_PROV_W); + keyProvInfo.pwszContainerName = cspNameW; + keyProvInfo.pwszProvName = ms_def_prov_w; + keyProvInfo.dwProvType = PROV_RSA_FULL; + keyProvInfo.dwFlags = 0; + keyProvInfo.cProvParam = 0; + keyProvInfo.rgProvParam = NULL; + keyProvInfo.dwKeySpec = AT_SIGNATURE; + + cert = CertCreateCertificateContext(X509_ASN_ENCODING, selfSignedCert, sizeof(selfSignedCert)); + ret = CertSetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, 0, &keyProvInfo); + ok(ret, "CertSetCertificateContextProperty failed: %08lx", GetLastError()); + ret = CryptAcquireContextW(&csp, cspNameW, MS_DEF_PROV_W, PROV_RSA_FULL, CRYPT_NEWKEYSET); + ok(ret, "CryptAcquireContextW failed: %08lx\n", GetLastError()); + ret = CryptImportKey(csp, privKey, sizeof(privKey), 0, 0, &key); + ok(ret, "CryptImportKey failed: %08lx\n", GetLastError()); + if (!ret) return; + + init_cred(&client_cred); + init_cred(&server_cred); + client_cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT; + client_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS|SCH_CRED_MANUAL_CRED_VALIDATION; + server_cred.grbitEnabledProtocols = SP_PROT_TLS1_SERVER; + server_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS|SCH_CRED_MANUAL_CRED_VALIDATION; + server_cred.cCreds = 1; + server_cred.paCred = &cert; + + client_status = AcquireCredentialsHandleA(NULL, (SEC_CHAR *)UNISP_NAME_A, SECPKG_CRED_OUTBOUND, NULL, &client_cred, NULL, NULL, &client_cred_handle, NULL); + ok(client_status == SEC_E_OK, "got %08lx\n", client_status); + if (client_status != SEC_E_OK) return; + server_status = AcquireCredentialsHandleA(NULL, (SEC_CHAR *)UNISP_NAME_A, SECPKG_CRED_INBOUND, NULL, &server_cred, NULL, NULL, &server_cred_handle, NULL); + ok(server_status == SEC_E_OK, "got %08lx\n", server_status); + if (server_status != SEC_E_OK) return; + + init_buffers(&buffers[0], 4, buf_size); + init_buffers(&buffers[1], 4, buf_size); + init_buffers(&buffers[2], 1, 128); + + alpn_buffer = buffers[2].pBuffers[0].pvBuffer; + extension_len = (unsigned int *)&alpn_buffer[offset]; + offset += sizeof(*extension_len); + *(unsigned int *)&alpn_buffer[offset] = SecApplicationProtocolNegotiationExt_ALPN; + offset += sizeof(unsigned int); + list_len = (unsigned short *)&alpn_buffer[offset]; + offset += sizeof(*list_len); + list_start_index = offset; + + alpn_buffer[offset++] = sizeof("http/1.1") - 1; + memcpy(&alpn_buffer[offset], "http/1.1", sizeof("http/1.1") - 1); + offset += sizeof("http/1.1") - 1; + alpn_buffer[offset++] = sizeof("h2") - 1; + memcpy(&alpn_buffer[offset], "h2", sizeof("h2") - 1); + offset += sizeof("h2") - 1; + + *list_len = offset - list_start_index; + *extension_len = *list_len + sizeof(*extension_len) + sizeof(*list_len); + + buffers[2].pBuffers[0].BufferType = SECBUFFER_APPLICATION_PROTOCOLS; + buffers[2].pBuffers[0].cbBuffer = offset; + buffers[0].pBuffers[0].BufferType = SECBUFFER_TOKEN; + client_status = InitializeSecurityContextA(&client_cred_handle, NULL, (SEC_CHAR *)"localhost", ISC_REQ_CONFIDENTIALITY|ISC_REQ_STREAM, 0, 0, &buffers[2], 0, &client_context, &buffers[0], &attrs, NULL); + ok(client_status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", client_status); + + buffers[1].pBuffers[0].cbBuffer = buf_size; + buffers[1].pBuffers[0].BufferType = SECBUFFER_TOKEN; + buffers[0].pBuffers[1] = buffers[2].pBuffers[0]; + server_status = AcceptSecurityContext(&server_cred_handle, NULL, &buffers[0], ASC_REQ_CONFIDENTIALITY|ASC_REQ_STREAM, 0, &server_context, &buffers[1], &attrs, NULL); + ok(server_status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", server_status); + memset(&buffers[0].pBuffers[1], 0, sizeof(buffers[0].pBuffers[1])); + + client_context2.dwLower = client_context2.dwUpper = 0xdeadbeef; + buffers[0].pBuffers[0].cbBuffer = buf_size; + client_status = InitializeSecurityContextA(&client_cred_handle, &client_context, (SEC_CHAR *)"localhost", ISC_REQ_CONFIDENTIALITY|ISC_REQ_STREAM|ISC_REQ_USE_SUPPLIED_CREDS, 0, 0, &buffers[1], 0, &client_context2, &buffers[0], &attrs, NULL); + ok(client_context.dwLower == client_context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", client_context.dwLower, client_context2.dwLower); + ok(client_context.dwUpper == client_context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", client_context.dwUpper, client_context2.dwUpper); + ok(client_status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", client_status); + + server_context2.dwLower = server_context2.dwUpper = 0xdeadbeef; + buffers[1].pBuffers[0].cbBuffer = buf_size; + server_status = AcceptSecurityContext(&server_cred_handle, &server_context, &buffers[0], ASC_REQ_CONFIDENTIALITY|ASC_REQ_STREAM, 0, &server_context2, &buffers[1], &attrs, NULL); + ok(server_context.dwLower == server_context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", server_context.dwLower, server_context2.dwLower); + ok(server_context.dwUpper == server_context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", server_context.dwUpper, server_context2.dwUpper); + ok(server_status == SEC_E_OK, "got %08lx\n", server_status); + + buffers[0].pBuffers[0].cbBuffer = buf_size; + client_status = InitializeSecurityContextA(&client_cred_handle, &client_context, (SEC_CHAR *)"localhost", ISC_REQ_USE_SUPPLIED_CREDS, 0, 0, &buffers[1], 0, NULL, &buffers[0], &attrs, NULL); + ok(client_status == SEC_E_OK, "got %08lx\n", client_status); + + memset(&protocol, 0, sizeof(protocol)); + client_status = pQueryContextAttributesA(&client_context, SECPKG_ATTR_APPLICATION_PROTOCOL, &protocol); + ok(client_status == SEC_E_OK || broken(client_status == SEC_E_UNSUPPORTED_FUNCTION) /* win2k8 */, "got %08lx\n", client_status); + if (client_status == SEC_E_OK) + { + ok(protocol.ProtoNegoStatus == SecApplicationProtocolNegotiationStatus_Success, "got %u\n", protocol.ProtoNegoStatus); + ok(protocol.ProtoNegoExt == SecApplicationProtocolNegotiationExt_ALPN, "got %u\n", protocol.ProtoNegoExt); + ok(protocol.ProtocolIdSize == 8, "got %u\n", protocol.ProtocolIdSize); + ok(!memcmp(protocol.ProtocolId, "http/1.1", 8), "wrong protocol id\n"); + } + + DeleteSecurityContext(&client_context); + DeleteSecurityContext(&server_context); + FreeCredentialsHandle(&client_cred_handle); + FreeCredentialsHandle(&server_cred_handle); + + free_buffers(&buffers[0]); + free_buffers(&buffers[1]); + free_buffers(&buffers[2]); + + CryptDestroyKey(key); + CryptReleaseContext(csp, 0); + CryptAcquireContextW(&csp, cspNameW, MS_DEF_PROV_W, PROV_RSA_FULL, CRYPT_DELETEKEYSET); + CertFreeCertificateContext(cert); +} + static void init_dtls_output_buffer(SecBufferDesc *buffer) { buffer->pBuffers[0].BufferType = SECBUFFER_TOKEN; @@ -1949,6 +2089,7 @@ START_TEST(schannel) test_InitializeSecurityContext(); test_communication(); test_application_protocol_negotiation(); + test_server_protocol_negotiation(); test_dtls(); test_connection_shutdown(); }
Hans Leidekker (@hans) commented about dlls/secur32/schannel.c:
return TRUE;
}
-/***********************************************************************
InitializeSecurityContextW
- */
-static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( +static SECURITY_STATUS schan_AcceptOrInitializeSecurityContext( PCredHandle phCredential, PCtxtHandle phContext, SEC_WCHAR *pszTargetName,
Please name it something like establish_context().
Hans Leidekker (@hans) commented about dlls/secur32/schannel.c:
return TRUE;
}
-/***********************************************************************
InitializeSecurityContextW
- */
-static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( +static SECURITY_STATUS schan_AcceptOrInitializeSecurityContext( PCredHandle phCredential, PCtxtHandle phContext, SEC_WCHAR *pszTargetName,
- ULONG fContextReq, ULONG Reserved1, ULONG TargetDataRep,
- PSecBufferDesc pInput, ULONG Reserved2, PCtxtHandle phNewContext,
- PSecBufferDesc pOutput, ULONG *pfContextAttr, PTimeStamp ptsExpiry)
- PSecBufferDesc pInput, ULONG fContextReq, ULONG TargetDataRep,
- PCtxtHandle phNewContext, PSecBufferDesc pOutput, ULONG *pfContextAttr,
- PTimeStamp ptsTimeStamp, BOOLEAN bIsServer)
{
BOOLEAN -> BOOL.
Hans Leidekker (@hans) commented about dlls/secur32/schannel.c:
return ret;
}
-static SECURITY_STATUS schan_AcquireClientCredentials(const void *schanCred,
- PCredHandle phCredential, PTimeStamp ptsExpiry)
+static SECURITY_STATUS schan_AcquireCredentialsHandle(ULONG fCredentialUse,
- const SCHANNEL_CRED *schanCred, PCredHandle phCredential, PTimeStamp ptsExpiry)
Please name it something like acquire_credentials_handle().
Hans Leidekker (@hans) commented about dlls/secur32/schannel.c:
enabled_protocols = cred_enabled_protocols & config_enabled_protocols; else enabled_protocols = config_enabled_protocols & ~config_default_disabled_protocols;
- if (!(fCredentialUse & SECPKG_CRED_OUTBOUND))
enabled_protocols &= ~SP_PROT_X_CLIENTS;
- if (!(fCredentialUse & SECPKG_CRED_INBOUND))
if(!enabled_protocols) { ERR("Could not find matching protocol\n");enabled_protocols &= ~SP_PROT_X_SERVERS;
return SEC_E_NO_AUTHENTICATING_AUTHORITY;
return SEC_E_ALGORITHM_MISMATCH;
Since you're not just adding tests, please change the patch title to something like 'Fix AcquireCredentialsHandle(Schannel) algorithm mismatch error return.'.