Signed-off-by: Paul Gofman <pgofman(a)codeweavers.com>
---
v2:
- don't add redundant checks on server.
v3:
- always pass &fd_type instead of type in server_get_unix_fd().
dlls/ntdll/file.c | 6 +++++-
dlls/ntdll/server.c | 18 ++++++++++++++----
dlls/ntdll/tests/om.c | 44 +++++++++++++++++++++++++++++++++++++++++--
3 files changed, 61 insertions(+), 7 deletions(-)
diff --git a/dlls/ntdll/file.c b/dlls/ntdll/file.c
index 9997a5e1cd..33f78b63ac 100644
--- a/dlls/ntdll/file.c
+++ b/dlls/ntdll/file.c
@@ -3266,10 +3266,11 @@ NTSTATUS WINAPI NtQueryVolumeInformationFile( HANDLE handle, PIO_STATUS_BLOCK io
PVOID buffer, ULONG length,
FS_INFORMATION_CLASS info_class )
{
+ enum server_fd_type type;
int fd, needs_close;
struct stat st;
- io->u.Status = server_get_unix_fd( handle, 0, &fd, &needs_close, NULL, NULL );
+ io->u.Status = server_get_unix_fd( handle, 0, &fd, &needs_close, &type, NULL );
if (io->u.Status == STATUS_BAD_DEVICE_TYPE)
{
SERVER_START_REQ( get_volume_info )
@@ -3285,6 +3286,9 @@ NTSTATUS WINAPI NtQueryVolumeInformationFile( HANDLE handle, PIO_STATUS_BLOCK io
}
else if (io->u.Status) return io->u.Status;
+ if (type == FD_TYPE_MAPPING)
+ return STATUS_OBJECT_TYPE_MISMATCH;
+
io->u.Status = STATUS_NOT_IMPLEMENTED;
io->Information = 0;
diff --git a/dlls/ntdll/server.c b/dlls/ntdll/server.c
index 3832a80f1e..214c11c8f2 100644
--- a/dlls/ntdll/server.c
+++ b/dlls/ntdll/server.c
@@ -1073,6 +1073,7 @@ int server_remove_fd_from_cache( HANDLE handle )
int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd,
int *needs_close, enum server_fd_type *type, unsigned int *options )
{
+ enum server_fd_type fd_type = FD_TYPE_NB_TYPES;
sigset_t sigset;
obj_handle_t fd_handle;
int ret, fd = -1;
@@ -1082,11 +1083,11 @@ int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd,
*needs_close = 0;
wanted_access &= FILE_READ_DATA | FILE_WRITE_DATA | FILE_APPEND_DATA;
- ret = get_cached_fd( handle, &fd, type, &access, options );
+ ret = get_cached_fd( handle, &fd, &fd_type, &access, options );
if (ret != STATUS_INVALID_HANDLE) goto done;
server_enter_uninterrupted_section( &fd_cache_section, &sigset );
- ret = get_cached_fd( handle, &fd, type, &access, options );
+ ret = get_cached_fd( handle, &fd, &fd_type, &access, options );
if (ret == STATUS_INVALID_HANDLE)
{
SERVER_START_REQ( get_handle_fd )
@@ -1094,7 +1095,7 @@ int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd,
req->handle = wine_server_obj_handle( handle );
if (!(ret = wine_server_call( req )))
{
- if (type) *type = reply->type;
+ fd_type = reply->type;
if (options) *options = reply->options;
access = reply->access;
if ((fd = receive_fd( &fd_handle )) != -1)
@@ -1116,12 +1117,21 @@ int server_get_unix_fd( HANDLE handle, unsigned int wanted_access, int *unix_fd,
server_leave_uninterrupted_section( &fd_cache_section, &sigset );
done:
+ if (!ret && wanted_access && fd_type == FD_TYPE_MAPPING)
+ ret = STATUS_INVALID_HANDLE;
+
if (!ret && ((access & wanted_access) != wanted_access))
{
ret = STATUS_ACCESS_DENIED;
if (*needs_close) close( fd );
}
- if (!ret) *unix_fd = fd;
+ if (!ret)
+ {
+ if (type)
+ *type = fd_type;
+
+ *unix_fd = fd;
+ }
return ret;
}
diff --git a/dlls/ntdll/tests/om.c b/dlls/ntdll/tests/om.c
index c17b6ffa8d..b6a338596e 100644
--- a/dlls/ntdll/tests/om.c
+++ b/dlls/ntdll/tests/om.c
@@ -23,6 +23,7 @@
#include "winternl.h"
#include "stdio.h"
#include "winnt.h"
+#include "winioctl.h"
#include "stdlib.h"
static HANDLE (WINAPI *pCreateWaitableTimerA)(SECURITY_ATTRIBUTES*, BOOL, LPCSTR);
@@ -72,6 +73,8 @@ static NTSTATUS (WINAPI *pNtReleaseKeyedEvent)( HANDLE, const void *, BOOLEAN, c
static NTSTATUS (WINAPI *pNtCreateIoCompletion)(PHANDLE, ACCESS_MASK, POBJECT_ATTRIBUTES, ULONG);
static NTSTATUS (WINAPI *pNtOpenIoCompletion)( PHANDLE, ACCESS_MASK, POBJECT_ATTRIBUTES );
static NTSTATUS (WINAPI *pNtQueryInformationFile)(HANDLE, PIO_STATUS_BLOCK, void *, ULONG, FILE_INFORMATION_CLASS);
+static NTSTATUS (WINAPI *pNtQueryVolumeInformationFile)( HANDLE handle, PIO_STATUS_BLOCK io, PVOID buffer,
+ ULONG length, FS_INFORMATION_CLASS info_class );
static NTSTATUS (WINAPI *pNtQuerySystemTime)( LARGE_INTEGER * );
static NTSTATUS (WINAPI *pRtlWaitOnAddress)( const void *, const void *, SIZE_T, const LARGE_INTEGER * );
static void (WINAPI *pRtlWakeAddressAll)( const void * );
@@ -1585,9 +1588,16 @@ static void test_query_object(void)
static void test_type_mismatch(void)
{
- HANDLE h;
- NTSTATUS res;
+ char tmp_path[MAX_PATH], file_name[MAX_PATH + 16];
+ FILE_FS_DEVICE_INFORMATION info;
OBJECT_ATTRIBUTES attr;
+ IO_STATUS_BLOCK io;
+ LARGE_INTEGER size;
+ HANDLE h, hfile;
+ DWORD length;
+ NTSTATUS res;
+ DWORD type;
+ BOOL ret;
attr.Length = sizeof(attr);
attr.RootDirectory = 0;
@@ -1603,6 +1613,35 @@ static void test_type_mismatch(void)
ok(res == STATUS_OBJECT_TYPE_MISMATCH, "expected 0xc0000024, got %x\n", res);
pNtClose( h );
+
+ GetTempPathA(MAX_PATH, tmp_path);
+ GetTempFileNameA(tmp_path, "foo", 0, file_name);
+ hfile = CreateFileA(file_name, GENERIC_READ | GENERIC_WRITE, 0, NULL, CREATE_ALWAYS,
+ FILE_FLAG_DELETE_ON_CLOSE, 0);
+ ok(hfile != INVALID_HANDLE_VALUE, "Got unexpected hfile %p.\n", hfile);
+
+ size.QuadPart = 256;
+ res = pNtCreateSection(&h, SECTION_MAP_WRITE | SECTION_MAP_READ, NULL, &size,
+ PAGE_READWRITE, SEC_COMMIT, hfile);
+ ok(res == STATUS_SUCCESS , "Got unexpected res %x\n", res);
+
+ res = pNtQueryVolumeInformationFile(h, &io, &info, sizeof(info), FileFsDeviceInformation);
+ ok(res == STATUS_OBJECT_TYPE_MISMATCH, "Got unexpected result %#x.\n", res);
+
+ type = GetFileType(h);
+ ok(type == FILE_TYPE_UNKNOWN, "Got unexpected type %#x.\n", type);
+
+ SetLastError(0xdeadbeef);
+ ret = WriteFile(h, file_name, 16, &length, NULL);
+ ok(!ret && GetLastError() == ERROR_INVALID_HANDLE, "Got unexpected ret %#x, GetLastError() %#x.\n",
+ ret, GetLastError());
+
+ SetLastError(0xdeadbeef);
+ ret = ReadFile(h, file_name, 16, &length, NULL);
+ ok(!ret && GetLastError() == ERROR_INVALID_HANDLE, "Got unexpected ret %#x, GetLastError() %#x.\n",
+ ret, GetLastError());
+ pNtClose(h);
+ CloseHandle(hfile);
}
static void test_event(void)
@@ -2207,6 +2246,7 @@ START_TEST(om)
pRtlWaitOnAddress = (void *)GetProcAddress(hntdll, "RtlWaitOnAddress");
pRtlWakeAddressAll = (void *)GetProcAddress(hntdll, "RtlWakeAddressAll");
pRtlWakeAddressSingle = (void *)GetProcAddress(hntdll, "RtlWakeAddressSingle");
+ pNtQueryVolumeInformationFile = (void *)GetProcAddress(hntdll, "NtQueryVolumeInformationFile");
test_case_sensitive();
test_namespace_pipe();
--
2.26.2