Implement RSA optimal asymmetric encryption padding according to RFC 8017.
Signed-off-by: Zhiyi Zhang zzhang@codeweavers.com --- dlls/rsaenh/rsaenh.c | 298 ++++++++++++++++++++++++++++++++++++++++++--- dlls/rsaenh/tests/rsaenh.c | 49 +++++++- 2 files changed, 332 insertions(+), 15 deletions(-)
diff --git a/dlls/rsaenh/rsaenh.c b/dlls/rsaenh/rsaenh.c index e4c6dd14b3..bc10fc731f 100644 --- a/dlls/rsaenh/rsaenh.c +++ b/dlls/rsaenh/rsaenh.c @@ -1694,26 +1694,189 @@ static BOOL pad_data_pkcs1(const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer, return TRUE; }
+/****************************************************************************** + * pkcs1_mgf1 [Internal] + * + * MGF function for RSA EM-OAEP as specified in RFC 8017 PKCS #1 V2.2, Appendix B.2.1. MGF1 + * + * PARAMS + * hProv [I] Cryptographic provider handle + * pbSeed [I] Seed from which mask is generated + * dwSeedLength [I] Length of pbSeed + * dwLength [I] Intended length in octets of the mask + * pbMask [O] Generated mask if success. Caller is responsible for freeing the mask when it's done + * + * RETURNS + * Success: TRUE + * Failure: FALSE + */ +static BOOL pkcs1_mgf1(HCRYPTPROV hProv, const BYTE *pbSeed, DWORD dwSeedLength, DWORD dwLength, PCRYPT_DATA_BLOB pbMask) +{ + HCRYPTHASH hHash; + BYTE *pbHashInput, *pbCounter; + DWORD dwCounter; + DWORD dwLen, dwHashLen; + + RSAENH_CPCreateHash(hProv, CALG_SHA1, 0, 0, &hHash); + RSAENH_CPHashData(hProv, hHash, 0, 0, 0); + dwLen = sizeof(dwHashLen); + RSAENH_CPGetHashParam(hProv, hHash, HP_HASHSIZE, (BYTE *)&dwHashLen, &dwLen, 0); + RSAENH_CPDestroyHash(hProv, hHash); + + /* Allocate multiples of hash value */ + pbMask->pbData = HeapAlloc(GetProcessHeap(), 0, (dwLength + dwHashLen - 1) / dwHashLen * dwHashLen); + if (!pbMask->pbData) + { + SetLastError(NTE_NO_MEMORY); + return FALSE; + } + pbMask->cbData = dwLength; + + pbHashInput = HeapAlloc(GetProcessHeap(), 0, dwSeedLength + sizeof(DWORD)); + if (!pbHashInput) + { + free_data_blob(pbMask); + SetLastError(NTE_NO_MEMORY); + return FALSE; + } + + dwLen = dwHashLen; + memcpy(pbHashInput, pbSeed, dwSeedLength); + pbCounter = pbHashInput + dwSeedLength; + for (dwCounter = 0; dwCounter < (dwLength + dwHashLen - 1) / dwHashLen; dwCounter++) + { + *(pbCounter) = (BYTE)((dwCounter >> 24) & 0xff); + *(pbCounter + 1) = (BYTE)((dwCounter >> 16) & 0xff); + *(pbCounter + 2) = (BYTE)((dwCounter >> 8) & 0xff); + *(pbCounter + 3) = (BYTE)(dwCounter & 0xff); + RSAENH_CPCreateHash(hProv, CALG_SHA1, 0, 0, &hHash); + RSAENH_CPHashData(hProv, hHash, pbHashInput, dwSeedLength + sizeof(DWORD), 0); + /* pbMask->pbData = old pbMask->pbData || Hash(Seed || Counter) */ + RSAENH_CPGetHashParam(hProv, hHash, HP_HASHVAL, pbMask->pbData + dwCounter * dwHashLen, &dwLen, 0); + RSAENH_CPDestroyHash(hProv, hHash); + } + + HeapFree(GetProcessHeap(), 0, pbHashInput); + return TRUE; +} + +/****************************************************************************** + * pad_data_oaep [Internal] + * + * Helper function for data OAEP padding scheme according to RFC 8017 PKCS #1 V2.2 + * + * PARAMS + * hProv [I] Cryptographic provider handle + * abData [I] The data to be padded + * dwDataLen [I] Length of the data + * abBuffer [O] Padded data will be stored here + * dwBufferLen [I] Length of the buffer (also length of padded data) + * dwFlags [I] Currently only CRYPT_OAEP is defined + * + * RETURN + * Success: TRUE + * Failure: FALSE + */ +static BOOL pad_data_oaep(HCRYPTPROV hProv, const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer, DWORD dwBufferLen, + DWORD dwFlags) +{ + CRYPT_DATA_BLOB blobDbMask = {0}, blobSeedMask = {0}; + HCRYPTHASH hHash; + BYTE *pbPadded = NULL, *pbDb, *pbSeed; + DWORD dwLen, dwHashLen; + DWORD dwDbLen, dwSeedLen; + BOOL result, ret = FALSE; + DWORD i; + + RSAENH_CPCreateHash(hProv, CALG_SHA1, 0, 0, &hHash); + /* Empty label */ + RSAENH_CPHashData(hProv, hHash, 0, 0, 0); + dwLen = sizeof(dwHashLen); + RSAENH_CPGetHashParam(hProv, hHash, HP_HASHSIZE, (BYTE *)&dwHashLen, &dwLen, 0); + + if (dwDataLen > dwBufferLen - 2 * dwHashLen - 2) + { + SetLastError(NTE_BAD_LEN); + goto done; + } + + if (dwBufferLen < 2 * dwHashLen + 2) + { + SetLastError(ERROR_MORE_DATA); + goto done; + } + + pbPadded = HeapAlloc(GetProcessHeap(), 0, dwBufferLen); + if (!pbPadded) + { + SetLastError(NTE_NO_MEMORY); + goto done; + } + + /* EM = 00 || maskedSeed || maskedDB */ + pbPadded[0] = 0; + pbSeed = pbPadded + 1; + dwSeedLen = dwHashLen; + pbDb = pbPadded + 1 + dwHashLen; + dwDbLen = dwBufferLen - dwSeedLen - 1; + + /* DB = pHash || PS || 01 || M */ + /* Set pHash in DB */ + dwLen = dwHashLen; + RSAENH_CPGetHashParam(hProv, hHash, HP_HASHVAL, pbDb, &dwLen, 0); + /* Set PS(zeros) in DB */ + memset(pbDb + dwHashLen, 0, dwDbLen - dwHashLen - 1 - dwDataLen); + /* Set 01 in DB */ + pbDb[dwDbLen - dwDataLen - 1] = 1; + /* Set M in DB */ + memcpy(pbDb + dwDbLen - dwDataLen, abData, dwDataLen); + + /* Get seed */ + gen_rand_impl(pbSeed, dwHashLen); + /* Get masked DB */ + result = pkcs1_mgf1(hProv, pbSeed, dwHashLen, dwDbLen, &blobDbMask); + if (!result) goto done; + for (i = 0; i < dwDbLen; i++) pbDb[i] ^= blobDbMask.pbData[i]; + + /* Get masked seed */ + result = pkcs1_mgf1(hProv, pbDb, dwDbLen, dwHashLen, &blobSeedMask); + if (!result) goto done; + for (i = 0; i < dwHashLen; i++) pbSeed[i] ^= blobSeedMask.pbData[i]; + + memcpy(abBuffer, pbPadded, dwBufferLen); + ret = TRUE; +done: + RSAENH_CPDestroyHash(hProv, hHash); + HeapFree(GetProcessHeap(), 0, pbPadded); + free_data_blob(&blobDbMask); + free_data_blob(&blobSeedMask); + return ret; +} + /****************************************************************************** * pad_data [Internal] * * Helper function for data padding according to padding format * * PARAMS + * hProv [I] Cryptographic provider handle * abData [I] The data to be padded * dwDataLen [I] Length of the data * abBuffer [O] Padded data will be stored here * dwBufferLen [I] Length of the buffer (also length of padded data) - * dwFlags [I] 0 or CRYPT_SSL2_FALLBACK + * dwFlags [I] 0, CRYPT_SSL2_FALLBACK or CRYPT_OAEP * * RETURN * Success: TRUE * Failure: FALSE */ -static BOOL pad_data(const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer, DWORD dwBufferLen, +static BOOL pad_data(HCRYPTPROV hProv, const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer, DWORD dwBufferLen, DWORD dwFlags) { - return pad_data_pkcs1(abData, dwDataLen, abBuffer, dwBufferLen, dwFlags); + if (dwFlags == CRYPT_OAEP) + return pad_data_oaep(hProv, abData, dwDataLen, abBuffer, dwBufferLen, dwFlags); + else + return pad_data_pkcs1(abData, dwDataLen, abBuffer, dwBufferLen, dwFlags); }
/****************************************************************************** @@ -1757,26 +1920,133 @@ static BOOL unpad_data_pkcs1(const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer return TRUE; }
+/****************************************************************************** + * unpad_data_oaep [Internal] + * + * Remove the OAEP padding from RSA decrypted data + * + * PARAMS + * hProv [I] Cryptographic provider handle + * abData [I] The padded data + * dwDataLen [I] Length of the padded data + * abBuffer [O] Data without padding will be stored here + * dwBufferLen [I/O] I: Length of the buffer, O: Length of unpadded data + * dwFlags [I] Currently only CRYPT_OAEP is defined + * + * RETURNS + * Success: TRUE + * Failure: FALSE + */ +static BOOL unpad_data_oaep(HCRYPTPROV hProv, const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer, DWORD *dwBufferLen, + DWORD dwFlags) +{ + CRYPT_DATA_BLOB blobDbMask = {0}, blobSeedMask = {0}; + HCRYPTHASH hHash; + BYTE *pbBuffer = NULL, *pbHashValue = NULL; + const BYTE *pbPaddedSeed, *pbPaddedDb; + BYTE *pbUnpaddedSeed, *pbUnpaddedDb; + DWORD dwLen, dwHashLen; + DWORD dwSeedLen, dwDbLen; + DWORD dwZeroCount, dwMsgCount; + BOOL result, ret = FALSE; + DWORD i; + + RSAENH_CPCreateHash(hProv, CALG_SHA1, 0, 0, &hHash); + RSAENH_CPHashData(hProv, hHash, 0, 0, 0); + dwLen = sizeof(dwHashLen); + RSAENH_CPGetHashParam(hProv, hHash, HP_HASHSIZE, (BYTE *)&dwHashLen, &dwLen, 0); + if (dwDataLen < 2 * dwHashLen + 2) + { + SetLastError(NTE_BAD_DATA); + goto done; + } + + /* Get default hash value */ + pbHashValue = HeapAlloc(GetProcessHeap(), 0, dwHashLen); + if (!pbHashValue) + { + SetLastError(NTE_NO_MEMORY); + goto done; + } + dwLen = dwHashLen; + RSAENH_CPGetHashParam(hProv, hHash, HP_HASHVAL, pbHashValue, &dwLen, 0); + + /* Store seed and DB */ + pbBuffer = HeapAlloc(GetProcessHeap(), 0, dwDataLen - 1); + if (!pbBuffer) + { + SetLastError(NTE_NO_MEMORY); + goto done; + } + + pbPaddedSeed = abData + 1; + pbPaddedDb = abData + 1 + dwHashLen; + pbUnpaddedSeed = pbBuffer; + pbUnpaddedDb = pbBuffer + dwHashLen; + dwSeedLen = dwHashLen; + dwDbLen = dwDataLen - dwHashLen - 1; + + /* Get unpadded seed */ + result = pkcs1_mgf1(hProv, pbPaddedDb, dwDbLen, dwSeedLen, &blobSeedMask); + if (!result) goto done; + for (i = 0; i < dwSeedLen; i++) pbUnpaddedSeed[i] = pbPaddedSeed[i] ^ blobSeedMask.pbData[i]; + + /* Get unpadded DB */ + result = pkcs1_mgf1(hProv, pbUnpaddedSeed, dwSeedLen, dwDbLen, &blobDbMask); + if (!result) goto done; + for (i = 0; i < dwDbLen; i++) pbUnpaddedDb[i] = pbPaddedDb[i] ^ blobDbMask.pbData[i]; + + /* Compare hash in DB */ + result = memcmp(pbUnpaddedDb, pbHashValue, dwHashLen); + + /* Get count of zero paddings(PS) */ + dwZeroCount = 0; + while (dwHashLen + dwZeroCount + 1 <= dwDbLen && pbUnpaddedDb[dwHashLen + dwZeroCount] == 0) dwZeroCount++; + dwMsgCount = dwDbLen - dwHashLen - dwZeroCount - 1; + + if (dwHashLen + dwZeroCount + 1 > dwDbLen || abData[0] || result || pbUnpaddedDb[dwHashLen + dwZeroCount] != 1 + || *dwBufferLen < dwMsgCount) + { + SetLastError(NTE_BAD_DATA); + goto done; + } + + *dwBufferLen = dwMsgCount; + memcpy(abBuffer, pbUnpaddedDb + dwHashLen + dwZeroCount + 1, dwMsgCount); + ret = TRUE; +done: + RSAENH_CPDestroyHash(hProv, hHash); + HeapFree(GetProcessHeap(), 0, pbHashValue); + HeapFree(GetProcessHeap(), 0, pbBuffer); + free_data_blob(&blobDbMask); + free_data_blob(&blobSeedMask); + return ret; +} + /****************************************************************************** * unpad_data [Internal] * * Remove the padding from RSA decrypted data according to padding format * * PARAMS + * hProv [I] Cryptographic provider handle * abData [I] The padded data * dwDataLen [I] Length of the padded data * abBuffer [O] Data without padding will be stored here * dwBufferLen [I/O] I: Length of the buffer, O: Length of unpadded data - * dwFlags [I] Currently none defined + * dwFlags [I] 0 or CRYPT_OAEP * * RETURNS * Success: TRUE * Failure: FALSE */ -static BOOL unpad_data(const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer, DWORD *dwBufferLen, +static BOOL unpad_data(HCRYPTPROV hProv, const BYTE *abData, DWORD dwDataLen, BYTE *abBuffer, DWORD *dwBufferLen, DWORD dwFlags) { - return unpad_data_pkcs1(abData, dwDataLen, abBuffer, dwBufferLen, dwFlags); + if (dwFlags == CRYPT_OAEP) + return unpad_data_oaep(hProv, abData, dwDataLen, abBuffer, dwBufferLen, dwFlags); + else + return unpad_data_pkcs1(abData, dwDataLen, abBuffer, dwBufferLen, dwFlags); }
/****************************************************************************** @@ -2187,7 +2457,7 @@ BOOL WINAPI RSAENH_CPDuplicateKey(HCRYPTPROV hUID, HCRYPTKEY hKey, DWORD *pdwRes * hKey [I] The key used to encrypt the data. * hHash [I] An optional hash object for parallel hashing. See notes. * Final [I] Indicates if this is the last block of data to encrypt. - * dwFlags [I] Currently no flags defined. Must be zero. + * dwFlags [I] Must be zero or CRYPT_OAEP * pbData [I/O] Pointer to the data to encrypt. Encrypted data will also be stored there. * pdwDataLen [I/O] I: Length of data to encrypt, O: Length of encrypted data. * dwBufLen [I] Size of the buffer at pbData. @@ -2219,7 +2489,7 @@ BOOL WINAPI RSAENH_CPEncrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash, return FALSE; }
- if (dwFlags) + if (dwFlags != 0 && dwFlags != CRYPT_OAEP) { SetLastError(NTE_BAD_FLAGS); return FALSE; @@ -2316,7 +2586,7 @@ BOOL WINAPI RSAENH_CPEncrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash, SetLastError(ERROR_MORE_DATA); return FALSE; } - if (!pad_data(pbData, *pdwDataLen, pbData, pCryptKey->dwBlockLen, dwFlags)) return FALSE; + if (!pad_data(hProv, pbData, *pdwDataLen, pbData, pCryptKey->dwBlockLen, dwFlags)) return FALSE; encrypt_block_impl(pCryptKey->aiAlgid, PK_PUBLIC, &pCryptKey->context, pbData, pbData, RSAENH_ENCRYPT); *pdwDataLen = pCryptKey->dwBlockLen; Final = TRUE; @@ -2340,7 +2610,7 @@ BOOL WINAPI RSAENH_CPEncrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash, * hKey [I] The key used to decrypt the data. * hHash [I] An optional hash object for parallel hashing. See notes. * Final [I] Indicates if this is the last block of data to decrypt. - * dwFlags [I] Currently no flags defined. Must be zero. + * dwFlags [I] Must be zero or CRYPT_OAEP * pbData [I/O] Pointer to the data to decrypt. Plaintext will also be stored there. * pdwDataLen [I/O] I: Length of ciphertext, O: Length of plaintext. * @@ -2371,7 +2641,7 @@ BOOL WINAPI RSAENH_CPDecrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash, return FALSE; }
- if (dwFlags) + if (dwFlags != 0 && dwFlags != CRYPT_OAEP) { SetLastError(NTE_BAD_FLAGS); return FALSE; @@ -2459,7 +2729,7 @@ BOOL WINAPI RSAENH_CPDecrypt(HCRYPTPROV hProv, HCRYPTKEY hKey, HCRYPTHASH hHash, return FALSE; } encrypt_block_impl(pCryptKey->aiAlgid, PK_PRIVATE, &pCryptKey->context, pbData, pbData, RSAENH_DECRYPT); - if (!unpad_data(pbData, pCryptKey->dwBlockLen, pbData, pdwDataLen, dwFlags)) return FALSE; + if (!unpad_data(hProv, pbData, pCryptKey->dwBlockLen, pbData, pdwDataLen, dwFlags)) return FALSE; Final = TRUE; } else { SetLastError(NTE_BAD_TYPE); @@ -2503,7 +2773,7 @@ static BOOL crypt_export_simple(CRYPTKEY *pCryptKey, CRYPTKEY *pPubKey,
*pAlgid = pPubKey->aiAlgid;
- if (!pad_data(pCryptKey->abKeyValue, pCryptKey->dwKeyLen, (BYTE*)(pAlgid+1), + if (!pad_data(pCryptKey->hProv, pCryptKey->abKeyValue, pCryptKey->dwKeyLen, (BYTE*)(pAlgid+1), pPubKey->dwBlockLen, dwFlags)) { return FALSE; @@ -2966,7 +3236,7 @@ static BOOL import_symmetric_key(HCRYPTPROV hProv, const BYTE *pbData, DWORD dwD RSAENH_DECRYPT);
dwKeyLen = RSAENH_MAX_KEY_SIZE; - if (!unpad_data(pbDecrypted, pPubKey->dwBlockLen, pbDecrypted, &dwKeyLen, dwFlags)) { + if (!unpad_data(hProv, pbDecrypted, pPubKey->dwBlockLen, pbDecrypted, &dwKeyLen, dwFlags)) { HeapFree(GetProcessHeap(), 0, pbDecrypted); return FALSE; } diff --git a/dlls/rsaenh/tests/rsaenh.c b/dlls/rsaenh/tests/rsaenh.c index df31aeb54b..3e782be279 100644 --- a/dlls/rsaenh/tests/rsaenh.c +++ b/dlls/rsaenh/tests/rsaenh.c @@ -2282,6 +2282,7 @@ static void test_rsa_encrypt(void) BYTE abData[2048] = "Wine rocks!"; BOOL result; DWORD dwVal, dwLen; + DWORD err;
/* It is allowed to use the key exchange key for encryption/decryption */ result = CryptGetUserKey(hProv, AT_KEYEXCHANGE, &hRSAKey); @@ -2297,6 +2298,7 @@ static void test_rsa_encrypt(void) } ok(result, "CryptEncrypt failed: %08x\n", GetLastError()); ok(dwLen == 128, "Unexpected length %d\n", dwLen); + /* PKCS1 V1.5 */ dwLen = 12; result = CryptEncrypt(hRSAKey, 0, TRUE, 0, abData, &dwLen, (DWORD)sizeof(abData)); ok (result, "%08x\n", GetLastError()); @@ -2304,7 +2306,52 @@ static void test_rsa_encrypt(void)
result = CryptDecrypt(hRSAKey, 0, TRUE, 0, abData, &dwLen); ok (result && dwLen == 12 && !memcmp(abData, "Wine rocks!", 12), "%08x\n", GetLastError()); - + + /* OAEP, RFC 8017 PKCS #1 V2.2 */ + /* Test minimal buffer length requirement */ + dwLen = 1; + SetLastError(0xdeadbeef); + result = CryptEncrypt(hRSAKey, 0, TRUE, CRYPT_OAEP, abData, &dwLen, 20 * 2 + 2); + err = GetLastError(); + ok(!result && err == ERROR_MORE_DATA, "%08x\n", err); + + /* Test data length limit */ + dwLen = sizeof(abData) - (20 * 2 + 2) + 1; + result = CryptEncrypt(hRSAKey, 0, TRUE, CRYPT_OAEP, abData, &dwLen, (DWORD)sizeof(abData)); + err = GetLastError(); + ok(!result && err == NTE_BAD_LEN, "%08x\n", err); + + /* Test malformed data */ + dwLen = 12; + SetLastError(0xdeadbeef); + memcpy(abData, "Wine rocks!", dwLen); + result = CryptDecrypt(hRSAKey, 0, TRUE, CRYPT_OAEP, abData, &dwLen); + err = GetLastError(); + /* NTE_DOUBLE_ENCRYPT on xp or 2003 */ + ok(!result && (err == NTE_BAD_DATA || broken(err == NTE_DOUBLE_ENCRYPT)), "%08x\n", err); + + /* Test decrypt with insufficient buffer */ + dwLen = 12; + SetLastError(0xdeadbeef); + memcpy(abData, "Wine rocks!", 12); + result = CryptEncrypt(hRSAKey, 0, TRUE, CRYPT_OAEP, abData, &dwLen, (DWORD)sizeof(abData)); + ok(result, "%08x\n", GetLastError()); + dwLen = 11; + SetLastError(0xdeadbeef); + result = CryptDecrypt(hRSAKey, 0, TRUE, CRYPT_OAEP, abData, &dwLen); + err = GetLastError(); + /* broken on xp or 2003 */ + ok((!result && dwLen == 11 && err == NTE_BAD_DATA) || broken(result == TRUE && dwLen == 12 && err == ERROR_NO_TOKEN), + "%08x %d %08x\n", result, dwLen, err); + + /* Test normal encryption and decryption */ + dwLen = 12; + memcpy(abData, "Wine rocks!", dwLen); + result = CryptEncrypt(hRSAKey, 0, TRUE, CRYPT_OAEP, abData, &dwLen, (DWORD)sizeof(abData)); + ok(result, "%08x\n", GetLastError()); + result = CryptDecrypt(hRSAKey, 0, TRUE, CRYPT_OAEP, abData, &dwLen); + ok(result && dwLen == 12 && !memcmp(abData, "Wine rocks!", 12), "%08x\n", GetLastError()); + dwVal = 0xdeadbeef; dwLen = sizeof(DWORD); result = CryptGetKeyParam(hRSAKey, KP_PERMISSIONS, (BYTE*)&dwVal, &dwLen, 0);