From: Paul Gofman pgofman@codeweavers.com
--- configure.ac | 2 + dlls/kernel32/tests/virtual.c | 16 ++- dlls/ntdll/unix/virtual.c | 219 +++++++++++++++++++++++++++++++--- dlls/ws2_32/tests/sock.c | 16 ++- 4 files changed, 229 insertions(+), 24 deletions(-)
diff --git a/configure.ac b/configure.ac index e7a3eb9d08a..4a23c0442d9 100644 --- a/configure.ac +++ b/configure.ac @@ -385,6 +385,7 @@ AC_CHECK_HEADERS(\ link.h \ linux/cdrom.h \ linux/filter.h \ + linux/fs.h \ linux/hdreg.h \ linux/hidraw.h \ linux/input.h \ @@ -394,6 +395,7 @@ AC_CHECK_HEADERS(\ linux/serial.h \ linux/types.h \ linux/ucdrom.h \ + linux/userfaultfd.h \ linux/wireless.h \ lwp.h \ mach-o/loader.h \ diff --git a/dlls/kernel32/tests/virtual.c b/dlls/kernel32/tests/virtual.c index 565a71a8c23..c1a8a829f48 100644 --- a/dlls/kernel32/tests/virtual.c +++ b/dlls/kernel32/tests/virtual.c @@ -2292,7 +2292,7 @@ static void test_write_watch(void) count = 64; ret = pGetWriteWatch( 0, base, size, results, &count, &pagesize ); ok( !ret, "GetWriteWatch failed %lu\n", GetLastError() ); - todo_wine ok( !count, "wrong count %Iu\n", count ); + todo_wine_if(count == 2) ok( !count, "wrong count %Iu\n", count );
base = VirtualAlloc( base, size, MEM_COMMIT, PAGE_READWRITE ); ok(!!base, "VirtualAlloc failed.\n"); @@ -2300,7 +2300,7 @@ static void test_write_watch(void) count = 64; ret = pGetWriteWatch( 0, base, size, results, &count, &pagesize ); ok( !ret, "GetWriteWatch failed %lu\n", GetLastError() ); - todo_wine ok( !count, "wrong count %Iu\n", count ); + todo_wine_if(count == 2) ok( !count, "wrong count %Iu\n", count );
/* Looks like VirtualProtect latches write watch state somewhere, so if pages are decommitted after, * (which normally clears write watch state), a page from range which previously had protection change @@ -2344,9 +2344,15 @@ static void test_write_watch(void) ok( !ret, "GetWriteWatch failed %lu\n", GetLastError() ); todo_wine ok( count == 4, "wrong count %Iu\n", count ); ok( results[0] == base + 2*pagesize, "wrong result %p\n", results[0] ); - ok( results[1] == base + 3*pagesize, "wrong result %p\n", results[1] ); - ok( results[2] == base + 4*pagesize, "wrong result %p\n", results[2] ); - todo_wine ok( results[3] == base + 6*pagesize, "wrong result %p\n", results[3] ); + i = 1; + if (count >= 4) + { + ok( results[i] == base + 3*pagesize, "wrong result %p\n", results[i] ); + ++i; + } + ok( results[i] == base + 4*pagesize, "wrong result %p\n", results[i] ); + ++i; + todo_wine_if(count == 5) ok( results[i] == base + 6*pagesize, "wrong result %p\n", results[i] );
VirtualFree( base, 0, MEM_RELEASE );
diff --git a/dlls/ntdll/unix/virtual.c b/dlls/ntdll/unix/virtual.c index ca5d50b0fe6..f64a5477382 100644 --- a/dlls/ntdll/unix/virtual.c +++ b/dlls/ntdll/unix/virtual.c @@ -72,6 +72,15 @@ #undef host_page_size #endif
+#include <sys/ioctl.h> +#if defined(HAVE_LINUX_USERFAULTFD_H) && defined(HAVE_LINUX_FS_H) +# include <linux/userfaultfd.h> +# include <linux/fs.h> +#if defined(UFFD_FEATURE_WP_ASYNC) && defined(PM_SCAN_WP_MATCHING) +#define HAVE_UFFD_WRITEWATCH +#endif +#endif + #include "ntstatus.h" #define WIN32_NO_STATUS #include "windef.h" @@ -220,6 +229,11 @@ static BYTE **pages_vprot; static BYTE *pages_vprot; #endif
+static int use_kernel_writewatch; +#ifdef HAVE_UFFD_WRITEWATCH +static int uffd_fd, pagemap_fd; +#endif + static struct file_view *view_block_start, *view_block_end, *next_free_view; static const size_t view_block_size = 0x100000; static void *preload_reserve_start; @@ -264,6 +278,160 @@ void *anon_mmap_alloc( size_t size, int prot ) return mmap( NULL, size, prot, MAP_PRIVATE | MAP_ANON, -1, 0 ); }
+#ifdef HAVE_UFFD_WRITEWATCH +static void kernel_writewatch_init(void) +{ + struct uffdio_api uffdio_api; + + uffd_fd = syscall( __NR_userfaultfd, O_CLOEXEC | O_NONBLOCK | UFFD_USER_MODE_ONLY ); + if (uffd_fd == -1) return; + + uffdio_api.api = UFFD_API; + uffdio_api.features = UFFD_FEATURE_WP_ASYNC | UFFD_FEATURE_WP_UNPOPULATED; + if (ioctl( uffd_fd, UFFDIO_API, &uffdio_api ) || uffdio_api.api != UFFD_API) + { + close( uffd_fd ); + return; + } + pagemap_fd = open( "/proc/self/pagemap", O_CLOEXEC | O_RDONLY ); + if (pagemap_fd == -1) + { + ERR( "Error opening /proc/self/pagemap.\n" ); + close( uffd_fd ); + return; + } + use_kernel_writewatch = 1; +} + +static void kernel_writewatch_reset( void *start, SIZE_T len ) +{ + struct pm_scan_arg arg = { 0 }; + + len = ROUND_SIZE( start, len, host_page_mask ); + start = (char *)ROUND_ADDR( start, host_page_mask ); + + arg.size = sizeof(arg); + arg.start = (UINT_PTR)start; + arg.end = arg.start + len; + arg.flags = PM_SCAN_WP_MATCHING; + arg.category_mask = PAGE_IS_WRITTEN; + arg.return_mask = PAGE_IS_WRITTEN; + if (ioctl( pagemap_fd, PAGEMAP_SCAN, &arg ) < 0) + ERR( "ioctl(PAGEMAP_SCAN) failed, err %s.\n", strerror(errno) ); +} + +static void kernel_writewatch_register_range( struct file_view *view, void *base, size_t size ) +{ + struct uffdio_register uffdio_register; + struct uffdio_writeprotect wp; + + if (!(view->protect & VPROT_WRITEWATCH) || !use_kernel_writewatch) return; + + size = ROUND_SIZE( base, size, host_page_mask ); + base = (char *)ROUND_ADDR( base, host_page_mask ); + + /* Transparent huge pages will result in larger areas reported as dirty. */ + madvise( base, size, MADV_NOHUGEPAGE ); + + uffdio_register.range.start = (UINT_PTR)base; + uffdio_register.range.len = size; + uffdio_register.mode = UFFDIO_REGISTER_MODE_WP; + if (ioctl( uffd_fd, UFFDIO_REGISTER, &uffdio_register ) == -1) + { + ERR( "ioctl( UFFDIO_REGISTER ) failed, %s.\n", strerror(errno) ); + return; + } + + if (!(uffdio_register.ioctls & UFFDIO_WRITEPROTECT)) + { + ERR( "uffdio_register.ioctls %s.\n", wine_dbgstr_longlong(uffdio_register.ioctls) ); + return; + } + wp.range.start = (UINT_PTR)base; + wp.range.len = size; + wp.mode = UFFDIO_WRITEPROTECT_MODE_WP; + + if (ioctl( uffd_fd, UFFDIO_WRITEPROTECT, &wp ) == -1) + ERR( "ioctl( UFFDIO_WRITEPROTECT ) failed, %s.\n", strerror(errno) ); +} + +static void kernel_get_write_watches( void *base, SIZE_T size, void **buffer, ULONG_PTR *count, BOOL reset ) +{ + struct pm_scan_arg arg = { 0 }; + struct page_region rgns[256]; + SIZE_T buffer_len = *count; + char *addr, *next_addr; + int rgn_count, i; + size_t end; + + assert( !(size & page_mask) ); + + end = (size_t)((char *)base + size); + size = ROUND_SIZE( base, size, host_page_mask ); + addr = (char *)ROUND_ADDR( base, host_page_mask ); + + arg.size = sizeof(arg); + arg.vec = (ULONG_PTR)rgns; + arg.vec_len = ARRAY_SIZE(rgns); + if (reset) arg.flags |= PM_SCAN_WP_MATCHING; + arg.category_mask = PAGE_IS_WRITTEN; + arg.return_mask = PAGE_IS_WRITTEN; + + *count = 0; + while (1) + { + arg.start = (UINT_PTR)addr; + arg.end = arg.start + size; + arg.max_pages = buffer_len; + + if ((rgn_count = ioctl( pagemap_fd, PAGEMAP_SCAN, &arg )) < 0) + { + ERR( "ioctl( PAGEMAP_SCAN ) failed, error %s.\n", strerror(errno) ); + return; + } + if (!rgn_count) break; + + assert( rgn_count <= ARRAY_SIZE(rgns) ); + for (i = 0; i < rgn_count; ++i) + { + size_t c_addr = max( rgns[i].start, (size_t)base ); + + rgns[i].end = min( rgns[i].end, end ); + assert( rgns[i].categories == PAGE_IS_WRITTEN ); + assert( buffer_len >= ((rgns[i].end - c_addr) >> page_shift) ); + while (buffer_len && c_addr < rgns[i].end) + { + buffer[(*count)++] = (void *)c_addr; + --buffer_len; + c_addr += page_size; + } + if (!buffer_len) break; + } + if (!buffer_len || rgn_count < arg.vec_len) break; + next_addr = (char *)(ULONG_PTR)arg.walk_end; + assert( size >= next_addr - addr ); + if (!(size -= next_addr - addr)) break; + addr = next_addr; + } +} +#else +static void kernel_writewatch_init(void) +{ +} + +static void kernel_writewatch_reset( void *start, SIZE_T len ) +{ +} + +static void kernel_writewatch_register_range( struct file_view *view, void *base, size_t size ) +{ +} + +static void kernel_get_write_watches( void *base, SIZE_T size, void **buffer, ULONG_PTR *count, BOOL reset ) +{ + assert( 0 ); +} +#endif
static void mmap_add_reserved_area( void *addr, SIZE_T size ) { @@ -1191,7 +1359,7 @@ static int get_unix_prot( BYTE vprot ) if (vprot & VPROT_WRITE) prot |= PROT_WRITE | PROT_READ; if (vprot & VPROT_WRITECOPY) prot |= PROT_WRITE | PROT_READ; if (vprot & VPROT_EXEC) prot |= PROT_EXEC | PROT_READ; - if (vprot & VPROT_WRITEWATCH) prot &= ~PROT_WRITE; + if (vprot & VPROT_WRITEWATCH && !use_kernel_writewatch) prot &= ~PROT_WRITE; } if (!prot) prot = PROT_NONE; return prot; @@ -1691,6 +1859,8 @@ static NTSTATUS create_view( struct file_view **view_ret, void *base, size_t siz TRACE( "forcing exec permission on %p-%p\n", base, (char *)base + size - 1 ); mprotect( base, size, unix_prot | PROT_EXEC ); } + + kernel_writewatch_register_range( view, view->base, view->size ); return STATUS_SUCCESS; }
@@ -1812,7 +1982,7 @@ static int mprotect_range( void *base, size_t size, BYTE set, BYTE clear ) */ static BOOL set_vprot( struct file_view *view, void *base, size_t size, BYTE vprot ) { - if (view->protect & VPROT_WRITEWATCH) + if (!use_kernel_writewatch && view->protect & VPROT_WRITEWATCH) { /* each page may need different protections depending on write watch flag */ set_page_vprot_bits( base, size, vprot & ~VPROT_WRITEWATCH, ~vprot & ~VPROT_WRITEWATCH ); @@ -1890,8 +2060,12 @@ static void update_write_watches( void *base, size_t size, size_t accessed_size */ static void reset_write_watches( void *base, SIZE_T size ) { - set_page_vprot_bits( base, size, VPROT_WRITEWATCH, 0 ); - mprotect_range( base, size, 0, 0 ); + if (use_kernel_writewatch) kernel_writewatch_reset( base, size ); + else + { + set_page_vprot_bits( base, size, VPROT_WRITEWATCH, 0 ); + mprotect_range( base, size, 0, 0 ); + } }
@@ -2085,7 +2259,11 @@ static NTSTATUS map_view( struct file_view **view_ret, void *base, size_t size,
view->protect = vprot | VPROT_PLACEHOLDER; set_vprot( view, base, size, vprot ); - if (vprot & VPROT_WRITEWATCH) reset_write_watches( base, size ); + if (vprot & VPROT_WRITEWATCH) + { + kernel_writewatch_register_range( view, base, size ); + reset_write_watches( base, size ); + } *view_ret = view; return STATUS_SUCCESS; } @@ -2305,6 +2483,7 @@ static NTSTATUS decommit_pages( struct file_view *view, char *base, size_t size
if (host_start < host_end) anon_mmap_fixed( host_start, host_end - host_start, PROT_NONE, 0 ); set_page_vprot_bits( base, size, 0, VPROT_COMMITTED ); + kernel_writewatch_register_range( view, host_start, host_end - host_start ); return STATUS_SUCCESS; }
@@ -3477,6 +3656,10 @@ void virtual_init(void) host_addr_space_limit = address_space_limit; #endif
+ kernel_writewatch_init(); + + if (use_kernel_writewatch) TRACE( "Using kernel write watches.\n" ); + if (preload_info && *preload_info) for (i = 0; (*preload_info)[i].size; i++) mmap_add_reserved_area( (*preload_info)[i].addr, (*preload_info)[i].size ); @@ -4201,7 +4384,7 @@ NTSTATUS virtual_handle_fault( EXCEPTION_RECORD *rec, void *stack ) } else ret = grow_thread_stack( page, &stack_info ); } - else if (err & EXCEPTION_WRITE_FAULT) + else if (!use_kernel_writewatch && err & EXCEPTION_WRITE_FAULT) { if (vprot & VPROT_WRITEWATCH) { @@ -4295,11 +4478,11 @@ static NTSTATUS check_write_access( void *base, size_t size, BOOL *has_write_wat for (i = 0; i < size; i += host_page_size) { BYTE vprot = get_host_page_vprot( addr + i ); - if (vprot & VPROT_WRITEWATCH) *has_write_watch = TRUE; + if (!use_kernel_writewatch && vprot & VPROT_WRITEWATCH) *has_write_watch = TRUE; if (!(get_unix_prot( vprot & ~VPROT_WRITEWATCH ) & PROT_WRITE)) return STATUS_INVALID_USER_BUFFER; } - if (*has_write_watch) + if (!use_kernel_writewatch && *has_write_watch) mprotect_range( addr, size, 0, VPROT_WRITEWATCH ); /* temporarily enable write access */ return STATUS_SUCCESS; } @@ -4341,7 +4524,7 @@ ssize_t virtual_locked_read( int fd, void *addr, size_t size ) int err = EFAULT;
ssize_t ret = read( fd, addr, size ); - if (ret != -1 || errno != EFAULT) return ret; + if (ret != -1 || use_kernel_writewatch || errno != EFAULT) return ret;
server_enter_uninterrupted_section( &virtual_mutex, &sigset ); if (!check_write_access( addr, size, &has_write_watch )) @@ -4366,7 +4549,7 @@ ssize_t virtual_locked_pread( int fd, void *addr, size_t size, off_t offset ) int err = EFAULT;
ssize_t ret = pread( fd, addr, size, offset ); - if (ret != -1 || errno != EFAULT) return ret; + if (ret != -1 || use_kernel_writewatch || errno != EFAULT) return ret;
server_enter_uninterrupted_section( &virtual_mutex, &sigset ); if (!check_write_access( addr, size, &has_write_watch )) @@ -4392,7 +4575,7 @@ ssize_t virtual_locked_recvmsg( int fd, struct msghdr *hdr, int flags ) int err = EFAULT;
ssize_t ret = recvmsg( fd, hdr, flags ); - if (ret != -1 || errno != EFAULT) return ret; + if (ret != -1 || use_kernel_writewatch || errno != EFAULT) return ret;
server_enter_uninterrupted_section( &virtual_mutex, &sigset ); for (i = 0; i < hdr->msg_iovlen; i++) @@ -6191,13 +6374,17 @@ NTSTATUS WINAPI NtGetWriteWatch( HANDLE process, ULONG flags, PVOID base, SIZE_T char *addr = base; char *end = addr + size;
- while (pos < *count && addr < end) + if (use_kernel_writewatch) kernel_get_write_watches( base, size, addresses, count, flags & WRITE_WATCH_FLAG_RESET ); + else { - if (!(get_page_vprot( addr ) & VPROT_WRITEWATCH)) addresses[pos++] = addr; - addr += page_size; + while (pos < *count && addr < end) + { + if (!(get_page_vprot( addr ) & VPROT_WRITEWATCH)) addresses[pos++] = addr; + addr += page_size; + } + if (flags & WRITE_WATCH_FLAG_RESET) reset_write_watches( base, addr - (char *)base ); + *count = pos; } - if (flags & WRITE_WATCH_FLAG_RESET) reset_write_watches( base, addr - (char *)base ); - *count = pos; *granularity = page_size; } else status = STATUS_INVALID_PARAMETER; diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c index 2ca5b9fc8b2..a3fb6744235 100644 --- a/dlls/ws2_32/tests/sock.c +++ b/dlls/ws2_32/tests/sock.c @@ -8113,6 +8113,7 @@ static void test_write_watch(void) ok( count == 9 || !count /* Win 11 */, "wrong count %Iu\n", count ); ok( !base[0], "data set\n" );
+ base[0x1000] = 1; send(src, "test message", sizeof("test message"), 0);
ret = GetOverlappedResult( (HANDLE)dest, &ov, &bytesReturned, TRUE ); @@ -8121,10 +8122,19 @@ static void test_write_watch(void) ok( !memcmp( base, "test ", 5 ), "wrong data %s\n", base ); ok( !memcmp( base + 0x4000, "message", 8 ), "wrong data %s\n", base + 0x4000 );
+ count = 64; + ret = pGetWriteWatch( 0, base, size, results, &count, &pagesize ); + ok( !ret, " GetWriteWatch failed %lu\n", GetLastError() ); + todo_wine_if( count == 3 ) ok( count == 1, "wrong count %Iu\n", count ); + todo_wine_if( count == 3 ) ok( results[0] == base + 0x1000, "got page %Iu.\n", ((char *)results[0] - base) / 0x1000 ); + + base[0x2000] = 1; count = 64; ret = pGetWriteWatch( WRITE_WATCH_FLAG_RESET, base, size, results, &count, &pagesize ); ok( !ret, "GetWriteWatch failed %lu\n", GetLastError() ); - ok( count == 0, "wrong count %Iu\n", count ); + todo_wine_if( count == 4 ) ok( count == 2, "wrong count %Iu\n", count ); + todo_wine_if( count == 4 ) ok( results[0] == base + 0x1000, "got page %Iu.\n", ((char *)results[0] - base) / 0x1000 ); + todo_wine_if( count == 4 ) ok( results[1] == base + 0x2000, "got page %Iu.\n", ((char *)results[1] - base) / 0x1000 );
memset( base, 0, size ); count = 64; @@ -8155,7 +8165,7 @@ static void test_write_watch(void) count = 64; ret = pGetWriteWatch( WRITE_WATCH_FLAG_RESET, base, size, results, &count, &pagesize ); ok( !ret, "GetWriteWatch failed %lu\n", GetLastError() ); - ok( count == 0, "wrong count %Iu\n", count ); + todo_wine_if( count == 2 ) ok( count == 0, "wrong count %Iu\n", count );
memset( base, 0, size ); count = 64; @@ -8184,7 +8194,7 @@ static void test_write_watch(void) count = 64; ret = pGetWriteWatch( WRITE_WATCH_FLAG_RESET, base, size, results, &count, &pagesize ); ok( !ret, "GetWriteWatch failed %lu\n", GetLastError() ); - ok( count == 0, "wrong count %Iu\n", count ); + todo_wine_if( count == 1 ) ok( count == 0, "wrong count %Iu\n", count ); } WSACloseEvent( event ); closesocket( dest );