From: Paul Gofman pgofman@codeweavers.com
--- dlls/kernel32/tests/module.c | 93 ++++++++++++++++++++++++++++++++++++ dlls/ntdll/loader.c | 35 ++++++++++++++ 2 files changed, 128 insertions(+)
diff --git a/dlls/kernel32/tests/module.c b/dlls/kernel32/tests/module.c index 6b576ea25fd..4b4269c821b 100644 --- a/dlls/kernel32/tests/module.c +++ b/dlls/kernel32/tests/module.c @@ -1674,6 +1674,98 @@ static void test_tls_links(void) CloseHandle(test_tls_links_done); }
+ +static RTL_BALANCED_NODE *rtl_node_parent( RTL_BALANCED_NODE *node ) +{ + return (RTL_BALANCED_NODE *)(node->ParentValue & ~(ULONG_PTR)RTL_BALANCED_NODE_RESERVED_PARENT_MASK); +} + +static unsigned int check_address_index_tree( RTL_BALANCED_NODE *node ) +{ + LDR_DATA_TABLE_ENTRY *mod; + unsigned int count; + char *base; + + if (!node) return 0; + ok( (node->ParentValue & RTL_BALANCED_NODE_RESERVED_PARENT_MASK) <= 1, "got ParentValue %#Ix.\n", + node->ParentValue ); + + mod = CONTAINING_RECORD(node, LDR_DATA_TABLE_ENTRY, BaseAddressIndexNode); + base = mod->DllBase; + if (node->Left) + { + mod = CONTAINING_RECORD(node->Left, LDR_DATA_TABLE_ENTRY, BaseAddressIndexNode); + ok( (char *)mod->DllBase < base, "wrong ordering.\n" ); + } + if (node->Right) + { + mod = CONTAINING_RECORD(node->Right, LDR_DATA_TABLE_ENTRY, BaseAddressIndexNode); + ok( (char *)mod->DllBase > base, "wrong ordering.\n" ); + } + + count = check_address_index_tree( node->Left ); + count += check_address_index_tree( node->Right ); + return count + 1; +} + +static void test_base_address_index_tree(void) +{ + LIST_ENTRY *first = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList; + unsigned int tree_count, list_count = 0; + LDR_DATA_TABLE_ENTRY *mod, *mod2; + RTL_BALANCED_NODE *root, *node; + LDR_DDAG_NODE *ddag_node; + NTSTATUS status; + HMODULE hexe; + char *base; + + /* Check for old LDR data strcuture. */ + hexe = GetModuleHandleW( NULL ); + ok( !!hexe, "Got NULL exe handle.\n" ); + status = LdrFindEntryForAddress( hexe, &mod ); + ok( !status, "got %#lx.\n", status ); + if (!(ddag_node = mod->DdagNode)) + { + win_skip( "DdagNode is NULL, skipping tests.\n" ); + return; + } + ok( !!ddag_node->Modules.Flink, "Got NULL module link.\n" ); + mod2 = CONTAINING_RECORD(ddag_node->Modules.Flink, LDR_DATA_TABLE_ENTRY, NodeModuleLink); + ok( mod2 == mod || broken( (void **)mod2 == (void **)mod - 1 ), "got %p, expected %p.\n", mod2, mod ); + if (mod2 != mod) + { + win_skip( "Old LDR_DATA_TABLE_ENTRY structure, skipping tests.\n" ); + return; + } + + mod = CONTAINING_RECORD(first->Flink, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks); + ok( mod->BaseAddressIndexNode.ParentValue || mod->BaseAddressIndexNode.Left || mod->BaseAddressIndexNode.Right, + "got zero BaseAddressIndexNode.\n" ); + root = &mod->BaseAddressIndexNode; + while (rtl_node_parent( root )) + root = rtl_node_parent( root ); + tree_count = check_address_index_tree( root ); + for (LIST_ENTRY *entry = first->Flink; entry != first; entry = entry->Flink) + { + ++list_count; + mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks); + base = mod->DllBase; + node = root; + mod2 = NULL; + while (1) + { + ok( !!node, "got NULL.\n" ); + if (!node) break; + mod2 = CONTAINING_RECORD(node, LDR_DATA_TABLE_ENTRY, BaseAddressIndexNode); + if (base == (char *)mod2->DllBase) break; + if (base < (char *)mod2->DllBase) node = node->Left; + else node = node->Right; + } + ok( base == (char *)mod2->DllBase, "module %s not found.\n", debugstr_w(mod->BaseDllName.Buffer) ); + } + ok( tree_count == list_count, "count mismatch %u, %u.\n", tree_count, list_count ); +} + START_TEST(module) { WCHAR filenameW[MAX_PATH]; @@ -1711,4 +1803,5 @@ START_TEST(module) test_apisets(); test_ddag_node(); test_tls_links(); + test_base_address_index_tree(); } diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c index 24652d5a663..57b33c0864a 100644 --- a/dlls/ntdll/loader.c +++ b/dlls/ntdll/loader.c @@ -175,6 +175,8 @@ static PEB_LDR_DATA ldr = { &ldr.InInitializationOrderModuleList, &ldr.InInitializationOrderModuleList } };
+static RTL_RB_TREE base_address_index_tree; + static RTL_BITMAP tls_bitmap; static RTL_BITMAP tls_expansion_bitmap;
@@ -237,6 +239,24 @@ static void module_push_unload_trace( const WINE_MODREF *wm ) unload_trace_ptr = unload_traces; }
+static int rtl_rb_tree_put( RTL_RB_TREE *tree, const void *key, RTL_BALANCED_NODE *entry, + int (*compare_func)( const void *key, const RTL_BALANCED_NODE *entry )) +{ + RTL_BALANCED_NODE *parent = tree->root; + BOOLEAN right = 0; + int c; + + while (parent) + { + if (!(c = compare_func( key, parent ))) return -1; + right = c > 0; + if (!parent->Children[right]) break; + parent = parent->Children[right]; + } + RtlRbInsertNodeEx( tree, parent, right, entry ); + return 0; +} + #ifdef __arm64ec__
static void update_hybrid_pointer( void *module, const IMAGE_SECTION_HEADER *sec, UINT rva, void *ptr ) @@ -558,6 +578,17 @@ static void call_ldr_notifications( ULONG reason, LDR_DATA_TABLE_ENTRY *module ) } }
+/* compare base address */ +static int base_address_compare( const void *key, const RTL_BALANCED_NODE *entry ) +{ + const LDR_DATA_TABLE_ENTRY *mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, BaseAddressIndexNode); + const char *base = key; + + if (base < (char *)mod->DllBase) return -1; + if (base > (char *)mod->DllBase) return 1; + return 0; +} + /************************************************************************* * get_modref * @@ -1559,6 +1590,8 @@ static WINE_MODREF *alloc_module( HMODULE hModule, const UNICODE_STRING *nt_name &wm->ldr.InLoadOrderLinks); InsertTailList(&NtCurrentTeb()->Peb->LdrData->InMemoryOrderModuleList, &wm->ldr.InMemoryOrderLinks); + if (rtl_rb_tree_put( &base_address_index_tree, wm->ldr.DllBase, &wm->ldr.BaseAddressIndexNode, base_address_compare )) + ERR( "rtl_rb_tree_put failed.\n" ); /* wait until init is called for inserting into InInitializationOrderModuleList */
if (!(nt->OptionalHeader.DllCharacteristics & IMAGE_DLLCHARACTERISTICS_NX_COMPAT)) @@ -2255,6 +2288,7 @@ static NTSTATUS build_module( LPCWSTR load_path, const UNICODE_STRING *nt_name, /* the module has only be inserted in the load & memory order lists */ RemoveEntryList(&wm->ldr.InLoadOrderLinks); RemoveEntryList(&wm->ldr.InMemoryOrderLinks); + RtlRbRemoveNode( &base_address_index_tree, &wm->ldr.BaseAddressIndexNode );
/* FIXME: there are several more dangling references * left. Including dlls loaded by this dll before the @@ -3918,6 +3952,7 @@ static void free_modref( WINE_MODREF *wm )
RemoveEntryList(&wm->ldr.InLoadOrderLinks); RemoveEntryList(&wm->ldr.InMemoryOrderLinks); + RtlRbRemoveNode( &base_address_index_tree, &wm->ldr.BaseAddressIndexNode ); if (wm->ldr.InInitializationOrderLinks.Flink) RemoveEntryList(&wm->ldr.InInitializationOrderLinks);