Signed-off-by: Andrew Wesie awesie@gmail.com --- dlls/ntdll/ntdll.spec | 7 +++++-- dlls/ntdll/ntdll_misc.h | 3 --- dlls/ntdll/time.c | 8 ++++++-- include/winternl.h | 1 + 4 files changed, 12 insertions(+), 7 deletions(-)
diff --git a/dlls/ntdll/ntdll.spec b/dlls/ntdll/ntdll.spec index 050ebc7..4ffb3c9 100644 --- a/dlls/ntdll/ntdll.spec +++ b/dlls/ntdll/ntdll.spec @@ -194,7 +194,7 @@ @ stdcall NtGetCurrentProcessorNumber() # @ stub NtGetDevicePowerState @ stub NtGetPlugPlayEvent -@ stdcall -ret64 NtGetTickCount() get_tick_count64 +@ stdcall NtGetTickCount() @ stdcall NtGetWriteWatch(long long ptr long ptr ptr ptr) @ stdcall NtImpersonateAnonymousToken(long) @ stub NtImpersonateClientOfPort @@ -1142,7 +1142,7 @@ @ stdcall -private ZwGetCurrentProcessorNumber() NtGetCurrentProcessorNumber # @ stub ZwGetDevicePowerState @ stub ZwGetPlugPlayEvent -@ stdcall -private -ret64 ZwGetTickCount() get_tick_count64 +@ stdcall -private ZwGetTickCount() NtGetTickCount @ stdcall -private ZwGetWriteWatch(long long ptr long ptr ptr ptr) NtGetWriteWatch @ stdcall -private ZwImpersonateAnonymousToken(long) NtImpersonateAnonymousToken @ stub ZwImpersonateClientOfPort @@ -1529,3 +1529,6 @@ # Filesystem @ cdecl wine_nt_to_unix_file_name(ptr ptr long long) @ cdecl wine_unix_to_nt_file_name(ptr ptr) + +# Time +@ cdecl -ret64 wine_get_tick_count64() diff --git a/dlls/ntdll/ntdll_misc.h b/dlls/ntdll/ntdll_misc.h index 2d83f54..72510b1c 100644 --- a/dlls/ntdll/ntdll_misc.h +++ b/dlls/ntdll/ntdll_misc.h @@ -269,7 +269,4 @@ void WINAPI LdrInitializeThunk(CONTEXT*,void**,ULONG_PTR,ULONG_PTR); int __cdecl NTDLL_tolower( int c ); int __cdecl _stricmp( LPCSTR str1, LPCSTR str2 );
-/* time functions */ -ULONGLONG WINAPI get_tick_count64( void ); -#define NtGetTickCount get_tick_count64 #endif diff --git a/dlls/ntdll/time.c b/dlls/ntdll/time.c index 41e4563..d3853e0 100644 --- a/dlls/ntdll/time.c +++ b/dlls/ntdll/time.c @@ -552,14 +552,18 @@ NTSTATUS WINAPI NtQueryPerformanceCounter( LARGE_INTEGER *counter, LARGE_INTEGER return STATUS_SUCCESS; }
+ULONGLONG CDECL wine_get_tick_count64(void) +{ + return monotonic_counter() / TICKSPERMSEC; +}
/****************************************************************************** * NtGetTickCount (NTDLL.@) * ZwGetTickCount (NTDLL.@) */ -ULONGLONG WINAPI DECLSPEC_HOTPATCH get_tick_count64(void) +ULONG WINAPI NtGetTickCount(void) { - return monotonic_counter() / TICKSPERMSEC; + return wine_get_tick_count64(); }
/* calculate the mday of dst change date, so that for instance Sun 5 Oct 2007 diff --git a/include/winternl.h b/include/winternl.h index e7f89b0..921386d 100644 --- a/include/winternl.h +++ b/include/winternl.h @@ -2946,6 +2946,7 @@ NTSYSAPI void WINAPI TpWaitForWork(TP_WORK *,BOOL); NTSYSAPI NTSTATUS CDECL wine_nt_to_unix_file_name( const UNICODE_STRING *nameW, ANSI_STRING *unix_name_ret, UINT disposition, BOOLEAN check_case ); NTSYSAPI NTSTATUS CDECL wine_unix_to_nt_file_name( const ANSI_STRING *name, UNICODE_STRING *nt ); +NTSYSAPI ULONGLONG CDECL wine_get_tick_count64( void );
/***********************************************************************
Fixes regression in 3e927c4aec9dbeef930b83f62ee0651b8c147247.
Wine-Bug: https://bugs.winehq.org/show_bug.cgi?id=47265 Signed-off-by: Andrew Wesie awesie@gmail.com --- dlls/kernel32/kernel32.spec | 4 ++-- dlls/kernel32/kernel_main.c | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-)
diff --git a/dlls/kernel32/kernel32.spec b/dlls/kernel32/kernel32.spec index 5f3f1fa..c01f132 100644 --- a/dlls/kernel32/kernel32.spec +++ b/dlls/kernel32/kernel32.spec @@ -858,8 +858,8 @@ @ stdcall GetThreadPriorityBoost(long ptr) @ stdcall GetThreadSelectorEntry(long long ptr) @ stdcall GetThreadTimes(long ptr ptr ptr ptr) -@ stdcall GetTickCount() ntdll.NtGetTickCount -@ stdcall -ret64 GetTickCount64() ntdll.NtGetTickCount +@ stdcall GetTickCount() +@ stdcall -ret64 GetTickCount64() @ stdcall GetTimeFormatA(long long ptr str ptr long) @ stdcall GetTimeFormatEx(wstr long ptr wstr ptr long) @ stdcall GetTimeFormatW(long long ptr wstr ptr long) diff --git a/dlls/kernel32/kernel_main.c b/dlls/kernel32/kernel_main.c index dfa66f0..4c8edf9 100644 --- a/dlls/kernel32/kernel_main.c +++ b/dlls/kernel32/kernel_main.c @@ -179,6 +179,33 @@ INT WINAPI MulDiv( INT nMultiplicand, INT nMultiplier, INT nDivisor) }
/****************************************************************************** + * GetTickCount64 (KERNEL32.@) + */ +ULONGLONG WINAPI DECLSPEC_HOTPATCH GetTickCount64(void) +{ + return wine_get_tick_count64(); +} + +/*********************************************************************** + * GetTickCount (KERNEL32.@) + * + * Get the number of milliseconds the system has been running. + * + * PARAMS + * None. + * + * RETURNS + * The current tick count. + * + * NOTES + * The value returned will wrap around every 2^32 milliseconds. + */ +DWORD WINAPI DECLSPEC_HOTPATCH GetTickCount(void) +{ + return wine_get_tick_count64(); +} + +/****************************************************************************** * GetSystemRegistryQuota (KERNEL32.@) */ BOOL WINAPI GetSystemRegistryQuota(PDWORD pdwQuotaAllowed, PDWORD pdwQuotaUsed)
Signed-off-by: Andrew Wesie awesie@gmail.com --- dlls/ntdll/tests/Makefile.in | 1 + dlls/ntdll/tests/hooks.c | 728 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 729 insertions(+) create mode 100644 dlls/ntdll/tests/hooks.c
diff --git a/dlls/ntdll/tests/Makefile.in b/dlls/ntdll/tests/Makefile.in index 5c70f3f..542def9 100644 --- a/dlls/ntdll/tests/Makefile.in +++ b/dlls/ntdll/tests/Makefile.in @@ -10,6 +10,7 @@ C_SRCS = \ exception.c \ file.c \ generated.c \ + hooks.c \ info.c \ large_int.c \ om.c \ diff --git a/dlls/ntdll/tests/hooks.c b/dlls/ntdll/tests/hooks.c new file mode 100644 index 0000000..64dea87 --- /dev/null +++ b/dlls/ntdll/tests/hooks.c @@ -0,0 +1,728 @@ +/* + * Unit test suite for hooking ntdll functions + * + * Copyright 2018 Andrew Wesie + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA + */ + +#define NONAMELESSUNION +#include "ntdll_test.h" +#include "excpt.h" + +#ifdef __i386__ + +/* ntdll exports */ +static void (WINAPI *pLdrInitializeThunk)(PCONTEXT,ULONG_PTR,ULONG_PTR,ULONG_PTR); +static NTSTATUS (WINAPI *pNtContinue)(PCONTEXT,BOOLEAN); +static NTSTATUS (WINAPI *pNtCreateThread)(PHANDLE,ACCESS_MASK,POBJECT_ATTRIBUTES,HANDLE,PCLIENT_ID,PCONTEXT,PINITIAL_TEB,BOOLEAN); +static ULONG (WINAPI *pNtGetTickCount)(VOID); +static NTSTATUS (WINAPI *pNtQueryInformationThread)(HANDLE,THREADINFOCLASS,PVOID,ULONG,PULONG); + +/* kernel32 exports */ +static DWORD (WINAPI *pGetTickCount)(VOID); +static ULONGLONG (WINAPI *pGetTickCount64)(VOID); + +struct hook_state +{ + void *target; + void *mem; + BYTE original[64]; + SIZE_T count; + void *callback; + PVOID exception_handler; +}; + +static DWORD callback_result; +static ULONG callback_process_id; +static ULONG callback_thread_id; +static CLIENT_ID callback_client_id; +static CLIENT_ID *callback_client_id_ptr; +static HANDLE *callback_handle_ptr; + +/* Code from dlls/kernel32/thread.c. Required since API is missing on WinXP. */ +static DWORD get_thread_id(HANDLE thread) +{ + THREAD_BASIC_INFORMATION tbi; + NTSTATUS status; + + status = pNtQueryInformationThread(thread, ThreadBasicInformation, &tbi, + sizeof(tbi), NULL); + if (status) + { + SetLastError( RtlNtStatusToDosError(status) ); + return 0; + } + + return HandleToULong(tbi.ClientId.UniqueThread); +} + +static BOOL modrm_size(PBYTE *pp, BOOL addr16) +{ + PBYTE p = *pp; + BYTE mod = (p[0] >> 6) & 3, rm = p[0] & 7; + p++; + + if (!addr16 && mod != 3 && rm == 4) + p++; /* SIB */ + + if (!addr16 && (mod == 0 && rm == 5)) + p += 4; /* disp32 */ + else if (mod == 1) + p += 1; /* ... + disp 8 */ + else if (mod == 2) + p += addr16 ? 2 : 4; /* ... + disp16 / disp32 */ + + *pp = p; + return TRUE; +} + +static int instruction_size(void *ip) +{ + BOOL pfx66 = FALSE; /* operand-size prefix */ + BOOL pfx67 = FALSE; /* address-size prefix */ + PBYTE p = ip; + BYTE ext; + + while (1) + { + switch (*p) + { + case 0x66: + pfx66 = TRUE; + p++; + break; + case 0x67: + pfx67 = TRUE; + p++; + break; + case 0x26: + case 0x2E: + case 0x36: + case 0x3E: + case 0x64: + case 0x65: + case 0x9B: + case 0xF0: + case 0xF2: + case 0xF3: + p++; + break; + default: + goto no_prefix; + } + } +no_prefix: + + switch (*p) + { + case 0x00: case 0x01: case 0x02: case 0x03: + case 0x08: case 0x09: case 0x0A: case 0x0B: + case 0x10: case 0x11: case 0x12: case 0x13: + case 0x18: case 0x19: case 0x1A: case 0x1B: + case 0x20: case 0x21: case 0x22: case 0x23: + case 0x28: case 0x29: case 0x2A: case 0x2B: + case 0x30: case 0x31: case 0x32: case 0x33: + case 0x38: case 0x39: case 0x3A: case 0x3B: + case 0x84: case 0x85: case 0x86: case 0x87: + case 0x88: case 0x89: case 0x8A: case 0x8B: + case 0x8C: case 0x8D: case 0x8E: case 0x8F: + case 0xC4: case 0xC5: case 0xD0: case 0xD1: + case 0xD2: case 0xD3: case 0xFE: case 0xFF: + p++; + if (!modrm_size(&p, pfx67)) return -1; + break; + case 0x04: case 0x0C: case 0x14: case 0x1C: + case 0x24: case 0x2C: case 0x34: case 0x3C: + case 0x6A: case 0xA8: + case 0xB0: case 0xB1: case 0xB2: case 0xB3: + case 0xB4: case 0xB5: case 0xB6: case 0xB7: + case 0xCD: case 0xD4: case 0xD5: + p += 2; + break; + case 0x05: case 0x0D: case 0x15: case 0x1D: + case 0x25: case 0x2D: case 0x35: case 0x3D: + case 0x68: case 0xA9: + case 0xB8: case 0xB9: case 0xBA: case 0xBB: + case 0xBC: case 0xBD: case 0xBE: case 0xBF: + p += 1 + (pfx66 ? 2 : 4); + break; + case 0x06: case 0x07: case 0x0E: case 0x16: + case 0x17: case 0x1E: case 0x1F: case 0x27: + case 0x2F: case 0x37: case 0x3F: + case 0x40: case 0x41: case 0x42: case 0x43: + case 0x44: case 0x45: case 0x46: case 0x47: + case 0x48: case 0x49: case 0x4A: case 0x4B: + case 0x4C: case 0x4D: case 0x4E: case 0x4F: + case 0x50: case 0x51: case 0x52: case 0x53: + case 0x54: case 0x55: case 0x56: case 0x57: + case 0x58: case 0x59: case 0x5A: case 0x5B: + case 0x5C: case 0x5D: case 0x5E: case 0x5F: + case 0x60: case 0x61: + case 0x90: case 0x91: case 0x92: case 0x93: + case 0x94: case 0x95: case 0x96: case 0x97: + case 0x98: case 0x99: case 0x9B: case 0x9C: + case 0x9D: case 0x9E: case 0x9F: case 0xA4: + case 0xA5: case 0xA6: case 0xA7: case 0xAA: + case 0xAB: case 0xAC: case 0xAD: case 0xAE: + case 0xAF: case 0xC3: case 0xC9: case 0xCB: + case 0xCC: case 0xCE: case 0xCF: case 0xF1: + case 0xF4: case 0xF5: case 0xF8: case 0xF9: + case 0xFA: case 0xFB: case 0xFC: case 0xFD: + p++; + break; + case 0x6B: case 0x80: case 0x82: case 0x83: + case 0xC0: case 0xC1: case 0xC6: + p++; + if (!modrm_size(&p, pfx67)) return -1; + p++; + break; + case 0x69: case 0x81: case 0xC7: + p++; + if (!modrm_size(&p, pfx67)) return -1; + p += pfx66 ? 2 : 4; + break; + case 0xA0: case 0xA1: case 0xA2: case 0xA3: + p += 1 + (pfx67 ? 2 : 4); + break; + case 0xC2: case 0xCA: + p += 3; + break; + case 0xF6: + p++; + ext = (*p >> 3) & 7; + if (!modrm_size(&p, pfx67)) return -1; + + switch (ext) + { + case 0: case 1: + p++; + break; + } + break; + case 0xF7: + p++; + ext = (*p >> 3) & 7; + if (!modrm_size(&p, pfx67)) return -1; + + switch (ext) + { + case 0: case 1: + p += pfx66 ? 2 : 4; + break; + } + break; + /* 2-byte opcodes */ + case 0x0F: + p++; + switch (*p) + { + case 0x0D: case 0x18: case 0x19: case 0x1A: + case 0x1B: case 0x1C: case 0x1D: case 0x1E: + case 0x1F: + case 0x90: case 0x91: case 0x92: case 0x93: + case 0x94: case 0x95: case 0x96: case 0x97: + case 0x98: case 0x99: case 0x9A: case 0x9B: + case 0x9C: case 0x9D: case 0x9E: case 0x9F: + case 0xA3: case 0xAB: case 0xAE: case 0xAF: + case 0xB0: case 0xB1: case 0xB3: case 0xB6: + case 0xB7: case 0xBB: case 0xBC: case 0xBD: + case 0xBE: case 0xBF: case 0xC0: case 0xC1: + case 0xC7: + p++; + if (!modrm_size(&p, pfx67)) return -1; + break; + case 0x31: case 0xA0: case 0xA1: case 0xA2: + case 0xA8: case 0xA9: + case 0xC8: case 0xC9: case 0xCA: case 0xCB: + case 0xCC: case 0xCD: case 0xCE: case 0xCF: + p++; + break; + case 0xBA: + p++; + if (!modrm_size(&p, pfx67)) return -1; + p++; + break; + /* unsupported instructions */ + default: + return -1; + } + /* unsupported instructions */ + case 0x62: case 0x63: /* bound / arpl */ + case 0x6C: case 0x6D: case 0x6E: case 0x6F: /* in / out */ + case 0xE4: case 0xE5: case 0xE6: case 0xE7: + case 0xEC: case 0xED: case 0xEE: case 0xEF: + case 0x70: case 0x71: case 0x72: case 0x73: /* jump */ + case 0x74: case 0x75: case 0x76: case 0x77: + case 0x78: case 0x79: case 0x7A: case 0x7B: + case 0x7C: case 0x7D: case 0x7E: case 0x7F: + case 0xE0: case 0xE1: case 0xE2: case 0xE3: + case 0xE9: case 0xEA: case 0xEB: + case 0x9A: case 0xE8: /* call */ + case 0xC8: /* enter */ + case 0xD7: /* xlat */ + case 0xD8: case 0xD9: case 0xDA: case 0xDB: /* fpu */ + case 0xDC: case 0xDD: case 0xDE: case 0xDF: + default: + return -1; + } + return p - (PBYTE)ip; +} + +static BOOL hook_function(struct hook_state *hook, void *func, void *callback) +{ + DWORD prot; + PBYTE p = func; + PBYTE mem; + + hook->callback = callback; + hook->exception_handler = NULL; + hook->target = p; + hook->mem = mem = VirtualAlloc(NULL, 0x1000, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE); + if (!mem) + return FALSE; + + hook->count = 0; + while (hook->count < 5) + { + int rv = instruction_size(p + hook->count); + if (rv < 0) + return FALSE; + hook->count += rv; + } + + memcpy(hook->original, p, hook->count); + + if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot)) + return FALSE; + + *p = 0xe9; /* jmp */ + *(PDWORD)(p + 1) = mem - (p + 5); + + /* copy the arguments */ + *mem++ = 0x56; /* push esi */ + *mem++ = 0x57; /* push edi */ + *mem++ = 0x8d; /* lea esi, [esp + 12] */ + *mem++ = 0x74; + *mem++ = 0x24; + *mem++ = 0x0c; + *mem++ = 0x83; /* sub esp, 64 */ + *mem++ = 0xec; + *mem++ = 0x40; + *mem++ = 0x89; /* mov edi, esp */ + *mem++ = 0xe7; + *mem++ = 0xb9; /* mov ecx, 16 */ + *mem++ = 0x10; + *mem++ = 0x00; + *mem++ = 0x00; + *mem++ = 0x00; + *mem++ = 0xf3; /* rep movsd */ + *mem++ = 0xa5; + /* execute the callback */ + *mem = 0xe8; /* call */ + *(PDWORD)(mem + 1) = (PBYTE)callback - (mem + 5); + mem += 5; + /* restore the stack and registers */ + *mem++ = 0x83; /* add esp, 64 */ + *mem++ = 0xc4; + *mem++ = 0x40; + *mem++ = 0x5f; /* pop edi */ + *mem++ = 0x5e; /* pop esi */ + /* execute the original code */ + memcpy(mem, hook->original, hook->count); + mem += hook->count; + *mem = 0xe9; /* jmp */ + *(PDWORD)(mem + 1) = (p + hook->count) - (mem + 5); + + if (!VirtualProtect(hook->target, hook->count, prot, &prot)) + return FALSE; + + return TRUE; +} + +static BOOL hook_syscall(struct hook_state *hook, void *syscall, void *callback) +{ + DWORD prot; + PBYTE p = syscall; + PBYTE mem; + + hook->callback = callback; + hook->exception_handler = NULL; + hook->target = p; + hook->mem = mem = VirtualAlloc(NULL, 0x1000, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE); + if (!mem) + return FALSE; + + /* Both 32-bit Windows and WoW64 start with "mov eax, 0x..". */ + todo_wine ok(p[0] == 0xb8, "syscall does not begin with 0xb8 got: 0x%02x\n", p[0]); + + hook->count = 5; + + /* make a copy of the original bytes */ + memcpy(hook->original, p, hook->count); + + if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot)) + return FALSE; + + *p = 0xe9; /* jmp */ + *(PDWORD)(p + 1) = mem - (p + 5); + + /* copy the arguments */ + *mem++ = 0x56; /* push esi */ + *mem++ = 0x57; /* push edi */ + *mem++ = 0x8d; /* lea esi, [esp + 12] */ + *mem++ = 0x74; + *mem++ = 0x24; + *mem++ = 0x0c; + *mem++ = 0x83; /* sub esp, 64 */ + *mem++ = 0xec; + *mem++ = 0x40; + *mem++ = 0x89; /* mov edi, esp */ + *mem++ = 0xe7; + *mem++ = 0xb9; /* mov ecx, 16 */ + *mem++ = 0x10; + *mem++ = 0x00; + *mem++ = 0x00; + *mem++ = 0x00; + *mem++ = 0xf3; /* rep movsd */ + *mem++ = 0xa5; + /* execute the callback */ + *mem = 0xe8; /* call */ + *(PDWORD)(mem + 1) = (PBYTE)callback - (mem + 5); + mem += 5; + /* restore the stack and registers */ + *mem++ = 0x83; /* add esp, 64 */ + *mem++ = 0xc4; + *mem++ = 0x40; + *mem++ = 0x5f; /* pop edi */ + *mem++ = 0x5e; /* pop esi */ + /* execute the original code */ + memcpy(mem, hook->original, hook->count); + mem += hook->count; + *mem = 0xe9; /* jmp */ + *(PDWORD)(mem + 1) = (p + hook->count) - (mem + 5); + + if (!VirtualProtect(hook->target, hook->count, prot, &prot)) + return FALSE; + + return TRUE; +} + +struct hook_state *hook_syscall_ret_hook; +static LONG CALLBACK hook_syscall_ret_handler(EXCEPTION_POINTERS *info) +{ + WORD args_size; + + ok(info->ExceptionRecord->ExceptionCode == EXCEPTION_BREAKPOINT, "wrong exception expected: %08X got: %08X\n", + EXCEPTION_BREAKPOINT, info->ExceptionRecord->ExceptionCode); + + ((void (*)(void))hook_syscall_ret_hook->callback)(); + + if (hook_syscall_ret_hook->original[0] == 0xc2) + args_size = *(WORD *)(hook_syscall_ret_hook->original + 1); + else + args_size = 0; + info->ContextRecord->Eip = *(ULONG*)info->ContextRecord->Esp; + info->ContextRecord->Esp += 4 + args_size; + return EXCEPTION_CONTINUE_EXECUTION; +} + +static BOOL hook_syscall_ret(struct hook_state *hook, void *syscall, void *callback) +{ + DWORD prot; + PBYTE p = syscall; + + hook->callback = callback; + hook->exception_handler = NULL; + hook->target = NULL; + hook->mem = NULL; + + /* We do not have enough space for a jump, so we use an exception handler instead. */ + hook_syscall_ret_hook = hook; + hook->exception_handler = AddVectoredExceptionHandler(TRUE, hook_syscall_ret_handler); + if (hook->exception_handler == NULL) + return FALSE; + + ok(p[0] == 0xb8, "syscall[0] expected: 0xb8 got: 0x%02x\n", p[0]); + p += 5; + + while (*p != 0xc2 && *p != 0xc3) + { + int rv = instruction_size(p); + if (rv < 0) + return FALSE; + p += rv; + } + + /* make a copy of the original bytes */ + hook->count = instruction_size(p); + hook->target = p; + memcpy(hook->original, p, hook->count); + + if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot)) + return FALSE; + + *p = 0xcc; /* int 3 */ + + if (!VirtualProtect(hook->target, hook->count, prot, &prot)) + return FALSE; + + return TRUE; +} + +static BOOL remove_hook(struct hook_state *hook) +{ + DWORD prot; + + if (!VirtualProtect(hook->target, hook->count, PAGE_EXECUTE_READWRITE, &prot)) + return FALSE; + + memcpy(hook->target, hook->original, hook->count); + + if (!VirtualProtect(hook->target, hook->count, prot, &prot)) + return FALSE; + + if (hook->exception_handler && !RemoveVectoredExceptionHandler(hook->exception_handler)) + return FALSE; + + if (hook->mem && !VirtualFree(hook->mem, 0, MEM_RELEASE)) + return FALSE; + + return TRUE; +} + +static DWORD WINAPI dummy_thread(LPVOID param) +{ + callback_process_id = GetCurrentProcessId(); + callback_thread_id = GetCurrentThreadId(); + return 0xdeadbeef; +} + +static void callback_LdrInitializeThunk(PCONTEXT context, ULONG_PTR unknown1, ULONG_PTR unknown2, ULONG_PTR unknown3) +{ + callback_result = 1; + + /* context is NULL on Windows XP */ + if (context) + { + ok(context->Eax == (ULONG)dummy_thread, "context->Eax expected: %08X got: %08X\n", + (ULONG)dummy_thread, context->Eax); + ok(context->Ebx == (ULONG)0x12345678, "context->Ebx expected: 12345678 got: %08X\n", context->Ebx); + } +} + +static void test_LdrInitializeThunk(void) +{ + HANDLE handle; + struct hook_state hook; + + ok(hook_function(&hook, pLdrInitializeThunk, callback_LdrInitializeThunk), "failed to hook LdrInitializeThunk\n"); + + callback_result = 0; + handle = CreateThread(NULL, 0, dummy_thread, (LPVOID)0x12345678, 0, NULL); + ok(handle != NULL, "CreateThread failed\n"); + ok(WaitForSingleObject(handle, 1000) == WAIT_OBJECT_0, "wait for thread failed\n"); + ok(callback_result == 1, "callback never ran\n"); + CloseHandle(handle); + + ok(remove_hook(&hook), "failed to remove hook\n"); +} + +static LONG CALLBACK dummy_exception_handler(EXCEPTION_POINTERS *info) +{ + ok(info->ExceptionRecord->ExceptionCode == EXCEPTION_ACCESS_VIOLATION, + "wrong exception expected: %08X got: %08X\n", EXCEPTION_ACCESS_VIOLATION, + info->ExceptionRecord->ExceptionCode); + + info->ContextRecord->Eax = 0x12345678; + info->ContextRecord->Eip = *(ULONG*)info->ContextRecord->Esp; + info->ContextRecord->Esp += 4; + + return EXCEPTION_CONTINUE_EXECUTION; +} + +static void callback_NtContinue(PCONTEXT context, BOOLEAN alert) +{ + callback_result = 1; + ok((context->ContextFlags & CONTEXT_FULL) == CONTEXT_FULL, "wrong context flags expected: %08x got: %08x\n", + CONTEXT_FULL, context->ContextFlags); + ok(context->Eax == 0x12345678, "wrong Eax expected: 0x12345678 got: %08x\n", context->Eax); + ok(context->Eip == *(ULONG*)(context->Esp - 4), "wrong Eip expected: %08x got: %08x\n", + *(ULONG*)(context->Esp - 4), context->Eip); + context->Eax = 0xdeadbeef; +} + +static void test_NtContinue(void) +{ + int result; + PVOID handle; + struct hook_state hook; + + ok(hook_function(&hook, pNtContinue, callback_NtContinue), "failed to hook NtContinue\n"); + + handle = AddVectoredExceptionHandler(TRUE, dummy_exception_handler); + ok(handle != NULL, "failed to register exception handler\n"); + + callback_result = 0; + /* raise an exception */ + result = ((int (*)(void))0)(); + todo_wine ok(callback_result == 1, "callback never ran\n"); + todo_wine ok(result == 0xdeadbeef, "wrong return value expected: deadbeef got: %08x\n", result); + + RemoveVectoredExceptionHandler(handle); + + ok(remove_hook(&hook), "failed to remove hook\n"); +} + +static void callback_NtCreateThread( HANDLE *handle_ptr, ACCESS_MASK access, OBJECT_ATTRIBUTES *attr, HANDLE process, + CLIENT_ID *id, CONTEXT *context, INITIAL_TEB *teb, BOOLEAN suspended ) +{ + static CLIENT_ID tmp_id; + static HANDLE tmp_handle; + + callback_result |= 1; + callback_client_id_ptr = id ? id : &tmp_id; + callback_handle_ptr = handle_ptr ? handle_ptr : &tmp_handle; + + ok(context != NULL, "context is NULL\n"); + if (context) + { + ok(context->Eax == (ULONG)dummy_thread, "context->Eax expected: %08X got: %08X\n", + (ULONG)dummy_thread, context->Eax); + ok(context->Ebx == (ULONG)0x12345678, "context->Ebx expected: 12345678 got: %08X\n", context->Ebx); + } +} + +static void callback_NtCreateThread_ret(void) +{ + callback_result |= 2; + callback_client_id = *callback_client_id_ptr; + ok(HandleToULong(callback_client_id.UniqueThread) == get_thread_id(*callback_handle_ptr), + "wrong thread id %d != %d\n", HandleToULong(callback_client_id.UniqueThread), + get_thread_id(*callback_handle_ptr)); +} + +static void test_NtCreateThread(void) +{ + BOOL is_winxp = NtCurrentTeb()->Peb->OSMajorVersion < 6; + HANDLE handle; + struct hook_state hook, hook2; + + todo_wine ok(hook_syscall_ret(&hook2, pNtCreateThread, callback_NtCreateThread_ret), "failed to hook NtCreateThread ret\n"); + ok(hook_syscall(&hook, pNtCreateThread, callback_NtCreateThread), "failed to hook NtCreateThread\n"); + + callback_result = 0; + callback_client_id_ptr = NULL; + callback_handle_ptr = NULL; + handle = CreateThread(NULL, 0, dummy_thread, (LPVOID)0x12345678, 0, NULL); + ok(handle != NULL, "CreateThread failed\n"); + ok(WaitForSingleObject(handle, 1000) == WAIT_OBJECT_0, "wait for thread failed\n"); + + if (is_winxp) + { + ok(callback_result & 1, "callback never ran\n"); + ok(callback_result & 2, "ret callback never ran\n"); + ok(callback_process_id == HandleToULong(callback_client_id.UniqueProcess), + "wrong process id expected: %d got: %d\n", callback_process_id, + HandleToULong(callback_client_id.UniqueProcess)); + ok(callback_thread_id == HandleToULong(callback_client_id.UniqueThread), + "wrong thread id expected: %d got: %d\n", callback_thread_id, + HandleToULong(callback_client_id.UniqueThread)); + } + else + { + ok(callback_result == 0, "callbacks ran but should not on major version %d\n", + NtCurrentTeb()->Peb->OSMajorVersion); + } + + CloseHandle(handle); + + ok(remove_hook(&hook), "failed to remove hook\n"); + todo_wine ok(remove_hook(&hook2), "failed to remove hook\n"); +} + +static void callback_NtGetTickCount(void) +{ + callback_result |= 1; +} + +static void test_NtGetTickCount(void) +{ + struct hook_state hook; + BYTE tmp[64]; + + /* NtGetTickCount may not be a system call. */ + ok(hook_function(&hook, pNtGetTickCount, callback_NtGetTickCount), + "failed too hook NtGetTickCount\n"); + + callback_result = 0; + pGetTickCount(); + ok(callback_result == 0, "callbacks ran during GetTickCount\n"); + + callback_result = 0; + pGetTickCount64(); + ok(callback_result == 0, "callbacks ran during GetTickCount64\n"); + + callback_result = 0; + pNtGetTickCount(); + ok(callback_result & 1, "callback never ran\n"); + + ok(remove_hook(&hook), "failed to remove hook\n"); + + memcpy(tmp, pNtGetTickCount, sizeof(tmp)); + + ok(hook_function(&hook, pGetTickCount, callback_NtGetTickCount), + "failed to hook GetTickCount\n"); + ok(memcmp(pNtGetTickCount, tmp, sizeof(tmp)) == 0, "hook modified NtGetTickCount\n"); + + callback_result = 0; + pNtGetTickCount(); + ok(callback_result == 0, "callbacks ran during NtGetTickCount\n"); + + ok(remove_hook(&hook), "failed to remove hook\n"); + + ok(hook_function(&hook, pGetTickCount64, callback_NtGetTickCount), + "failed to hook GetTickCount64\n"); + ok(memcmp(pNtGetTickCount, tmp, sizeof(tmp)) == 0, "hook modified NtGetTickCount\n"); + + callback_result = 0; + pNtGetTickCount(); + ok(callback_result == 0, "callbacks ran during NtGetTickCount\n"); + + ok(remove_hook(&hook), "failed to remove hook\n"); +} +#endif + +START_TEST(hooks) +{ +#ifdef __i386__ + HMODULE hntdll = GetModuleHandleA("ntdll.dll"); + HMODULE hkernel32 = GetModuleHandleA("kernel32.dll"); + + pLdrInitializeThunk = (void *)GetProcAddress(hntdll, "LdrInitializeThunk"); + pNtContinue = (void *)GetProcAddress(hntdll, "NtContinue"); + pNtCreateThread = (void *)GetProcAddress(hntdll, "NtCreateThread"); + pNtGetTickCount = (void *)GetProcAddress(hntdll, "NtGetTickCount"); + pNtQueryInformationThread = (void *)GetProcAddress(hntdll, "NtQueryInformationThread"); + + pGetTickCount = (void *)GetProcAddress(hkernel32, "GetTickCount"); + pGetTickCount64 = (void *)GetProcAddress(hkernel32, "GetTickCount64"); + + test_LdrInitializeThunk(); + test_NtContinue(); + test_NtCreateThread(); + test_NtGetTickCount(); +#endif +}