Signed-off-by: Paul Gofman pgofman@codeweavers.com --- v2: - don't add redundant checks on server. v3: - always pass &fd_type instead of type in server_get_unix_fd().
dlls/ntdll/file.c | 6 +++++- dlls/ntdll/server.c | 18 ++++++++++++++---- dlls/ntdll/tests/om.c | 44 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 7 deletions(-)
diff --git a/dlls/ntdll/file.c b/dlls/ntdll/file.c index 9997a5e1cd..33f78b63ac 100644 --- a/dlls/ntdll/file.c +++ b/dlls/ntdll/file.c @@ -3266,10 +3266,11 @@ NTSTATUS WINAPI NtQueryVolumeInformationFile( HANDLE handle, PIO_STATUS_BLOCK io PVOID buffer, ULONG length, FS_INFORMATION_CLASS info_class ) { + enum server_fd_type type; int fd, needs_close; struct stat st;
- io->u.Status = server_get_unix_fd( handle, 0, &fd, &needs_close, NULL, NULL ); + io->u.Status = server_get_unix_fd( handle, 0, &fd, &needs_close, &type, NULL ); if (io->u.Status == STATUS_BAD_DEVICE_TYPE) { SERVER_START_REQ( get_volume_info ) @@ -3285,6 +3286,9 @@ NTSTATUS WINAPI NtQueryVolumeInformationFile( HANDLE handle, PIO_STATUS_BLOCK io } else if (io->u.Status) return io->u.Status;
+ if (type == FD_TYPE_MAPPING) + return STATUS_OBJECT_TYPE_MISMATCH; + io->u.Status = STATUS_NOT_IMPLEMENTED; io->Information = 0;
diff --git a/dlls/ntdll/server.c b/dlls/ntdll/server.c index 3832a80f1e..214c11c8f2 100644 --- a/dlls/ntdll/server.c +++ b/dlls/ntdll/server.c @@ -1073,6 +1073,7 @@ int server_remove_fd_from_cache( HANDLE handle ) int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd, int *needs_close, enum server_fd_type *type, unsigned int *options ) { + enum server_fd_type fd_type = FD_TYPE_NB_TYPES; sigset_t sigset; obj_handle_t fd_handle; int ret, fd = -1; @@ -1082,11 +1083,11 @@ int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd, *needs_close = 0; wanted_access &= FILE_READ_DATA | FILE_WRITE_DATA | FILE_APPEND_DATA;
- ret = get_cached_fd( handle, &fd, type, &access, options ); + ret = get_cached_fd( handle, &fd, &fd_type, &access, options ); if (ret != STATUS_INVALID_HANDLE) goto done;
server_enter_uninterrupted_section( &fd_cache_section, &sigset ); - ret = get_cached_fd( handle, &fd, type, &access, options ); + ret = get_cached_fd( handle, &fd, &fd_type, &access, options ); if (ret == STATUS_INVALID_HANDLE) { SERVER_START_REQ( get_handle_fd ) @@ -1094,7 +1095,7 @@ int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd, req->handle = wine_server_obj_handle( handle ); if (!(ret = wine_server_call( req ))) { - if (type) *type = reply->type; + fd_type = reply->type; if (options) *options = reply->options; access = reply->access; if ((fd = receive_fd( &fd_handle )) != -1) @@ -1116,12 +1117,21 @@ int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd, server_leave_uninterrupted_section( &fd_cache_section, &sigset );
done: + if (!ret && wanted_access && fd_type == FD_TYPE_MAPPING) + ret = STATUS_INVALID_HANDLE; + if (!ret && ((access & wanted_access) != wanted_access)) { ret = STATUS_ACCESS_DENIED; if (*needs_close) close( fd ); } - if (!ret) *unix_fd = fd; + if (!ret) + { + if (type) + *type = fd_type; + + *unix_fd = fd; + } return ret; }
diff --git a/dlls/ntdll/tests/om.c b/dlls/ntdll/tests/om.c index c17b6ffa8d..b6a338596e 100644 --- a/dlls/ntdll/tests/om.c +++ b/dlls/ntdll/tests/om.c @@ -23,6 +23,7 @@ #include "winternl.h" #include "stdio.h" #include "winnt.h" +#include "winioctl.h" #include "stdlib.h"
static HANDLE (WINAPI *pCreateWaitableTimerA)(SECURITY_ATTRIBUTES*, BOOL, LPCSTR); @@ -72,6 +73,8 @@ static NTSTATUS (WINAPI *pNtReleaseKeyedEvent)( HANDLE, const void *, BOOLEAN, c static NTSTATUS (WINAPI *pNtCreateIoCompletion)(PHANDLE, ACCESS_MASK, POBJECT_ATTRIBUTES, ULONG); static NTSTATUS (WINAPI *pNtOpenIoCompletion)( PHANDLE, ACCESS_MASK, POBJECT_ATTRIBUTES ); static NTSTATUS (WINAPI *pNtQueryInformationFile)(HANDLE, PIO_STATUS_BLOCK, void *, ULONG, FILE_INFORMATION_CLASS); +static NTSTATUS (WINAPI *pNtQueryVolumeInformationFile)( HANDLE handle, PIO_STATUS_BLOCK io, PVOID buffer, + ULONG length, FS_INFORMATION_CLASS info_class ); static NTSTATUS (WINAPI *pNtQuerySystemTime)( LARGE_INTEGER * ); static NTSTATUS (WINAPI *pRtlWaitOnAddress)( const void *, const void *, SIZE_T, const LARGE_INTEGER * ); static void (WINAPI *pRtlWakeAddressAll)( const void * ); @@ -1585,9 +1588,16 @@ static void test_query_object(void)
static void test_type_mismatch(void) { - HANDLE h; - NTSTATUS res; + char tmp_path[MAX_PATH], file_name[MAX_PATH + 16]; + FILE_FS_DEVICE_INFORMATION info; OBJECT_ATTRIBUTES attr; + IO_STATUS_BLOCK io; + LARGE_INTEGER size; + HANDLE h, hfile; + DWORD length; + NTSTATUS res; + DWORD type; + BOOL ret;
attr.Length = sizeof(attr); attr.RootDirectory = 0; @@ -1603,6 +1613,35 @@ static void test_type_mismatch(void) ok(res == STATUS_OBJECT_TYPE_MISMATCH, "expected 0xc0000024, got %x\n", res);
pNtClose( h ); + + GetTempPathA(MAX_PATH, tmp_path); + GetTempFileNameA(tmp_path, "foo", 0, file_name); + hfile = CreateFileA(file_name, GENERIC_READ | GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, + FILE_FLAG_DELETE_ON_CLOSE, 0); + ok(hfile != INVALID_HANDLE_VALUE, "Got unexpected hfile %p.\n", hfile); + + size.QuadPart = 256; + res = pNtCreateSection(&h, SECTION_MAP_WRITE | SECTION_MAP_READ, NULL, &size, + PAGE_READWRITE, SEC_COMMIT, hfile); + ok(res == STATUS_SUCCESS , "Got unexpected res %x\n", res); + + res = pNtQueryVolumeInformationFile(h, &io, &info, sizeof(info), FileFsDeviceInformation); + ok(res == STATUS_OBJECT_TYPE_MISMATCH, "Got unexpected result %#x.\n", res); + + type = GetFileType(h); + ok(type == FILE_TYPE_UNKNOWN, "Got unexpected type %#x.\n", type); + + SetLastError(0xdeadbeef); + ret = WriteFile(h, file_name, 16, &length, NULL); + ok(!ret && GetLastError() == ERROR_INVALID_HANDLE, "Got unexpected ret %#x, GetLastError() %#x.\n", + ret, GetLastError()); + + SetLastError(0xdeadbeef); + ret = ReadFile(h, file_name, 16, &length, NULL); + ok(!ret && GetLastError() == ERROR_INVALID_HANDLE, "Got unexpected ret %#x, GetLastError() %#x.\n", + ret, GetLastError()); + pNtClose(h); + CloseHandle(hfile); }
static void test_event(void) @@ -2207,6 +2246,7 @@ START_TEST(om) pRtlWaitOnAddress = (void *)GetProcAddress(hntdll, "RtlWaitOnAddress"); pRtlWakeAddressAll = (void *)GetProcAddress(hntdll, "RtlWakeAddressAll"); pRtlWakeAddressSingle = (void *)GetProcAddress(hntdll, "RtlWakeAddressSingle"); + pNtQueryVolumeInformationFile = (void *)GetProcAddress(hntdll, "NtQueryVolumeInformationFile");
test_case_sensitive(); test_namespace_pipe();