From: Paul Gofman pgofman@codeweavers.com
--- dlls/kernel32/tests/module.c | 9 ++++-- dlls/ntdll/loader.c | 53 ++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-)
diff --git a/dlls/kernel32/tests/module.c b/dlls/kernel32/tests/module.c index 5bdf5d9f1ee..2f465f4be6d 100644 --- a/dlls/kernel32/tests/module.c +++ b/dlls/kernel32/tests/module.c @@ -1675,13 +1675,15 @@ static void test_known_dlls_load(void) h = LoadLibraryA( dll ); ret = pSetDefaultDllDirectories( LOAD_LIBRARY_SEARCH_DEFAULT_DIRS ); ok( ret, "SetDefaultDllDirectories failed err %lu\n", GetLastError() ); - todo_wine ok( !!h, "Got NULL.\n" ); + ok( !!h, "Got NULL.\n" ); + check_dll_path( h, system_path ); hapiset = GetModuleHandleA( apiset_dll ); ok( hapiset == h, "Got %p, %p.\n", hapiset, h ); FreeLibrary( h );
h = LoadLibraryExA( dll, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR ); - todo_wine ok( !!h, "Got NULL.\n" ); + ok( !!h, "Got NULL.\n" ); + check_dll_path( h, system_path ); hapiset = GetModuleHandleA( apiset_dll ); ok( hapiset == h, "Got %p, %p.\n", hapiset, h ); FreeLibrary( h ); @@ -1691,8 +1693,9 @@ static void test_known_dlls_load(void)
h = LoadLibraryExA( dll, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR ); ok( !!h, "Got NULL.\n" ); + check_dll_path( h, system_path ); hapiset = GetModuleHandleA( apiset_dll ); - todo_wine ok( hapiset == h, "Got %p, %p.\n", hapiset, h ); + ok( hapiset == h, "Got %p, %p.\n", hapiset, h ); FreeLibrary( h );
/* Local version can still be loaded if dll name contains path. */ diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c index 6e9daecc4d5..7becff2c23b 100644 --- a/dlls/ntdll/loader.c +++ b/dlls/ntdll/loader.c @@ -36,6 +36,7 @@ #include "wine/exception.h" #include "wine/debug.h" #include "wine/list.h" +#include "wine/rbtree.h" #include "ntdll_misc.h" #include "ddk/wdm.h"
@@ -185,6 +186,13 @@ static WINE_MODREF *last_failed_modref;
static LDR_DDAG_NODE *node_ntdll, *node_kernel32;
+struct known_dll +{ + struct rb_entry entry; + WCHAR name[1]; +}; +static struct rb_tree known_dlls; + static NTSTATUS load_dll( const WCHAR *load_path, const WCHAR *libname, DWORD flags, WINE_MODREF** pwm, BOOL system ); static NTSTATUS process_attach( LDR_DDAG_NODE *node, LPVOID lpReserved ); static FARPROC find_ordinal_export( HMODULE module, const IMAGE_EXPORT_DIRECTORY *exports, @@ -3058,6 +3066,7 @@ static NTSTATUS find_dll_file( const WCHAR *load_path, const WCHAR *libname, UNI WINE_MODREF **pwm, HANDLE *mapping, SECTION_IMAGE_INFORMATION *image_info, struct file_id *id ) { + const WCHAR *known_dll_name = NULL; WCHAR *fullname = NULL; NTSTATUS status; ULONG wow64_old_value = 0; @@ -3090,6 +3099,12 @@ static NTSTATUS find_dll_file( const WCHAR *load_path, const WCHAR *libname, UNI goto done; } } + if (!fullname && rb_get( &known_dlls, libname )) + { + prepend_system_dir( libname, wcslen(libname), &fullname ); + known_dll_name = libname; + libname = fullname; + } }
if (RtlDetermineDosPathNameType_U( libname ) == RELATIVE_PATH) @@ -3099,7 +3114,11 @@ static NTSTATUS find_dll_file( const WCHAR *load_path, const WCHAR *libname, UNI status = find_builtin_without_file( libname, nt_name, pwm, mapping, image_info, id ); } else if (!(status = RtlDosPathNameToNtPathName_U_WithStatus( libname, nt_name, NULL, NULL ))) + { status = open_dll_file( nt_name, pwm, mapping, image_info, id ); + if (status == STATUS_DLL_NOT_FOUND && known_dll_name) + status = find_builtin_without_file( known_dll_name, nt_name, pwm, mapping, image_info, id ); + }
if (status == STATUS_IMAGE_MACHINE_TYPE_MISMATCH) status = STATUS_INVALID_IMAGE_FORMAT;
@@ -3952,17 +3971,33 @@ static void process_breakpoint(void) __ENDTRY }
+/************************************************************************* + * compare_known_dlls + */ +static int compare_known_dlls( const void *name, const struct wine_rb_entry *entry ) +{ + struct known_dll *known_dll = WINE_RB_ENTRY_VALUE( entry, struct known_dll, entry ); + + return wcsicmp( name, known_dll->name ); +}
/*********************************************************************** * load_global_options */ static void load_global_options(void) { + char buffer[256]; + KEY_VALUE_PARTIAL_INFORMATION *info = (KEY_VALUE_PARTIAL_INFORMATION *)buffer; OBJECT_ATTRIBUTES attr; UNICODE_STRING bootstrap_mode_str = RTL_CONSTANT_STRING( L"WINEBOOTSTRAPMODE" ); UNICODE_STRING session_manager_str = RTL_CONSTANT_STRING( L"\Registry\Machine\System\CurrentControlSet\Control\Session Manager" ); + UNICODE_STRING known_dlls_str = + RTL_CONSTANT_STRING( L"\Registry\Machine\System\CurrentControlSet\Control\Session Manager\KnownDLLs" ); UNICODE_STRING val_str; + struct known_dll *known_dll; + ULONG idx = 0, size; + NTSTATUS status; HANDLE hkey;
val_str.MaximumLength = 0; @@ -3982,6 +4017,24 @@ static void load_global_options(void) query_dword_option( hkey, L"SafeDllSearchMode", &dll_safe_mode ); NtClose( hkey ); } + + rb_init( &known_dlls, compare_known_dlls ); + + attr.ObjectName = &known_dlls_str; + if (NtOpenKey( &hkey, KEY_QUERY_VALUE, &attr )) return; + while (1) + { + status = NtEnumerateValueKey( hkey, idx++, KeyValuePartialInformation, buffer, sizeof(buffer), &size ); + if (status == STATUS_BUFFER_OVERFLOW) continue; + if (status) break; + if (info->Type != REG_SZ) continue; + + known_dll = RtlAllocateHeap( GetProcessHeap(), 0, offsetof(struct known_dll, name[0]) + info->DataLength ); + if (!known_dll) break; + memcpy( known_dll->name, info->Data, info->DataLength ); + rb_put( &known_dlls, known_dll->name, &known_dll->entry ); + } + NtClose( hkey ); }