Signed-off-by: Brendan Shanks bshanks@codeweavers.com --- dlls/ntdll/tests/Makefile.in | 1 + dlls/ntdll/tests/umip.c | 301 +++++++++++++++++++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 dlls/ntdll/tests/umip.c
diff --git a/dlls/ntdll/tests/Makefile.in b/dlls/ntdll/tests/Makefile.in index ed15c51339..e866c54149 100644 --- a/dlls/ntdll/tests/Makefile.in +++ b/dlls/ntdll/tests/Makefile.in @@ -23,4 +23,5 @@ C_SRCS = \ string.c \ threadpool.c \ time.c \ + umip.c \ virtual.c diff --git a/dlls/ntdll/tests/umip.c b/dlls/ntdll/tests/umip.c new file mode 100644 index 0000000000..f39149dd8d --- /dev/null +++ b/dlls/ntdll/tests/umip.c @@ -0,0 +1,301 @@ +/* + * Unit test suite for x86 instructions protected by UMIP. + * + * Copyright (C) 2019 Brendan Shanks for CodeWeavers + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA + */ + +#include <stdio.h> +#include <stdint.h> + +#include "ntstatus.h" +#define WIN32_NO_STATUS +#include "windef.h" +#include "winternl.h" +#include "wine/test.h" + +#if defined(__x86_64__) || defined(__i386__) + +static PVOID (WINAPI *pRtlAddVectoredExceptionHandler)(ULONG first, PVECTORED_EXCEPTION_HANDLER func); +static ULONG (WINAPI *pRtlRemoveVectoredExceptionHandler)(PVOID handler); + +static void *code_mem; + +/* sldt, str, smsw all store a 16-bit value to a register or memory. + * When the destination is a register, the 16-bit value is zero-extended. + * When the destination is memory, only the 16-bit value is stored. + */ +static void test_reg_mem_16(const char *insn_name, + const BYTE *reg16_code, UINT reg16_code_size, + const BYTE *reg32_code, UINT reg32_code_size, + const BYTE *reg64_code, UINT reg64_code_size, + const BYTE *mem16_code, UINT mem16_code_size) +{ + UINT16 value16; + UINT32 value32; + + UINT16 (*reg16_func)(void) = code_mem; + UINT32 (*reg32_func)(void) = code_mem; + void (*mem16_func)(UINT16 *value) = code_mem; + + /* Destination is 16-bit register */ + memset(&value16, 0xcc, sizeof(value16)); + memcpy(code_mem, reg16_code, reg16_code_size); + value16 = reg16_func(); + trace("%s reg16: 0x%x\n", insn_name, value16); + + /* Destination is 32-bit register. */ + memset(&value32, 0xcc, sizeof(value32)); + memcpy(code_mem, reg32_code, reg32_code_size); + value32 = reg32_func(); + trace("%s reg32: 0x%x\n", insn_name, value32); +#if defined(__x86_64__) + /* For sldt/str in 64-bit mode, the upper 16 bits is defined to be zero */ + if (!strcmp(insn_name, "sldt") || !strcmp(insn_name, "str")) + ok(value32 >> 16 == 0, "%s expected upper 16 bits = 0, got 0x%x\n", insn_name, value32 >> 16); +#endif + +#if defined(__x86_64__) + /* Destination is 64-bit register */ + { + UINT64 value64; + UINT64 (*reg64_func)(void) = code_mem; + + memset(&value64, 0xcc, sizeof(value64)); + memcpy(code_mem, reg64_code, reg64_code_size); + value64 = reg64_func(); + trace("%s reg64: 0x%llx\n", insn_name, value64); + + /* For sldt/str the upper 48 bits is defined to be zero */ + if (!strcmp(insn_name, "sldt") || !strcmp(insn_name, "str")) + ok(value64 >> 16 == 0, "%s expected upper 48 bits = 0, got 0x%llx\n", insn_name, value64 >> 16); + } +#endif + + /* Destination is memory (only the low 16 bits are defined to be written) */ + memset(&value32, 0xcc, sizeof(value32)); + memcpy(code_mem, mem16_code, mem16_code_size); + mem16_func((UINT16 *)&value32); + trace("%s mem: 0x%x\n", insn_name, value32); + ok(value32 >> 16 == 0xcccc, "%s expected upper 16 bits = 0xcccc, got 0x%x\n", insn_name, value32 >> 16); +} + +/* sgdt and sidt write the descriptor table register to memory. + * The descriptor consists of a 2-byte limit field, and a base field which + * is 4 bytes in 32-bit mode and 8 bytes in 64-bit mode. + */ +static void test_mem_descriptor(const char *insn_name, const BYTE *code, UINT code_size) +{ + BYTE descriptor[10]; + void (*func)(BYTE *value) = code_mem; + + memset(descriptor, 0xcc, sizeof(descriptor)); + memcpy(code_mem, code, code_size); + func(descriptor); + trace("%s limit: 0x%x\n", insn_name, *(UINT16 *)&descriptor[0]); + trace("%s base: 0x%p\n", insn_name, (void *) *(UINT_PTR *)&descriptor[2]); +} + +static void test_sldt(void) +{ + const BYTE reg16[] = { + 0x66, 0x0f, 0x00, 0xc0, /* sldt ax */ + 0xc3, /* ret */ + }; + const BYTE reg32[] = { + 0x0f, 0x00, 0xc0, /* sldt eax */ + 0xc3, /* ret */ + }; + const BYTE reg64[] = { + 0x48, 0x0f, 0x00, 0xc0, /* sldt rax */ + 0xc3, /* ret */ + }; +#if defined(__x86_64__) + const BYTE mem16[] = { + 0x0f, 0x00, 0x01, /* sldt word [rcx] */ + 0xc3, /* ret */ + }; +#else + const BYTE mem16[] = { + 0x8b, 0x4c, 0x24, 0x04, /* mov ecx, dword [esp] */ + 0x0f, 0x00, 0x01, /* sldt word [ecx] */ + 0xc3, /* ret */ + }; +#endif + + test_reg_mem_16("sldt", reg16, sizeof(reg16), reg32, sizeof(reg32), reg64, sizeof(reg64), mem16, sizeof(mem16)); +} + +static void test_str(void) +{ + const BYTE reg16[] = { + 0x66, 0x0f, 0x00, 0xc8, /* str ax */ + 0xc3, /* ret */ + }; + const BYTE reg32[] = { + 0x0f, 0x00, 0xc8, /* str eax */ + 0xc3, /* ret */ + }; + const BYTE reg64[] = { + 0x48, 0x0f, 0x00, 0xc8, /* str rax */ + 0xc3, /* ret */ + }; +#if defined(__x86_64__) + const BYTE mem16[] = { + 0x0f, 0x00, 0x09, /* str word [rcx] */ + 0xc3, /* ret */ + }; +#else + const BYTE mem16[] = { + 0x8b, 0x4c, 0x24, 0x04, /* mov ecx, dword [esp] */ + 0x0f, 0x00, 0x09, /* str word [rcx] */ + 0xc3, /* ret */ + }; +#endif + + test_reg_mem_16("str", reg16, sizeof(reg16), reg32, sizeof(reg32), reg64, sizeof(reg64), mem16, sizeof(mem16)); +} + +static void test_sgdt(void) +{ + /* sgdt destination must be memory */ +#if defined(__x86_64__) + const BYTE mem_code[] = { + 0x0f, 0x01, 0x01, /* sgdt rcx */ + 0xc3, /* ret */ + }; +#else + const BYTE mem_code[] = { + 0x8b, 0x4c, 0x24, 0x04, /* mov ecx, dword [esp] */ + 0x0f, 0x01, 0x01, /* sgdt ecx */ + 0xc3, /* ret */ + }; +#endif + + test_mem_descriptor("sgdt", mem_code, sizeof(mem_code)); +} + +static void test_sidt(void) +{ + /* sidt destination must be memory */ +#if defined(__x86_64__) + const BYTE mem_code[] = { + 0x0f, 0x01, 0x09, /* sidt rcx */ + 0xc3, /* ret */ + }; +#else + const BYTE mem_code[] = { + 0x8b, 0x4c, 0x24, 0x04, /* mov ecx, dword [esp] */ + 0x0f, 0x01, 0x09, /* sidt ecx */ + 0xc3, /* ret */ + }; +#endif + + test_mem_descriptor("sidt", mem_code, sizeof(mem_code)); +} + +static void test_smsw(void) +{ + const BYTE reg16[] = { + 0x66, 0x0f, 0x01, 0xe0, /* smsw ax */ + 0xc3, /* ret */ + }; + const BYTE reg32[] = { + 0x0f, 0x01, 0xe0, /* smsw eax */ + 0xc3, /* ret */ + }; + const BYTE reg64[] = { + 0x48, 0x0f, 0x01, 0xe0, /* smsw rax */ + 0xc3, /* ret */ + }; + +#if defined(__x86_64__) + const BYTE mem16[] = { + 0x0f, 0x01, 0x21, /* smsw word [rcx] */ + 0xc3, /* ret */ + }; +#else + const BYTE mem16[] = { + 0x8b, 0x4c, 0x24, 0x04, /* mov ecx, dword [esp] */ + 0x0f, 0x01, 0x21, /* smsw word [ecx] */ + 0xc3, /* ret */ + }; +#endif + + test_reg_mem_16("smsw", reg16, sizeof(reg16), reg32, sizeof(reg32), reg64, sizeof(reg64), mem16, sizeof(mem16)); +} + +static LONG CALLBACK umip_vectored_handler(EXCEPTION_POINTERS *ExceptionInfo) +{ + PEXCEPTION_RECORD rec = ExceptionInfo->ExceptionRecord; + trace("vectored exception handler %08x addr:%p\n", rec->ExceptionCode, rec->ExceptionAddress); + + ok (!(rec->ExceptionCode == EXCEPTION_ACCESS_VIOLATION && + rec->ExceptionInformation[0] == 0 && + rec->ExceptionInformation[1] == UINTPTR_MAX), + "vectored_handler caught fault for unemulated UMIP instruction, exiting\n"); + + ExitProcess(1); + + return EXCEPTION_CONTINUE_SEARCH; +} +#endif /* __x86_64__ || __i386__ */ + +START_TEST(umip) +{ + /* Test that sldt, str, sgdt, sidt, and smsw can be executed with + * all possible operand types (registers/memory of different widths). + * + * We mostly cannot test/predict the returned values, but on a UMIP-enabled + * system without emulation the instructions willl trigger a SIGSEGV. + * A non-first-chance vectored exception handler is added to catch the exception, + * fail the test, and exit the process if that happens. + */ + +#if defined(__x86_64__) || defined(__i386__) + PVOID vectored_handler; + HMODULE hntdll = GetModuleHandleA("ntdll.dll"); + + pRtlAddVectoredExceptionHandler = (void *)GetProcAddress(hntdll, "RtlAddVectoredExceptionHandler"); + pRtlRemoveVectoredExceptionHandler = (void *)GetProcAddress(hntdll, "RtlRemoveVectoredExceptionHandler"); + + if (!pRtlAddVectoredExceptionHandler || !pRtlRemoveVectoredExceptionHandler) { + trace("RtlAddVectoredExceptionHandler or RtlRemoveVectoredExceptionHandler not found\n"); + return; + } + + vectored_handler = pRtlAddVectoredExceptionHandler(FALSE, &umip_vectored_handler); + if (!vectored_handler) { + trace("RtlAddVectoredExceptionHandler failed\n"); + return; + } + + code_mem = VirtualAlloc(NULL, 65536, MEM_RESERVE | MEM_COMMIT, PAGE_EXECUTE_READWRITE); + if(!code_mem) { + trace("VirtualAlloc failed\n"); + return; + } + + test_sldt(); + test_str(); + test_sgdt(); + test_sidt(); + test_smsw(); + + VirtualFree(code_mem, 0, MEM_RELEASE); + pRtlRemoveVectoredExceptionHandler(vectored_handler); +#endif +}