Module: vkd3d Branch: master Commit: fef5760af061b3d3536911681005938e99b34ae8 URL: https://gitlab.winehq.org/wine/vkd3d/-/commit/fef5760af061b3d353691168100593...
Author: Conor McCarthy cmccarthy@codeweavers.com Date: Tue Apr 23 21:14:22 2024 +1000
vkd3d-shader/spirv: Implement the WAVE_ACTIVE_BIT_* instructions.
---
libs/vkd3d-shader/spirv.c | 44 ++++++++++++++++++++++++++++++++++++ tests/hlsl/wave-ops-uint.shader_test | 6 ++--- 2 files changed, 47 insertions(+), 3 deletions(-)
diff --git a/libs/vkd3d-shader/spirv.c b/libs/vkd3d-shader/spirv.c index 7dc8feea..60c08be2 100644 --- a/libs/vkd3d-shader/spirv.c +++ b/libs/vkd3d-shader/spirv.c @@ -9816,6 +9816,45 @@ static void spirv_compiler_emit_wave_active_ballot(struct spirv_compiler *compil spirv_compiler_emit_store_dst(compiler, dst, val_id); }
+static SpvOp map_wave_alu_op(enum vkd3d_shader_opcode handler_idx, bool is_float) +{ + switch (handler_idx) + { + case VKD3DSIH_WAVE_ACTIVE_BIT_AND: + return SpvOpGroupNonUniformBitwiseAnd; + case VKD3DSIH_WAVE_ACTIVE_BIT_OR: + return SpvOpGroupNonUniformBitwiseOr; + case VKD3DSIH_WAVE_ACTIVE_BIT_XOR: + return SpvOpGroupNonUniformBitwiseXor; + default: + vkd3d_unreachable(); + } +} + +static void spirv_compiler_emit_wave_alu_op(struct spirv_compiler *compiler, + const struct vkd3d_shader_instruction *instruction) +{ + struct vkd3d_spirv_builder *builder = &compiler->spirv_builder; + const struct vkd3d_shader_dst_param *dst = instruction->dst; + const struct vkd3d_shader_src_param *src = instruction->src; + uint32_t type_id, val_id; + SpvOp op; + + op = map_wave_alu_op(instruction->handler_idx, data_type_is_floating_point(src->reg.data_type)); + + type_id = vkd3d_spirv_get_type_id_for_data_type(builder, dst->reg.data_type, + vsir_write_mask_component_count(dst->write_mask)); + val_id = spirv_compiler_emit_load_src(compiler, &src[0], dst->write_mask); + + vkd3d_spirv_enable_capability(builder, SpvCapabilityGroupNonUniformArithmetic); + val_id = vkd3d_spirv_build_op_tr3(builder, &builder->function_stream, op, type_id, + vkd3d_spirv_get_op_scope_subgroup(builder), + SpvGroupOperationReduce, + val_id); + + spirv_compiler_emit_store_dst(compiler, dst, val_id); +} + /* This function is called after declarations are processed. */ static void spirv_compiler_emit_main_prolog(struct spirv_compiler *compiler) { @@ -10168,6 +10207,11 @@ static int spirv_compiler_handle_instruction(struct spirv_compiler *compiler, case VKD3DSIH_WAVE_ACTIVE_BALLOT: spirv_compiler_emit_wave_active_ballot(compiler, instruction); break; + case VKD3DSIH_WAVE_ACTIVE_BIT_AND: + case VKD3DSIH_WAVE_ACTIVE_BIT_OR: + case VKD3DSIH_WAVE_ACTIVE_BIT_XOR: + spirv_compiler_emit_wave_alu_op(compiler, instruction); + break; case VKD3DSIH_DCL: case VKD3DSIH_DCL_HS_MAX_TESSFACTOR: case VKD3DSIH_DCL_INPUT_CONTROL_POINT_COUNT: diff --git a/tests/hlsl/wave-ops-uint.shader_test b/tests/hlsl/wave-ops-uint.shader_test index cf2daebd..cb502e86 100644 --- a/tests/hlsl/wave-ops-uint.shader_test +++ b/tests/hlsl/wave-ops-uint.shader_test @@ -232,7 +232,7 @@ void main(uint id : SV_GroupIndex) }
[test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rui (8) probe uav 1 (1) rui (8) probe uav 1 (2) rui (8) @@ -250,7 +250,7 @@ void main(uint id : SV_GroupIndex) }
[test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rui (15) probe uav 1 (1) rui (15) probe uav 1 (2) rui (15) @@ -268,7 +268,7 @@ void main(uint id : SV_GroupIndex) }
[test] -todo dispatch 4 1 1 +dispatch 4 1 1 probe uav 1 (0) rui (5) probe uav 1 (1) rui (5) probe uav 1 (2) rui (5)