From: Aric Stewart <aric(a)codeweavers.com>
Signed-off-by: Zebediah Figura <z.figura12(a)gmail.com>
---
 dlls/ntoskrnl.exe/ntoskrnl.c | 110 +++++++++++++++++++++++++++++++++++++------
 1 file changed, 96 insertions(+), 14 deletions(-)
diff --git a/dlls/ntoskrnl.exe/ntoskrnl.c b/dlls/ntoskrnl.exe/ntoskrnl.c
index e52cb168a8..aec37daa59 100644
--- a/dlls/ntoskrnl.exe/ntoskrnl.c
+++ b/dlls/ntoskrnl.exe/ntoskrnl.c
@@ -102,6 +102,16 @@ struct wine_driver
     SERVICE_STATUS_HANDLE service_handle;
 };
 
+struct device_interface
+{
+    struct wine_rb_entry entry;
+
+    UNICODE_STRING symbolic_link;
+    DEVICE_OBJECT *device;
+    GUID interface_class;
+    BOOL enabled;
+};
+
 static NTSTATUS get_device_id( DEVICE_OBJECT *device, BUS_QUERY_ID_TYPE type, WCHAR **id );
 
 static int wine_drivers_rb_compare( const void *key, const struct wine_rb_entry *entry )
@@ -114,6 +124,16 @@ static int wine_drivers_rb_compare( const void *key, const struct wine_rb_entry
 
 static struct wine_rb_tree wine_drivers = { wine_drivers_rb_compare };
 
+static int interface_rb_compare( const void *key, const struct wine_rb_entry *entry)
+{
+    const struct device_interface *iface = WINE_RB_ENTRY_VALUE( entry, const struct device_interface, entry );
+    const UNICODE_STRING *k = key;
+
+    return RtlCompareUnicodeString( k, &iface->symbolic_link, FALSE );
+}
+
+static struct wine_rb_tree device_interfaces = { interface_rb_compare };
+
 static CRITICAL_SECTION drivers_cs;
 static CRITICAL_SECTION_DEBUG critsect_debug =
 {
@@ -1321,6 +1341,30 @@ NTSTATUS WINAPI IoDeleteSymbolicLink( UNICODE_STRING *name )
     return status;
 }
 
+static NTSTATUS create_device_symlink( DEVICE_OBJECT *device, UNICODE_STRING *symlink_name )
+{
+    UNICODE_STRING device_nameU;
+    WCHAR *device_name;
+    ULONG len = 0;
+    NTSTATUS ret;
+
+    ret = IoGetDeviceProperty( device, DevicePropertyPhysicalDeviceObjectName, 0, NULL, &len );
+    if (ret != STATUS_BUFFER_TOO_SMALL)
+        return ret;
+
+    device_name = heap_alloc( len );
+    ret = IoGetDeviceProperty( device, DevicePropertyPhysicalDeviceObjectName, len, device_name, &len );
+    if (ret)
+    {
+        heap_free( device_name );
+        return ret;
+    }
+
+    RtlInitUnicodeString( &device_nameU, device_name );
+    ret = IoCreateSymbolicLink( symlink_name, &device_nameU );
+    heap_free( device_name );
+    return ret;
+}
 
 /***********************************************************************
  *           IoSetDeviceInterfaceState   (NTOSKRNL.EXE.@)
@@ -1339,16 +1383,31 @@ NTSTATUS WINAPI IoSetDeviceInterfaceState( UNICODE_STRING *name, BOOLEAN enable
 
     size_t namelen = name->Length / sizeof(WCHAR);
     DEV_BROADCAST_DEVICEINTERFACE_W *broadcast;
+    struct device_interface *iface;
     HANDLE iface_key, control_key;
     OBJECT_ATTRIBUTES attr = {0};
+    struct wine_rb_entry *entry;
     WCHAR *path, *refstr, *p;
     UNICODE_STRING string;
+    DWORD data = enable;
     NTSTATUS ret;
-    size_t len;
     GUID class;
+    ULONG len;
 
     TRACE("(%s, %d)\n", debugstr_us(name), enable);
 
+    entry = wine_rb_get( &device_interfaces, name );
+    if (!entry)
+        return STATUS_OBJECT_NAME_NOT_FOUND;
+
+    iface = WINE_RB_ENTRY_VALUE( entry, struct device_interface, entry );
+
+    if (!enable && !iface->enabled)
+        return STATUS_OBJECT_NAME_NOT_FOUND;
+
+    if (enable && iface->enabled)
+        return STATUS_OBJECT_NAME_EXISTS;
+
     refstr = memrchrW(name->Buffer + 4, '\\', namelen - 4);
 
     if (!guid_from_string( (refstr ? refstr : name->Buffer + namelen) - 38, &class ))
@@ -1374,22 +1433,37 @@ NTSTATUS WINAPI IoSetDeviceInterfaceState( UNICODE_STRING *name, BOOLEAN enable
     attr.ObjectName = &string;
     RtlInitUnicodeString( &string, path );
     ret = NtOpenKey( &iface_key, KEY_CREATE_SUB_KEY, &attr );
-    if (!ret)
+    heap_free(path);
+    if (ret)
+        return ret;
+
+    attr.RootDirectory = iface_key;
+    RtlInitUnicodeString( &string, controlW );
+    ret = NtCreateKey( &control_key, KEY_SET_VALUE, &attr, 0, NULL, 0, NULL );
+    NtClose( iface_key );
+    if (ret)
+        return ret;
+
+    RtlInitUnicodeString( &string, linkedW );
+    ret = NtSetValueKey( control_key, &string, 0, REG_DWORD, &data, sizeof(data) );
+    if (ret)
     {
-        attr.RootDirectory = iface_key;
-        RtlInitUnicodeString( &string, controlW );
-        ret = NtCreateKey( &control_key, KEY_SET_VALUE, &attr, 0, NULL, 0, NULL );
-        if (!ret)
-        {
-            DWORD data = enable;
-            RtlInitUnicodeString( &string, linkedW );
-            ret = NtSetValueKey( control_key, &string, 0, REG_DWORD, &data, sizeof(data) );
-            NtClose( control_key );
-        }
-        NtClose( iface_key );
+        NtClose( control_key );
+        return ret;
     }
 
-    heap_free( path );
+    if (enable)
+        ret = create_device_symlink( iface->device, name );
+    else
+        ret = IoDeleteSymbolicLink( name );
+    if (ret)
+    {
+        NtDeleteValueKey( control_key, &string );
+        NtClose( control_key );
+        return ret;
+    }
+
+    iface->enabled = enable;
 
     len = offsetof(DEV_BROADCAST_DEVICEINTERFACE_W, dbcc_name[namelen + 1]);
 
@@ -1696,6 +1770,7 @@ NTSTATUS WINAPI IoRegisterDeviceInterface(DEVICE_OBJECT *device, const GUID *cla
     SP_DEVICE_INTERFACE_DETAIL_DATA_W *data;
     DWORD required;
     BOOL rc;
+    struct device_interface *iface;
 
     TRACE( "(%p, %s, %s, %p)\n", device, debugstr_guid(class_guid), debugstr_us(reference_string), symbolic_link );
 
@@ -1766,9 +1841,16 @@ NTSTATUS WINAPI IoRegisterDeviceInterface(DEVICE_OBJECT *device, const GUID *cla
     data->DevicePath[1] = '?';
     TRACE( "Device path %s\n",debugstr_w(data->DevicePath) );
 
+    iface = heap_alloc_zero( sizeof(struct device_interface) );
+    iface->device = device;
+    iface->interface_class = *class_guid;
+    RtlCreateUnicodeString(&iface->symbolic_link, data->DevicePath);
     if (symbolic_link)
         RtlCreateUnicodeString( symbolic_link, data->DevicePath);
 
+    if (wine_rb_put( &device_interfaces, &iface->symbolic_link, &iface->entry ))
+        ERR( "failed to insert interface %s into tree\n", debugstr_us(&iface->symbolic_link) );
+
     HeapFree( GetProcessHeap(), 0, data );
 
     return status;
-- 
2.14.1