From: Paul Gofman pgofman@codeweavers.com
--- dlls/ntdll/tests/info.c | 274 ++++++++++++++++++++++++++++++++++++++++ include/winternl.h | 35 +++++ 2 files changed, 309 insertions(+)
diff --git a/dlls/ntdll/tests/info.c b/dlls/ntdll/tests/info.c index b2a373ce1b3..d379ad5c152 100644 --- a/dlls/ntdll/tests/info.c +++ b/dlls/ntdll/tests/info.c @@ -4080,6 +4080,279 @@ static void test_processor_idle_cycle_time(void) ok( size == cpu_count * sizeof(*buffer), "got %#lx.\n", size ); }
+static DWORD WINAPI test_set_process_tls_info_thread(void *param) +{ + return 0; +} + +static void test_set_process_tls_info(void) +{ + THREAD_BASIC_INFORMATION tbi; + TEB *teb = NtCurrentTeb(), *thread_teb; + char buffer[1024]; + PROCESS_TLS_INFORMATION *tlsinfo = (PROCESS_TLS_INFORMATION *)buffer; + void **save_tls_pointers[2]; + void *tls_pointer[4], *new_tls_pointer[4], *tls_pointer2[4], *new_tls_pointer2[4]; + unsigned int i; + DWORD thread_id, curr_thread_id = GetCurrentThreadId(); + NTSTATUS status; + HANDLE thread; + BOOL wow = is_wow64 && !old_wow64; + + thread = CreateThread( NULL, 0, test_set_process_tls_info_thread, NULL, CREATE_SUSPENDED, &thread_id ); + do + { + /* workaround currently present Wine bug when thread teb may be not available immediately + * after creating a thread before it is initialized on the Unix side. */ + Sleep( 1 ); + status = NtQueryInformationThread( thread, ThreadBasicInformation, &tbi, sizeof(tbi), NULL ); + ok( !status, "got %#lx.\n", status ); + } while (!(thread_teb = tbi.TebBaseAddress)); + ok( !thread_teb->ThreadLocalStoragePointer, "got %p.\n", thread_teb->ThreadLocalStoragePointer ); + + save_tls_pointers[0] = teb->ThreadLocalStoragePointer; + save_tls_pointers[1] = thread_teb->ThreadLocalStoragePointer; + + for (i = 0; i < ARRAY_SIZE(tls_pointer); ++i) + { + tls_pointer[i] = (void *)(ULONG_PTR)(i + 1); + new_tls_pointer[i] = (void *)(ULONG_PTR)(i + 20); + } + teb->ThreadLocalStoragePointer = tls_pointer; + + /* This flag probably requests WOW64 teb update. */ + tlsinfo->Flags = 1; + tlsinfo->ThreadDataCount = 1; + tlsinfo->OperationType = ProcessTlsReplaceVector; + tlsinfo->TlsVectorLength = 2; + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->ThreadData[0].ThreadId = thread_id; + tlsinfo->ThreadData[0].TlsVector = new_tls_pointer; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + if (wow) + { + todo_wine ok( !status, "got %#lx.\n", status ); + ok( tlsinfo->Flags == 1, "got %#lx.\n", tlsinfo->Flags ); + todo_wine ok( tlsinfo->ThreadData[0].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + todo_wine ok( tlsinfo->ThreadData[0].ThreadId == curr_thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + } + else + { + todo_wine ok( status == STATUS_INVALID_PARAMETER, "got %#lx.\n", status ); + ok( tlsinfo->Flags == 1, "got %#lx.\n", tlsinfo->Flags ); + ok( !tlsinfo->ThreadData[0].Flags, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( tlsinfo->ThreadData[0].ThreadId == thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + } + if (status == STATUS_NOT_IMPLEMENTED) return; + + /* Other PROCESS_TLS_INFORMATION flags are invalid. STATUS_INFO_LENGTH_MISMATCH is weird but that's for any flag + * besides 1. */ + tlsinfo->Flags = 2; + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->ThreadData[0].ThreadId = thread_id; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_INFO_LENGTH_MISMATCH, "got %#lx.\n", status ); + ok( tlsinfo->Flags == 2, "got %#lx.\n", tlsinfo->Flags ); + ok( !tlsinfo->ThreadData[0].Flags, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( tlsinfo->ThreadData[0].ThreadId == thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + + /* Nonzero THREAD_TLS_INFORMATION flags on input are invalid. */ + tlsinfo->Flags = 0; + tlsinfo->ThreadData[0].Flags = 1; + tlsinfo->ThreadData[0].ThreadId = thread_id; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_INVALID_PARAMETER, "got %#lx.\n", status ); + ok( !tlsinfo->Flags, "got %#lx.\n", tlsinfo->Flags ); + ok( tlsinfo->ThreadData[0].Flags == 1, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( tlsinfo->ThreadData[0].ThreadId == thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->OperationType = MaxProcessTlsOperation; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + /* Unknown operation type. */ + ok( status == STATUS_INFO_LENGTH_MISMATCH || status == STATUS_INVALID_PARAMETER, "got %#lx.\n", status ); + + tlsinfo->OperationType = ProcessTlsReplaceVector; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount]) + 8); + /* Larger data size. */ + ok( (!wow && status == STATUS_INFO_LENGTH_MISMATCH) || (wow && !status), "got %#lx.\n", status ); + + /* Zero thread count. */ + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->ThreadDataCount = 0; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_INFO_LENGTH_MISMATCH, "got %#lx.\n", status ); + ok( !tlsinfo->Flags, "got %#lx.\n", tlsinfo->Flags ); + + tlsinfo->ThreadDataCount = 1; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_SUCCESS, "got %#lx.\n", status ); + ok( !tlsinfo->Flags, "got %#lx.\n", tlsinfo->Flags ); + ok( tlsinfo->ThreadData[0].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + /* ThreadId is output parameter, ignored on input and contains the thread where the data were assigned on + * output. */ + ok( tlsinfo->ThreadData[0].ThreadId == curr_thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + /* TlsVector contains the repaced vector on output. */ + ok( tlsinfo->ThreadData[0].TlsVector == tls_pointer, "got %p.\n", tlsinfo->ThreadData[0].TlsVector ); + ok( teb->ThreadLocalStoragePointer == new_tls_pointer, "wrong vector.\n" ); + for (i = 0; i < ARRAY_SIZE(tls_pointer); ++i) + { + ok( tls_pointer[i] == (void *)(ULONG_PTR)(i + 1), "got %p.\n", tls_pointer ); + /* TlsVectorLength pointers are copied from the old vector to the new one. */ + if (i < 2) + ok( new_tls_pointer[i] == (void *)(ULONG_PTR)(i + 1), "got %p.\n", tls_pointer ); + else + ok( new_tls_pointer[i] == (void *)(ULONG_PTR)(i + 20), "got %p.\n", tls_pointer ); + } + + teb->ThreadLocalStoragePointer = NULL; + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->ThreadData[0].TlsVector = new_tls_pointer; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_SUCCESS, "got %#lx.\n", status ); + /* Threads with NULL ThreadLocalStoragePointer are ignored. */ + ok( !tlsinfo->Flags, "got %#lx.\n", tlsinfo->Flags ); + ok( !tlsinfo->ThreadData[0].Flags, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( tlsinfo->ThreadData[0].TlsVector == new_tls_pointer, "got %p.\n", tlsinfo->ThreadData[0].TlsVector ); + + memcpy( tls_pointer2, tls_pointer, sizeof(tls_pointer2) ); + thread_teb->ThreadLocalStoragePointer = tls_pointer2; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_SUCCESS, "got %#lx.\n", status ); + ok( tlsinfo->ThreadData[0].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( thread_teb->ThreadLocalStoragePointer == new_tls_pointer, "wrong vector.\n" ); + ok( tlsinfo->ThreadData[0].ThreadId == thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + ok( tlsinfo->ThreadData[0].TlsVector == tls_pointer2, "got %p.\n", tlsinfo->ThreadData[0].TlsVector ); + + /* Two eligible threads, data for only one are provided, that succeeds. */ + teb->ThreadLocalStoragePointer = tls_pointer; + thread_teb->ThreadLocalStoragePointer = tls_pointer2; + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->ThreadData[0].TlsVector = new_tls_pointer; + thread_teb->ThreadLocalStoragePointer = tls_pointer2; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_SUCCESS, "got %#lx.\n", status ); + ok( tlsinfo->ThreadDataCount == 1, "got %#lx.\n", tlsinfo->ThreadDataCount ); + ok( tlsinfo->ThreadData[0].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( teb->ThreadLocalStoragePointer == new_tls_pointer, "wrong vector.\n" ); + ok( tlsinfo->ThreadData[0].ThreadId == curr_thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + ok( tlsinfo->ThreadData[0].TlsVector == tls_pointer, "got %p.\n", tlsinfo->ThreadData[0].TlsVector ); + ok( thread_teb->ThreadLocalStoragePointer == tls_pointer2, "wrong vector.\n" ); + + /* Set for both threads at once as probably intended. Provide an extra data for the missing third thread + * which won't be used. */ + teb->ThreadLocalStoragePointer = tls_pointer; + thread_teb->ThreadLocalStoragePointer = tls_pointer2; + memcpy( new_tls_pointer2, new_tls_pointer, sizeof(new_tls_pointer2) ); + tlsinfo->ThreadDataCount = 3; + tlsinfo->ThreadData[0].TlsVector = new_tls_pointer; + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->ThreadData[1].TlsVector = new_tls_pointer2; + tlsinfo->ThreadData[1].Flags = 0; + tlsinfo->ThreadData[2].TlsVector = (void *)0xdeadbeef; + tlsinfo->ThreadData[2].Flags = 0; + tlsinfo->ThreadData[2].ThreadId = 0xdeadbeef; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_SUCCESS, "got %#lx.\n", status ); + ok( !tlsinfo->Flags, "got %#lx.\n", tlsinfo->Flags ); + ok( tlsinfo->ThreadDataCount == 3, "got %#lx.\n", tlsinfo->ThreadDataCount ); + ok( tlsinfo->ThreadData[0].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( teb->ThreadLocalStoragePointer == new_tls_pointer, "wrong vector.\n" ); + ok( tlsinfo->ThreadData[0].ThreadId == curr_thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + ok( tlsinfo->ThreadData[0].TlsVector == tls_pointer, "got %p.\n", tlsinfo->ThreadData[0].TlsVector ); + ok( tlsinfo->ThreadData[1].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[1].Flags ); + ok( teb->ThreadLocalStoragePointer == new_tls_pointer, "wrong vector.\n" ); + ok( tlsinfo->ThreadData[1].ThreadId == thread_id, "got %#Ix.\n", tlsinfo->ThreadData[1].ThreadId ); + ok( tlsinfo->ThreadData[1].TlsVector == tls_pointer2, "got %p.\n", tlsinfo->ThreadData[1].TlsVector ); + ok( !tlsinfo->ThreadData[2].Flags, "got %#lx.\n", tlsinfo->ThreadData[2].Flags ); + ok( tlsinfo->ThreadData[2].ThreadId == 0xdeadbeef, "got %#Ix.\n", tlsinfo->ThreadData[2].ThreadId ); + ok( tlsinfo->ThreadData[2].TlsVector == (void *)0xdeadbeef, "got %p.\n", tlsinfo->ThreadData[2].TlsVector ); + + /* Test with unaccessible data. */ + tlsinfo->ThreadData[0].TlsVector = new_tls_pointer; + tlsinfo->ThreadData[0].ThreadId = 0; + tlsinfo->ThreadData[0].Flags = 0; + tlsinfo->ThreadData[1].TlsVector = new_tls_pointer2; + tlsinfo->ThreadData[1].ThreadId = 0; + tlsinfo->ThreadData[1].Flags = 0; + teb->ThreadLocalStoragePointer = tls_pointer; + thread_teb->ThreadLocalStoragePointer = (void *)0xdeadbee0; + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_ACCESS_VIOLATION, "got %#lx.\n", status ); + ok( !tlsinfo->Flags, "got %#lx.\n", tlsinfo->Flags ); + ok( tlsinfo->ThreadDataCount == 3, "got %#lx.\n", tlsinfo->ThreadDataCount ); + if (wow) + { + ok( !tlsinfo->ThreadData[0].Flags, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( !tlsinfo->ThreadData[0].ThreadId, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + ok( tlsinfo->ThreadData[0].TlsVector == new_tls_pointer, "got %p.\n", tlsinfo->ThreadData[0].TlsVector ); + } + else + { + ok( tlsinfo->ThreadData[0].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( tlsinfo->ThreadData[0].ThreadId == curr_thread_id, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + ok( tlsinfo->ThreadData[0].TlsVector == tls_pointer, "got %p.\n", tlsinfo->ThreadData[0].TlsVector ); + } + ok( teb->ThreadLocalStoragePointer == new_tls_pointer, "wrong vector.\n" ); + ok( !tlsinfo->ThreadData[1].Flags, "got %#lx.\n", tlsinfo->ThreadData[1].Flags ); + ok( teb->ThreadLocalStoragePointer == new_tls_pointer, "wrong vector.\n" ); + ok( !tlsinfo->ThreadData[1].ThreadId, "got %#Ix.\n", tlsinfo->ThreadData[1].ThreadId ); + ok( tlsinfo->ThreadData[1].TlsVector == new_tls_pointer2, "got %p.\n", tlsinfo->ThreadData[1].TlsVector ); + ok( !tlsinfo->ThreadData[2].Flags, "got %#lx.\n", tlsinfo->ThreadData[2].Flags ); + ok( tlsinfo->ThreadData[2].ThreadId == 0xdeadbeef, "got %#Ix.\n", tlsinfo->ThreadData[2].ThreadId ); + ok( tlsinfo->ThreadData[2].TlsVector == (void *)0xdeadbeef, "got %p.\n", tlsinfo->ThreadData[2].TlsVector ); + + /* Test replacing TLS index. */ + teb->ThreadLocalStoragePointer = new_tls_pointer; + thread_teb->ThreadLocalStoragePointer = new_tls_pointer2; + new_tls_pointer[1] = (void *)0xcccccccc; + new_tls_pointer2[1] = (void *)0xdddddddd; + tlsinfo->ThreadDataCount = 3; + tlsinfo->OperationType = ProcessTlsReplaceIndex; + tlsinfo->TlsIndex = 1; + for (i = 0; i < 3; ++i) + { + tlsinfo->ThreadData[i].Flags = 0; + tlsinfo->ThreadData[i].ThreadId = 0xdeadbeef; + tlsinfo->ThreadData[i].TlsModulePointer = (void *)((ULONG_PTR)i + 1); + } + status = NtSetInformationProcess( GetCurrentProcess(), ProcessTlsInformation, tlsinfo, + offsetof(PROCESS_TLS_INFORMATION, ThreadData[tlsinfo->ThreadDataCount])); + ok( status == STATUS_SUCCESS, "got %#lx.\n", status ); + ok( !tlsinfo->Flags, "got %#lx.\n", tlsinfo->Flags ); + ok( tlsinfo->ThreadDataCount == 3, "got %#lx.\n", tlsinfo->ThreadDataCount ); + ok( tlsinfo->ThreadData[0].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[0].Flags ); + ok( (ULONG_PTR)new_tls_pointer[1] == 1, "got %p.\n", new_tls_pointer[1] ); + ok( tlsinfo->ThreadData[0].ThreadId == 0xdeadbeef, "got %#Ix.\n", tlsinfo->ThreadData[0].ThreadId ); + ok( (ULONG_PTR)tlsinfo->ThreadData[0].TlsModulePointer == 0xcccccccc, "got %p.\n", tlsinfo->ThreadData[0].TlsModulePointer ); + ok( tlsinfo->ThreadData[1].Flags == THREAD_TLS_INFORMATION_ASSIGNED, "got %#lx.\n", tlsinfo->ThreadData[1].Flags ); + ok( (ULONG_PTR)new_tls_pointer2[1] == 2, "got %p.\n", new_tls_pointer2[1] ); + ok( tlsinfo->ThreadData[1].ThreadId == 0xdeadbeef, "got %#Ix.\n", tlsinfo->ThreadData[1].ThreadId ); + ok( (ULONG_PTR)tlsinfo->ThreadData[1].TlsModulePointer == 0xdddddddd, "got %p.\n", tlsinfo->ThreadData[1].TlsModulePointer ); + ok( !tlsinfo->ThreadData[2].Flags, "got %#lx.\n", tlsinfo->ThreadData[2].Flags ); + ok( tlsinfo->ThreadData[2].ThreadId == 0xdeadbeef, "got %#Ix.\n", tlsinfo->ThreadData[2].ThreadId ); + ok( (ULONG_PTR)tlsinfo->ThreadData[2].TlsModulePointer == 3, "got %p.\n", tlsinfo->ThreadData[2].TlsModulePointer ); + + /* Restore original TLS data. */ + teb->ThreadLocalStoragePointer = save_tls_pointers[0]; + thread_teb->ThreadLocalStoragePointer = save_tls_pointers[1]; + ResumeThread( thread ); + WaitForSingleObject( thread, INFINITE ); + CloseHandle( thread ); +} + START_TEST(info) { char **argv; @@ -4160,4 +4433,5 @@ START_TEST(info) test_process_token(argc, argv); test_process_id(); test_processor_idle_cycle_time(); + test_set_process_tls_info(); } diff --git a/include/winternl.h b/include/winternl.h index 4f24a921bb1..0f3abe1ccfb 100644 --- a/include/winternl.h +++ b/include/winternl.h @@ -2653,6 +2653,41 @@ typedef struct _PROCESS_CYCLE_TIME_INFORMATION { ULONGLONG CurrentCycleCount; } PROCESS_CYCLE_TIME_INFORMATION, *PPROCESS_CYCLE_TIME_INFORMATION;
+typedef struct _THREAD_TLS_INFORMATION +{ + ULONG Flags; + union + { + void *TlsVector; + void *TlsModulePointer; + }; + ULONG_PTR ThreadId; +} THREAD_TLS_INFORMATION, * PTHREAD_TLS_INFORMATION; + +#define THREAD_TLS_INFORMATION_ASSIGNED 0x2 + +typedef enum _PROCESS_TLS_INFORMATION_TYPE +{ + ProcessTlsReplaceIndex, + ProcessTlsReplaceVector, + MaxProcessTlsOperation +} PROCESS_TLS_INFORMATION_TYPE, *PPROCESS_TLS_INFORMATION_TYPE; + +typedef struct _PROCESS_TLS_INFORMATION +{ + ULONG Flags; + ULONG OperationType; + ULONG ThreadDataCount; + union + { + ULONG TlsIndex; + ULONG TlsVectorLength; + }; + THREAD_TLS_INFORMATION ThreadData[1]; +} PROCESS_TLS_INFORMATION, *PPROCESS_TLS_INFORMATION; + +#define PROCESS_TLS_INFORMATION_WOW64 1 + typedef struct _PROCESS_STACK_ALLOCATION_INFORMATION { SIZE_T ReserveSize;