From HID_IOCTL_GET_INPUT_REPORT code, to factor report buffer transfer,
and use it for HID_IOCTL_GET_FEATURE.
Signed-off-by: Rémi Bernon rbernon@codeweavers.com --- dlls/hidclass.sys/device.c | 100 ++++++++++----------------- dlls/ntoskrnl.exe/tests/driver_hid.c | 3 +- dlls/ntoskrnl.exe/tests/ntoskrnl.c | 10 +-- 3 files changed, 39 insertions(+), 74 deletions(-)
diff --git a/dlls/hidclass.sys/device.c b/dlls/hidclass.sys/device.c index a7cb6a843e0..81ddd7e6d7e 100644 --- a/dlls/hidclass.sys/device.c +++ b/dlls/hidclass.sys/device.c @@ -299,46 +299,48 @@ static void handle_minidriver_string( BASE_DEVICE_EXTENSION *ext, IRP *irp, SHOR } }
-static void HID_get_feature( BASE_DEVICE_EXTENSION *ext, IRP *irp ) +static void hid_device_xfer_report( BASE_DEVICE_EXTENSION *ext, ULONG code, IRP *irp ) { - IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation( irp ); - HID_XFER_PACKET *packet; - DWORD len; - BYTE *out_buffer; - - irp->IoStatus.Information = 0; + const WINE_HIDP_PREPARSED_DATA *preparsed = ext->u.pdo.preparsed_data; + IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation( irp ); + ULONG report_len = 0, buffer_len = stack->Parameters.DeviceIoControl.OutputBufferLength; + BYTE *buffer = MmGetSystemAddressForMdlSafe( irp->MdlAddress, NormalPagePriority ); + BYTE report_id = HID_INPUT_VALUE_CAPS( preparsed )->report_id; + HID_XFER_PACKET packet;
- out_buffer = MmGetSystemAddressForMdlSafe(irp->MdlAddress, NormalPagePriority); - TRACE_(hid_report)("Device %p Buffer length %i Buffer %p\n", ext, irpsp->Parameters.DeviceIoControl.OutputBufferLength, out_buffer); + switch (code) + { + case IOCTL_HID_GET_INPUT_REPORT: + report_len = preparsed->caps.InputReportByteLength; + break; + case IOCTL_HID_GET_FEATURE: + report_len = preparsed->caps.FeatureReportByteLength; + break; + }
- if (!irpsp->Parameters.DeviceIoControl.OutputBufferLength || !out_buffer) + if (!buffer) { - irp->IoStatus.Status = STATUS_BUFFER_TOO_SMALL; + irp->IoStatus.Status = STATUS_INVALID_USER_BUFFER; + return; + } + if (buffer_len < report_len) + { + irp->IoStatus.Status = STATUS_INVALID_PARAMETER; return; }
- len = sizeof(*packet) + irpsp->Parameters.DeviceIoControl.OutputBufferLength; - packet = malloc(len); - packet->reportBufferLen = irpsp->Parameters.DeviceIoControl.OutputBufferLength; - packet->reportBuffer = ((BYTE*)packet) + sizeof(*packet); - packet->reportId = out_buffer[0]; - - TRACE_(hid_report)("(id %i, len %i buffer %p)\n", packet->reportId, packet->reportBufferLen, packet->reportBuffer); - - call_minidriver( IOCTL_HID_GET_FEATURE, ext->u.pdo.parent_fdo, NULL, 0, packet, sizeof(*packet), - &irp->IoStatus ); + packet.reportId = buffer[0]; + packet.reportBuffer = buffer; + packet.reportBufferLen = buffer_len;
- if (irp->IoStatus.Status == STATUS_SUCCESS) + if (!report_id) { - irp->IoStatus.Information = packet->reportBufferLen; - memcpy(out_buffer, packet->reportBuffer, packet->reportBufferLen); + packet.reportId = 0; + packet.reportBuffer++; + packet.reportBufferLen--; } - else - irp->IoStatus.Information = 0; - - TRACE_(hid_report)( "Result 0x%x get %li bytes\n", irp->IoStatus.Status, irp->IoStatus.Information );
- free(packet); + call_minidriver( code, ext->u.pdo.parent_fdo, NULL, 0, &packet, sizeof(packet), &irp->IoStatus ); }
static void HID_set_to_device( DEVICE_OBJECT *device, IRP *irp ) @@ -390,10 +392,9 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp) { IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation( irp ); BASE_DEVICE_EXTENSION *ext = device->DeviceExtension; - const WINE_HIDP_PREPARSED_DATA *data = ext->u.pdo.preparsed_data; - BYTE report_id = HID_INPUT_VALUE_CAPS( data )->report_id; NTSTATUS status; BOOL removed; + ULONG code; KIRQL irql;
irp->IoStatus.Information = 0; @@ -411,7 +412,7 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp) return STATUS_DELETE_PENDING; }
- switch (irpsp->Parameters.DeviceIoControl.IoControlCode) + switch ((code = irpsp->Parameters.DeviceIoControl.IoControlCode)) { case IOCTL_HID_GET_POLL_FREQUENCY_MSEC: TRACE("IOCTL_HID_GET_POLL_FREQUENCY_MSEC\n"); @@ -469,38 +470,6 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp) handle_IOCTL_HID_GET_COLLECTION_DESCRIPTOR( irp, ext ); break; } - case IOCTL_HID_GET_INPUT_REPORT: - { - HID_XFER_PACKET packet; - ULONG buffer_len = irpsp->Parameters.DeviceIoControl.OutputBufferLength; - BYTE *buffer = MmGetSystemAddressForMdlSafe( irp->MdlAddress, NormalPagePriority ); - - if (!buffer) - { - irp->IoStatus.Status = STATUS_INVALID_USER_BUFFER; - break; - } - if (buffer_len < data->caps.InputReportByteLength) - { - irp->IoStatus.Status = STATUS_INVALID_PARAMETER; - break; - } - - packet.reportId = buffer[0]; - packet.reportBuffer = buffer; - packet.reportBufferLen = buffer_len; - - if (!report_id) - { - packet.reportId = 0; - packet.reportBuffer++; - packet.reportBufferLen--; - } - - call_minidriver( IOCTL_HID_GET_INPUT_REPORT, ext->u.pdo.parent_fdo, NULL, 0, &packet, - sizeof(packet), &irp->IoStatus ); - break; - } case IOCTL_SET_NUM_DEVICE_INPUT_BUFFERS: { irp->IoStatus.Information = 0; @@ -527,7 +496,8 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp) break; } case IOCTL_HID_GET_FEATURE: - HID_get_feature( ext, irp ); + case IOCTL_HID_GET_INPUT_REPORT: + hid_device_xfer_report( ext, code, irp ); break; case IOCTL_HID_SET_FEATURE: case IOCTL_HID_SET_OUTPUT_REPORT: diff --git a/dlls/ntoskrnl.exe/tests/driver_hid.c b/dlls/ntoskrnl.exe/tests/driver_hid.c index 831b08c5c97..d268b62d288 100644 --- a/dlls/ntoskrnl.exe/tests/driver_hid.c +++ b/dlls/ntoskrnl.exe/tests/driver_hid.c @@ -570,9 +570,8 @@ static NTSTATUS WINAPI driver_internal_ioctl(DEVICE_OBJECT *device, IRP *irp) ok(!in_size, "got input size %u\n", in_size); ok(out_size == sizeof(*packet), "got output size %u\n", out_size);
- todo_wine_if(packet->reportId == 0x5a || packet->reportId == 0xa5) + todo_wine_if(packet->reportId == 0x5a) ok(packet->reportId == report_id, "got id %u\n", packet->reportId); - todo_wine_if(packet->reportBufferLen == 16) ok(packet->reportBufferLen >= expected_size, "got len %u\n", packet->reportBufferLen); ok(!!packet->reportBuffer, "got buffer %p\n", packet->reportBuffer);
diff --git a/dlls/ntoskrnl.exe/tests/ntoskrnl.c b/dlls/ntoskrnl.exe/tests/ntoskrnl.c index 17d82fa3482..f5ebf21b7f9 100644 --- a/dlls/ntoskrnl.exe/tests/ntoskrnl.c +++ b/dlls/ntoskrnl.exe/tests/ntoskrnl.c @@ -2486,13 +2486,11 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled SetLastError(0xdeadbeef); ret = HidD_GetFeature(file, report, 0); ok(!ret, "HidD_GetFeature succeeded\n"); - todo_wine ok(GetLastError() == ERROR_INVALID_USER_BUFFER, "HidD_GetFeature returned error %u\n", GetLastError()); + ok(GetLastError() == ERROR_INVALID_USER_BUFFER, "HidD_GetFeature returned error %u\n", GetLastError());
SetLastError(0xdeadbeef); ret = HidD_GetFeature(file, report, caps.FeatureReportByteLength - 1); - todo_wine ok(!ret, "HidD_GetFeature succeeded\n"); - todo_wine ok(GetLastError() == ERROR_INVALID_PARAMETER || broken(GetLastError() == ERROR_CRC), "HidD_GetFeature returned error %u\n", GetLastError());
@@ -2510,21 +2508,19 @@ static void test_hidp(HANDLE file, HANDLE async_file, int report_id, BOOL polled else { ok(ret, "HidD_GetFeature failed, last error %u\n", GetLastError()); - todo_wine ok(buffer[0] == 0x5a, "got buffer[0] %x, expected 0x5a\n", (BYTE)buffer[0]); + ok(buffer[0] == 0x5a, "got buffer[0] %x, expected 0x5a\n", (BYTE)buffer[0]); }
SetLastError(0xdeadbeef); ret = HidD_GetFeature(file, report, caps.FeatureReportByteLength); ok(ret, "HidD_GetFeature failed, last error %u\n", GetLastError()); - todo_wine_if(!report_id) ok(report[0] == report_id, "got report[0] %02x, expected %02x\n", report[0], report_id);
value = caps.FeatureReportByteLength * 2; SetLastError(0xdeadbeef); ret = sync_ioctl(file, IOCTL_HID_GET_FEATURE, NULL, 0, report, &value); ok(ret, "IOCTL_HID_GET_FEATURE failed, last error %u\n", GetLastError()); - todo_wine ok(value == 3, "got length %u, expected 3\n", value); - todo_wine_if(!report_id) + ok(value == 3, "got length %u, expected 3\n", value); ok(report[0] == report_id, "got report[0] %02x, expected %02x\n", report[0], report_id);