Signed-off-by: Paul Gofman pgofman@codeweavers.com --- dlls/ntoskrnl.exe/ntoskrnl.c | 24 +++++++++++++ dlls/ntoskrnl.exe/tests/driver.c | 60 +++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-)
diff --git a/dlls/ntoskrnl.exe/ntoskrnl.c b/dlls/ntoskrnl.exe/ntoskrnl.c index 94733ec60a0..fbf6262b3eb 100644 --- a/dlls/ntoskrnl.exe/ntoskrnl.c +++ b/dlls/ntoskrnl.exe/ntoskrnl.c @@ -3673,6 +3673,30 @@ static HMODULE load_driver( const WCHAR *driver_name, const UNICODE_STRING *keyn TRACE( "loading driver %s\n", wine_dbgstr_w(str) );
module = load_driver_module( str ); + + if (module && load_image_notify_routine_count) + { + UNICODE_STRING module_name; + IMAGE_NT_HEADERS *nt; + IMAGE_INFO info; + unsigned int i; + + RtlInitUnicodeString(&module_name, str); + nt = RtlImageNtHeader(module); + memset(&info, 0, sizeof(info)); + info.u.s.ImageAddressingMode = IMAGE_ADDRESSING_MODE_32BIT; + info.u.s.SystemModeImage = TRUE; + info.ImageSize = nt->OptionalHeader.SizeOfImage; + info.ImageBase = module; + + for (i = 0; i < load_image_notify_routine_count; ++i) + { + TRACE("Calling image load notify %p.\n", load_image_notify_routines[i]); + load_image_notify_routines[i](&module_name, NULL, &info); + TRACE("Called image load notify %p.\n", load_image_notify_routines[i]); + } + } + HeapFree( GetProcessHeap(), 0, path ); return module; } diff --git a/dlls/ntoskrnl.exe/tests/driver.c b/dlls/ntoskrnl.exe/tests/driver.c index f6cc442bab0..bfc2db3adbf 100644 --- a/dlls/ntoskrnl.exe/tests/driver.c +++ b/dlls/ntoskrnl.exe/tests/driver.c @@ -308,21 +308,79 @@ static const WCHAR driver2_path[] = { '\','W','i','n','e','T','e','s','t','D','r','i','v','e','r','2',0 };
+static IMAGE_INFO test_image_info; +static int test_load_image_notify_count; +static WCHAR test_load_image_name[MAX_PATH]; + +static void WINAPI test_load_image_notify_routine(UNICODE_STRING *image_name, HANDLE process_id, + IMAGE_INFO *image_info) +{ + if (test_load_image_notify_count == -1 + || (image_name->Buffer && wcsstr(image_name->Buffer, u".tmp"))) + { + ++test_load_image_notify_count; + test_image_info = *image_info; + wcscpy(test_load_image_name, image_name->Buffer); + } +} + static void test_load_driver(void) { - UNICODE_STRING name; + static WCHAR image_path_key_name[] = u"ImagePath"; + RTL_QUERY_REGISTRY_TABLE query_table[2]; + UNICODE_STRING name, image_path; NTSTATUS ret;
+ ret = PsSetLoadImageNotifyRoutine(test_load_image_notify_routine); + ok(ret == STATUS_SUCCESS, "Got unexpected status %#x.\n", ret); + + /* Routine gets registered twice on Windows. */ + ret = PsSetLoadImageNotifyRoutine(test_load_image_notify_routine); + ok(ret == STATUS_SUCCESS, "Got unexpected status %#x.\n", ret); + + RtlInitUnicodeString(&image_path, NULL); + memset(query_table, 0, sizeof(query_table)); + query_table[0].QueryRoutine = NULL; + query_table[0].Name = image_path_key_name; + query_table[0].EntryContext = &image_path; + query_table[0].Flags = RTL_QUERY_REGISTRY_DIRECT | RTL_QUERY_REGISTRY_TYPECHECK; + query_table[0].DefaultType = REG_EXPAND_SZ << RTL_QUERY_REGISTRY_TYPECHECK_SHIFT; + + ret = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE, driver2_path, query_table, NULL, NULL); + ok(ret == STATUS_SUCCESS, "Got unexpected status %#x.\n", ret); + ok(!!image_path.Buffer, "image_path.Buffer is NULL.\n"); + RtlInitUnicodeString(&name, driver2_path);
ret = ZwLoadDriver(&name); ok(!ret, "got %#x\n", ret);
+ ok(test_load_image_notify_count == 2, "Got unexpected test_load_image_notify_count %u.\n", + test_load_image_notify_count); + ok(test_image_info.ImageAddressingMode == IMAGE_ADDRESSING_MODE_32BIT, + "Got unexpected ImageAddressingMode %#x.\n", test_image_info.ImageAddressingMode); + ok(test_image_info.SystemModeImage, + "Got unexpected SystemModeImage %#x.\n", test_image_info.SystemModeImage); + ok(!wcscmp(test_load_image_name, image_path.Buffer), "Image path names do not match.\n"); + + test_load_image_notify_count = -1; + ret = ZwLoadDriver(&name); ok(ret == STATUS_IMAGE_ALREADY_LOADED, "got %#x\n", ret);
ret = ZwUnloadDriver(&name); ok(!ret, "got %#x\n", ret); + + ret = PsRemoveLoadImageNotifyRoutine(test_load_image_notify_routine); + ok(ret == STATUS_SUCCESS, "Got unexpected status %#x.\n", ret); + ret = PsRemoveLoadImageNotifyRoutine(test_load_image_notify_routine); + ok(ret == STATUS_SUCCESS, "Got unexpected status %#x.\n", ret); + ret = PsRemoveLoadImageNotifyRoutine(test_load_image_notify_routine); + ok(ret == STATUS_PROCEDURE_NOT_FOUND, "Got unexpected status %#x.\n", ret); + + ok(test_load_image_notify_count == -1, "Got unexpected test_load_image_notify_count %u.\n", + test_load_image_notify_count); + RtlFreeUnicodeString(&image_path); }
static NTSTATUS wait_single(void *obj, ULONGLONG timeout)