From: Paul Gofman pgofman@codeweavers.com
--- dlls/crypt32/str.c | 105 +++++++++++++++++++++++++++++---------- dlls/crypt32/tests/str.c | 55 +++++++++++++++----- include/wincrypt.h | 6 ++- 3 files changed, 126 insertions(+), 40 deletions(-)
diff --git a/dlls/crypt32/str.c b/dlls/crypt32/str.c index 0906ad883db..d74df308e4a 100644 --- a/dlls/crypt32/str.c +++ b/dlls/crypt32/str.c @@ -905,17 +905,7 @@ DWORD WINAPI CertGetNameStringA(PCCERT_CONTEXT cert, DWORD type, return ret; }
-/* Searches cert's extensions for the alternate name extension with OID - * altNameOID, and if found, searches it for the alternate name type entryType. - * If found, returns a pointer to the entry, otherwise returns NULL. - * Regardless of whether an entry of the desired type is found, if the - * alternate name extension is present, sets *info to the decoded alternate - * name extension, which you must free using LocalFree. - * The return value is a pointer within *info, so don't free *info before - * you're done with the return value. - */ -static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL alt_name_issuer, - DWORD entryType, PCERT_ALT_NAME_INFO *info) +static BOOL cert_get_alt_name_info(PCCERT_CONTEXT cert, BOOL alt_name_issuer, PCERT_ALT_NAME_INFO *info) { static const char *oids[][2] = { @@ -924,24 +914,48 @@ static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL a }; PCERT_EXTENSION ext; DWORD bytes = 0; - unsigned int i;
ext = CertFindExtension(oids[!!alt_name_issuer][0], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension); if (!ext) ext = CertFindExtension(oids[!!alt_name_issuer][1], cert->pCertInfo->cExtension, cert->pCertInfo->rgExtension); - if (!ext) return NULL; + if (!ext) return FALSE;
- if (!CryptDecodeObjectEx(cert->dwCertEncodingType, X509_ALTERNATE_NAME, ext->Value.pbData, ext->Value.cbData, - CRYPT_DECODE_ALLOC_FLAG, NULL, info, &bytes)) - return NULL; + return CryptDecodeObjectEx(cert->dwCertEncodingType, X509_ALTERNATE_NAME, ext->Value.pbData, ext->Value.cbData, + CRYPT_DECODE_ALLOC_FLAG, NULL, info, &bytes); +}
- for (i = 0; i < (*info)->cAltEntry; ++i) - if ((*info)->rgAltEntry[i].dwAltNameChoice == entryType) - return &(*info)->rgAltEntry[i]; +static PCERT_ALT_NAME_ENTRY cert_find_next_alt_name_entry(PCERT_ALT_NAME_INFO info, DWORD entry_type, + unsigned int *index) +{ + unsigned int i;
+ for (i = *index; i < info->cAltEntry; ++i) + if (info->rgAltEntry[i].dwAltNameChoice == entry_type) + { + *index = i + 1; + return &info->rgAltEntry[i]; + } return NULL; }
+/* Searches cert's extensions for the alternate name extension with OID + * altNameOID, and if found, searches it for the alternate name type entryType. + * If found, returns a pointer to the entry, otherwise returns NULL. + * Regardless of whether an entry of the desired type is found, if the + * alternate name extension is present, sets *info to the decoded alternate + * name extension, which you must free using LocalFree. + * The return value is a pointer within *info, so don't free *info before + * you're done with the return value. + */ +static PCERT_ALT_NAME_ENTRY cert_find_alt_name_entry(PCCERT_CONTEXT cert, BOOL alt_name_issuer, + DWORD entry_type, PCERT_ALT_NAME_INFO *info) +{ + unsigned int index = 0; + + if (!cert_get_alt_name_info(cert, alt_name_issuer, info)) return NULL; + return cert_find_next_alt_name_entry(*info, entry_type, &index); +} + static DWORD cert_get_name_from_rdn_attr(DWORD encodingType, const CERT_NAME_BLOB *name, LPCSTR oid, LPWSTR pszNameString, DWORD cchNameString) { @@ -978,9 +992,10 @@ static DWORD copy_output_str(WCHAR *dst, const WCHAR *src, DWORD dst_size) DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, void *type_para, LPWSTR name_string, DWORD name_len) { + static const DWORD supported_flags = CERT_NAME_ISSUER_FLAG | CERT_NAME_SEARCH_ALL_NAMES_FLAG; + BOOL alt_name_issuer, search_all_names; CERT_ALT_NAME_INFO *info = NULL; PCERT_ALT_NAME_ENTRY entry; - BOOL alt_name_issuer; PCERT_NAME_BLOB name; DWORD ret = 0;
@@ -989,6 +1004,16 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo if (!cert) goto done;
+ if (flags & ~supported_flags) + FIXME("Unsupported flags %#lx.\n", flags); + + search_all_names = flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG; + if (search_all_names && type != CERT_NAME_DNS_TYPE) + { + WARN("CERT_NAME_SEARCH_ALL_NAMES_FLAG used with type %lu.\n", type); + goto done; + } + alt_name_issuer = flags & CERT_NAME_ISSUER_FLAG; name = alt_name_issuer ? &cert->pCertInfo->Issuer : &cert->pCertInfo->Subject;
@@ -1077,15 +1102,43 @@ DWORD WINAPI CertGetNameStringW(PCCERT_CONTEXT cert, DWORD type, DWORD flags, vo } case CERT_NAME_DNS_TYPE: { - entry = cert_find_alt_name_entry(cert, alt_name_issuer, CERT_ALT_NAME_DNS_NAME, &info); + unsigned int index = 0, len;
- if (entry) + if (cert_get_alt_name_info(cert, alt_name_issuer, &info) + && (entry = cert_find_next_alt_name_entry(info, CERT_ALT_NAME_DNS_NAME, &index))) { - ret = copy_output_str(name_string, entry->u.pwszDNSName, name_len); - break; + if (search_all_names) + { + do + { + if (name_string && name_len == 1) break; + ret += len = copy_output_str(name_string, entry->u.pwszDNSName, name_len ? name_len - 1 : 0); + if (name_string && name_len) + { + name_string += len; + name_len -= len; + } + } + while ((entry = cert_find_next_alt_name_entry(info, CERT_ALT_NAME_DNS_NAME, &index))); + } + else ret = copy_output_str(name_string, entry->u.pwszDNSName, name_len); + } + else + { + if (!search_all_names || name_len != 1) + { + len = search_all_names && name_len ? name_len - 1 : name_len; + ret = cert_get_name_from_rdn_attr(cert->dwCertEncodingType, name, szOID_COMMON_NAME, + name_string, len); + if (name_string) name_string += ret; + } + } + + if (search_all_names) + { + if (name_string && name_len) *name_string = 0; + ++ret; } - ret = cert_get_name_from_rdn_attr(cert->dwCertEncodingType, name, szOID_COMMON_NAME, - name_string, name_len); break; } case CERT_NAME_URL_TYPE: diff --git a/dlls/crypt32/tests/str.c b/dlls/crypt32/tests/str.c index 9fa9efff6b9..5fb05bdb836 100644 --- a/dlls/crypt32/tests/str.c +++ b/dlls/crypt32/tests/str.c @@ -847,39 +847,63 @@ static void test_CertStrToNameW(void) static void test_CertGetNameString_value_(unsigned int line, PCCERT_CONTEXT context, DWORD type, DWORD flags, void *type_para, const char *expected) { + DWORD len, retlen, expected_len; WCHAR expectedW[512]; - DWORD len, retlen; WCHAR strW[512]; - unsigned int i; char str[512];
- for (i = 0; expected[i]; ++i) - expectedW[i] = expected[i]; - expectedW[i] = 0; + expected_len = 0; + while(expected[expected_len]) + { + while((expectedW[expected_len] = expected[expected_len])) + ++expected_len; + if (!(flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG)) + break; + expectedW[expected_len++] = 0; + } + expectedW[expected_len++] = 0;
len = CertGetNameStringA(context, type, flags, type_para, NULL, 0); - ok(len == strlen(expected) + 1, "line %u: unexpected length %ld.\n", line, len); + ok(len == expected_len, "line %u: unexpected length %ld, expected %ld.\n", line, len, expected_len); + memset(str, 0xcc, len); retlen = CertGetNameStringA(context, type, flags, type_para, str, len); ok(retlen == len, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len); - ok(!strcmp(str, expected), "line %u: unexpected value %s.\n", line, str); + ok(!memcmp(str, expected, expected_len), "line %u: unexpected value %s.\n", line, debugstr_an(str, expected_len)); str[0] = str[1] = 0xcc; retlen = CertGetNameStringA(context, type, flags, type_para, str, len - 1); ok(retlen == 1, "line %u: Unexpected len %lu, expected 1.\n", line, retlen); if (len == 1) return; ok(!str[0], "line %u: unexpected str[0] %#x.\n", line, str[0]); ok(str[1] == expected[1], "line %u: unexpected str[1] %#x.\n", line, str[1]); - + ok(!memcmp(str + 1, expected + 1, len - 2), + "line %u: str %s, string data mismatch.\n", line, debugstr_a(str + 1)); retlen = CertGetNameStringA(context, type, flags, type_para, str, 0); ok(retlen == len, "line %u: Unexpected len %lu, expected 1.\n", line, retlen);
+ memset(strW, 0xcc, len * sizeof(*strW)); retlen = CertGetNameStringW(context, type, flags, type_para, strW, len); - ok(retlen == len, "line %u: unexpected len %lu, expected 1.\n", line, retlen); - ok(!wcscmp(strW, expectedW), "line %u: unexpected value %s.\n", line, debugstr_w(strW)); + ok(retlen == expected_len, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, expected_len); + ok(!memcmp(strW, expectedW, len * sizeof(*strW)), "line %u: unexpected value %s.\n", line, debugstr_wn(strW, len)); strW[0] = strW[1] = 0xcccc; retlen = CertGetNameStringW(context, type, flags, type_para, strW, len - 1); ok(retlen == len - 1, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len - 1); - ok(!wcsncmp(strW, expectedW, retlen - 1), "line %u: string data mismatch.\n", line); - ok(!strW[retlen - 1], "line %u: string is not zero terminated.\n", line); + if (flags & CERT_NAME_SEARCH_ALL_NAMES_FLAG) + { + ok(!memcmp(strW, expectedW, (retlen - 2) * sizeof(*strW)), + "line %u: str %s, string data mismatch.\n", line, debugstr_wn(strW, retlen - 2)); + ok(!strW[retlen - 2], "line %u: string is not zero terminated.\n", line); + ok(!strW[retlen - 1], "line %u: string sequence is not zero terminated.\n", line); + + retlen = CertGetNameStringW(context, type, flags, type_para, strW, 1); + ok(retlen == 1, "line %u: unexpected len %lu, expected %lu.\n", line, retlen, len - 1); + ok(!strW[retlen - 1], "line %u: string sequence is not zero terminated.\n", line); + } + else + { + ok(!memcmp(strW, expectedW, (retlen - 1) * sizeof(*strW)), + "line %u: str %s, string data mismatch.\n", line, debugstr_wn(strW, retlen - 1)); + ok(!strW[retlen - 1], "line %u: string is not zero terminated.\n", line); + } retlen = CertGetNameStringA(context, type, flags, type_para, NULL, len - 1); ok(retlen == len, "line %u: unexpected len %lu, expected %lu\n", line, retlen, len); retlen = CertGetNameStringW(context, type, flags, type_para, NULL, len - 1); @@ -924,6 +948,9 @@ static void test_CertGetNameString(void) test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, localhost); test_CertGetNameString_value(context, CERT_NAME_FRIENDLY_DISPLAY_TYPE, 0, NULL, localhost); test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, 0, NULL, localhost); + test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "localhost\0"); + test_CertGetNameString_value(context, CERT_NAME_EMAIL_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, ""); + test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, NULL, "");
CertFreeCertificateContext(context);
@@ -945,6 +972,10 @@ static void test_CertGetNameString(void) test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_ISSUER_FLAG, NULL, "ex3.org"); test_CertGetNameString_value(context, CERT_NAME_SIMPLE_DISPLAY_TYPE, 0, NULL, "server_cn.org"); test_CertGetNameString_value(context, CERT_NAME_ATTR_TYPE, 0, (void *)szOID_SUR_NAME, ""); + test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG, + NULL, "ex1.org\0*.ex2.org\0"); + test_CertGetNameString_value(context, CERT_NAME_DNS_TYPE, CERT_NAME_SEARCH_ALL_NAMES_FLAG | CERT_NAME_ISSUER_FLAG, + NULL, "ex3.org\0*.ex4.org\0"); CertFreeCertificateContext(context); }
diff --git a/include/wincrypt.h b/include/wincrypt.h index 2c1e3f0d4c3..77b2fb5d7cf 100644 --- a/include/wincrypt.h +++ b/include/wincrypt.h @@ -3351,8 +3351,10 @@ typedef struct _CTL_FIND_SUBJECT_PARA #define CERT_NAME_URL_TYPE 7 #define CERT_NAME_UPN_TYPE 8
-#define CERT_NAME_ISSUER_FLAG 0x00000001 -#define CERT_NAME_DISABLE_IE4_UTF8_FLAG 0x00010000 +#define CERT_NAME_ISSUER_FLAG 0x00000001 +#define CERT_NAME_SEARCH_ALL_NAMES_FLAG 0x00000002 +#define CERT_NAME_DISABLE_IE4_UTF8_FLAG 0x00010000 +#define CERT_NAME_STR_ENABLE_PUNYCODE_FLAG 0x00200000
/* CryptFormatObject flags */ #define CRYPT_FORMAT_STR_MULTI_LINE 0x0001