From: Paul Gofman pgofman@codeweavers.com
--- dlls/kernel32/tests/module.c | 73 +++++++++++++++++++++++++----------- dlls/ntdll/loader.c | 51 +++++++++++++++++++------ 2 files changed, 92 insertions(+), 32 deletions(-)
diff --git a/dlls/kernel32/tests/module.c b/dlls/kernel32/tests/module.c index 4b4269c821b..93878a9bd22 100644 --- a/dlls/kernel32/tests/module.c +++ b/dlls/kernel32/tests/module.c @@ -141,6 +141,34 @@ static void create_test_dll( const char *name ) CloseHandle( handle ); }
+static BOOL is_old_loader_struct(void) +{ + LDR_DATA_TABLE_ENTRY *mod, *mod2; + LDR_DDAG_NODE *ddag_node; + NTSTATUS status; + HMODULE hexe; + + /* Check for old LDR data strcuture. */ + hexe = GetModuleHandleW( NULL ); + ok( !!hexe, "Got NULL exe handle.\n" ); + status = LdrFindEntryForAddress( hexe, &mod ); + ok( !status, "got %#lx.\n", status ); + if (!(ddag_node = mod->DdagNode)) + { + win_skip( "DdagNode is NULL, skipping tests.\n" ); + return TRUE; + } + ok( !!ddag_node->Modules.Flink, "Got NULL module link.\n" ); + mod2 = CONTAINING_RECORD(ddag_node->Modules.Flink, LDR_DATA_TABLE_ENTRY, NodeModuleLink); + ok( mod2 == mod || broken( (void **)mod2 == (void **)mod - 1 ), "got %p, expected %p.\n", mod2, mod ); + if (mod2 != mod) + { + win_skip( "Old LDR_DATA_TABLE_ENTRY structure, skipping tests.\n" ); + return TRUE; + } + return FALSE; +} + static void testGetModuleFileName(const char* name) { HMODULE hMod; @@ -1641,7 +1669,10 @@ static void test_tls_links(void) TEB *teb = NtCurrentTeb(), *thread_teb; THREAD_BASIC_INFORMATION tbi; NTSTATUS status; + ULONG i, count; HANDLE thread; + SIZE_T size; + void **ptr;
ok(!!teb->ThreadLocalStoragePointer, "got NULL.\n");
@@ -1661,6 +1692,26 @@ static void test_tls_links(void) ResumeThread(thread); WaitForSingleObject(test_tls_links_started, INFINITE);
+ if (!is_old_loader_struct()) + { + ptr = teb->ThreadLocalStoragePointer; + count = (ULONG_PTR)ptr[-2]; + size = HeapSize(GetProcessHeap(), 0, ptr - 2); + ok(size == (count + 2) * sizeof(void *), "got count %lu, size %Iu.\n", count, size); + + for (i = 0; i < count; ++i) + { + if (!ptr[i]) continue; + size = HeapSize(GetProcessHeap(), 0, (void **)ptr[i] - 2); + ok(size && size < 100000, "got %Iu.\n", size); + } + + ptr = thread_teb->ThreadLocalStoragePointer; + count = (ULONG_PTR)ptr[-2]; + size = HeapSize(GetProcessHeap(), 0, ptr - 2); + ok(size == (count + 2) * sizeof(void *), "got count %lu, size %Iu.\n", count, size); + } + ok(!!thread_teb->ThreadLocalStoragePointer, "got NULL.\n"); ok(!teb->TlsLinks.Flink, "got %p.\n", teb->TlsLinks.Flink); ok(!teb->TlsLinks.Blink, "got %p.\n", teb->TlsLinks.Blink); @@ -1714,29 +1765,9 @@ static void test_base_address_index_tree(void) unsigned int tree_count, list_count = 0; LDR_DATA_TABLE_ENTRY *mod, *mod2; RTL_BALANCED_NODE *root, *node; - LDR_DDAG_NODE *ddag_node; - NTSTATUS status; - HMODULE hexe; char *base;
- /* Check for old LDR data strcuture. */ - hexe = GetModuleHandleW( NULL ); - ok( !!hexe, "Got NULL exe handle.\n" ); - status = LdrFindEntryForAddress( hexe, &mod ); - ok( !status, "got %#lx.\n", status ); - if (!(ddag_node = mod->DdagNode)) - { - win_skip( "DdagNode is NULL, skipping tests.\n" ); - return; - } - ok( !!ddag_node->Modules.Flink, "Got NULL module link.\n" ); - mod2 = CONTAINING_RECORD(ddag_node->Modules.Flink, LDR_DATA_TABLE_ENTRY, NodeModuleLink); - ok( mod2 == mod || broken( (void **)mod2 == (void **)mod - 1 ), "got %p, expected %p.\n", mod2, mod ); - if (mod2 != mod) - { - win_skip( "Old LDR_DATA_TABLE_ENTRY structure, skipping tests.\n" ); - return; - } + if (is_old_loader_struct()) return;
mod = CONTAINING_RECORD(first->Flink, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks); ok( mod->BaseAddressIndexNode.ParentValue || mod->BaseAddressIndexNode.Left || mod->BaseAddressIndexNode.Right, diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c index aa3ac83cbb5..6322aff1711 100644 --- a/dlls/ntdll/loader.c +++ b/dlls/ntdll/loader.c @@ -1324,6 +1324,36 @@ static BOOL is_dll_native_subsystem( LDR_DATA_TABLE_ENTRY *mod, const IMAGE_NT_H return TRUE; }
+ +/************************************************************************* + * alloc_tls_memory + * + * Allocate memory for TLS vector or index with an extra data. + */ +static void *alloc_tls_memory( BOOL vector, ULONG_PTR size ) +{ + ULONG_PTR *ptr; + + if (!(ptr = RtlAllocateHeap( GetProcessHeap(), vector ? HEAP_ZERO_MEMORY : 0, size + sizeof(void *) * 2 ))) return NULL; + ptr += 2; + if (vector) ptr[-2] = size / sizeof(void *); + else ptr[-2] = ptr[-1] = 0; + return ptr; +} + + +/************************************************************************* + * free_tls_memory + * + * Free TLS vector or index memory. + */ +static void free_tls_memory( void *ptr ) +{ + if (!ptr) return; + RtlFreeHeap( GetProcessHeap(), 0, (void **)ptr - 2 ); +} + + /************************************************************************* * alloc_tls_slot * @@ -1390,13 +1420,13 @@ static BOOL alloc_tls_slot( LDR_DATA_TABLE_ENTRY *mod )
t->ThreadData[j].Flags = 0;
- if (!(new_ptr = RtlAllocateHeap( GetProcessHeap(), 0, size + dir->SizeOfZeroFill ))) return FALSE; + if (!(new_ptr = alloc_tls_memory( FALSE, size + dir->SizeOfZeroFill ))) return FALSE; memcpy( new_ptr, (void *)dir->StartAddressOfRawData, size ); memset( (char *)new_ptr + size, 0, dir->SizeOfZeroFill );
if (t->OperationType == ProcessTlsReplaceVector) { - vector = RtlAllocateHeap( GetProcessHeap(), HEAP_ZERO_MEMORY, tls_module_count * sizeof(*vector) ); + vector = alloc_tls_memory( TRUE, tls_module_count * sizeof(*vector) ); if (!vector) return FALSE; t->ThreadData[j].TlsVector = vector; vector[i] = new_ptr; @@ -1416,12 +1446,12 @@ static BOOL alloc_tls_slot( LDR_DATA_TABLE_ENTRY *mod ) { /* There could be fewer active threads than we counted here due to force terminated threads, first * free extra TLS directory data set in the new TLS vector. */ - RtlFreeHeap( GetProcessHeap(), 0, ((void **)t->ThreadData[j].TlsVector)[i] ); + free_tls_memory( ((void **)t->ThreadData[j].TlsVector)[i] ); } if (!(t->ThreadData[j].Flags & THREAD_TLS_INFORMATION_ASSIGNED) || t->OperationType == ProcessTlsReplaceIndex) { /* FIXME: can't free old Tls vector here, should be freed at thread exit. */ - RtlFreeHeap( GetProcessHeap(), 0, t->ThreadData[j].TlsVector ); + free_tls_memory( t->ThreadData[j].TlsVector ); } } RtlFreeHeap( GetProcessHeap(), 0, t ); @@ -1633,8 +1663,7 @@ static NTSTATUS alloc_thread_tls(void) void **pointers; UINT i, size;
- if (!(pointers = RtlAllocateHeap( GetProcessHeap(), HEAP_ZERO_MEMORY, - tls_module_count * sizeof(*pointers) ))) + if (!(pointers = alloc_tls_memory( TRUE, tls_module_count * sizeof(*pointers) ))) return STATUS_NO_MEMORY;
for (i = 0; i < tls_module_count; i++) @@ -1645,10 +1674,10 @@ static NTSTATUS alloc_thread_tls(void) size = dir->EndAddressOfRawData - dir->StartAddressOfRawData; if (!size && !dir->SizeOfZeroFill) continue;
- if (!(pointers[i] = RtlAllocateHeap( GetProcessHeap(), 0, size + dir->SizeOfZeroFill ))) + if (!(pointers[i] = alloc_tls_memory( FALSE, size + dir->SizeOfZeroFill ))) { - while (i) RtlFreeHeap( GetProcessHeap(), 0, pointers[--i] ); - RtlFreeHeap( GetProcessHeap(), 0, pointers ); + while (i) free_tls_memory( pointers[--i] ); + free_tls_memory( pointers ); return STATUS_NO_MEMORY; } memcpy( pointers[i], (void *)dir->StartAddressOfRawData, size ); @@ -3946,8 +3975,8 @@ void WINAPI LdrShutdownThread(void) if (NtCurrentTeb()->Instrumentation[0]) ((TEB *)NtCurrentTeb()->Instrumentation[0])->ThreadLocalStoragePointer = NULL; #endif - for (i = 0; i < tls_module_count; i++) RtlFreeHeap( GetProcessHeap(), 0, pointers[i] ); - RtlFreeHeap( GetProcessHeap(), 0, pointers ); + for (i = 0; i < tls_module_count; i++) free_tls_memory( pointers[i] ); + free_tls_memory( pointers ); } RtlProcessFlsData( NtCurrentTeb()->FlsSlots, 2 ); NtCurrentTeb()->FlsSlots = NULL;