Signed-off-by: Paul Gofman pgofman@codeweavers.com --- v2: - protect cached_modref with ldr_data_srw_lock instead of introducing interlocked access.
dlls/ntdll/loader.c | 65 +++++++++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 20 deletions(-)
diff --git a/dlls/ntdll/loader.c b/dlls/ntdll/loader.c index 047e837238c..8e2ed03ad4c 100644 --- a/dlls/ntdll/loader.c +++ b/dlls/ntdll/loader.c @@ -206,7 +206,10 @@ static RTL_SRWLOCK ldr_data_srw_lock = RTL_SRWLOCK_INIT; static RTL_BITMAP tls_bitmap; static RTL_BITMAP tls_expansion_bitmap;
+/* Guarded by ldr_data_srw_lock. */ static WINE_MODREF *cached_modref; + +/* Used with exclusive loader lock only. */ static WINE_MODREF *current_modref; static WINE_MODREF *last_failed_modref;
@@ -470,6 +473,33 @@ static void lock_loader_restore_exclusive(void) locked_exclusive = TRUE; }
+/************************************************************************* + * get_cached_modref + * + */ +static WINE_MODREF *get_cached_modref(void) +{ + WINE_MODREF *ret; + + RtlAcquireSRWLockShared( &ldr_data_srw_lock ); + ret = cached_modref; + RtlReleaseSRWLockShared( &ldr_data_srw_lock ); + return ret; +} + +/************************************************************************* + * set_cached_modref + * + * Returns the new cached modref. + */ +static WINE_MODREF *set_cached_modref( WINE_MODREF *new ) +{ + RtlAcquireSRWLockExclusive( &ldr_data_srw_lock ); + cached_modref = new; + RtlReleaseSRWLockExclusive( &ldr_data_srw_lock ); + return new; +} + #define RTL_UNLOAD_EVENT_TRACE_NUMBER 64
typedef struct _RTL_UNLOAD_EVENT_TRACE @@ -753,17 +783,18 @@ static void call_ldr_notifications( ULONG reason, LDR_DATA_TABLE_ENTRY *module ) */ static WINE_MODREF *get_modref( HMODULE hmod ) { + WINE_MODREF *cached = get_cached_modref(); PLIST_ENTRY mark, entry; PLDR_DATA_TABLE_ENTRY mod;
- if (cached_modref && cached_modref->ldr.DllBase == hmod) return cached_modref; + if (cached && cached->ldr.DllBase == hmod) return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InMemoryOrderModuleList; for (entry = mark->Flink; entry != mark; entry = entry->Flink) { mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InMemoryOrderLinks); if (mod->DllBase == hmod) - return cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr); + return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) ); } return NULL; } @@ -777,23 +808,21 @@ static WINE_MODREF *get_modref( HMODULE hmod ) */ static WINE_MODREF *find_basename_module( LPCWSTR name ) { + WINE_MODREF *cached = get_cached_modref(); PLIST_ENTRY mark, entry; UNICODE_STRING name_str;
RtlInitUnicodeString( &name_str, name );
- if (cached_modref && RtlEqualUnicodeString( &name_str, &cached_modref->ldr.BaseDllName, TRUE )) - return cached_modref; + if (cached && RtlEqualUnicodeString( &name_str, &cached->ldr.BaseDllName, TRUE )) + return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList; for (entry = mark->Flink; entry != mark; entry = entry->Flink) { LDR_DATA_TABLE_ENTRY *mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks); if (RtlEqualUnicodeString( &name_str, &mod->BaseDllName, TRUE )) - { - cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr); - return cached_modref; - } + return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) ); } return NULL; } @@ -807,6 +836,7 @@ static WINE_MODREF *find_basename_module( LPCWSTR name ) */ static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name ) { + WINE_MODREF *cached = get_cached_modref(); PLIST_ENTRY mark, entry; UNICODE_STRING name = *nt_name;
@@ -814,18 +844,15 @@ static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name ) name.Length -= 4 * sizeof(WCHAR); /* for ??\ prefix */ name.Buffer += 4;
- if (cached_modref && RtlEqualUnicodeString( &name, &cached_modref->ldr.FullDllName, TRUE )) - return cached_modref; + if (cached && RtlEqualUnicodeString( &name, &cached->ldr.FullDllName, TRUE )) + return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList; for (entry = mark->Flink; entry != mark; entry = entry->Flink) { LDR_DATA_TABLE_ENTRY *mod = CONTAINING_RECORD(entry, LDR_DATA_TABLE_ENTRY, InLoadOrderLinks); if (RtlEqualUnicodeString( &name, &mod->FullDllName, TRUE )) - { - cached_modref = CONTAINING_RECORD(mod, WINE_MODREF, ldr); - return cached_modref; - } + return set_cached_modref( CONTAINING_RECORD(mod, WINE_MODREF, ldr) ); } return NULL; } @@ -839,9 +866,10 @@ static WINE_MODREF *find_fullname_module( const UNICODE_STRING *nt_name ) */ static WINE_MODREF *find_fileid_module( const struct file_id *id ) { + WINE_MODREF *cached = get_cached_modref(); LIST_ENTRY *mark, *entry;
- if (cached_modref && !memcmp( &cached_modref->id, id, sizeof(*id) )) return cached_modref; + if (cached && !memcmp( &cached->id, id, sizeof(*id) )) return cached;
mark = &NtCurrentTeb()->Peb->LdrData->InLoadOrderModuleList; for (entry = mark->Flink; entry != mark; entry = entry->Flink) @@ -850,10 +878,7 @@ static WINE_MODREF *find_fileid_module( const struct file_id *id ) WINE_MODREF *wm = CONTAINING_RECORD( mod, WINE_MODREF, ldr );
if (!memcmp( &wm->id, id, sizeof(*id) )) - { - cached_modref = wm; - return wm; - } + return set_cached_modref( wm ); } return NULL; } @@ -3810,6 +3835,7 @@ static void free_modref( WINE_MODREF *wm ) RemoveEntryList(&wm->ldr.InMemoryOrderLinks); if (wm->ldr.InInitializationOrderLinks.Flink) RemoveEntryList(&wm->ldr.InInitializationOrderLinks); + if (cached_modref == wm) cached_modref = NULL; RtlReleaseSRWLockExclusive( &ldr_data_srw_lock );
TRACE(" unloading %s\n", debugstr_w(wm->ldr.FullDllName.Buffer)); @@ -3821,7 +3847,6 @@ static void free_modref( WINE_MODREF *wm ) free_tls_slot( &wm->ldr ); RtlReleaseActivationContext( wm->ldr.ActivationContext ); NtUnmapViewOfSection( NtCurrentProcess(), wm->ldr.DllBase ); - if (cached_modref == wm) cached_modref = NULL; RtlFreeUnicodeString( &wm->ldr.FullDllName ); RtlFreeHeap( GetProcessHeap(), 0, wm->deps ); RtlFreeHeap( GetProcessHeap(), 0, wm );