From: Francisco Casas fcasas@codeweavers.com
---
I decided to keep the 'dst_type->type <= HLSL_CLASS_VECTOR' branch because, even though the code in the new branch is able to handle these cases, it would result in less efficient code, since we don't have a vectorization pass yet. --- libs/vkd3d-shader/hlsl.h | 3 ++ libs/vkd3d-shader/hlsl.y | 6 ++-- libs/vkd3d-shader/hlsl_codegen.c | 58 ++++++++++++++++++++++---------- tests/cast-broadcast.shader_test | 4 +-- 4 files changed, 48 insertions(+), 23 deletions(-)
diff --git a/libs/vkd3d-shader/hlsl.h b/libs/vkd3d-shader/hlsl.h index 794749aa..fa0adb07 100644 --- a/libs/vkd3d-shader/hlsl.h +++ b/libs/vkd3d-shader/hlsl.h @@ -810,6 +810,9 @@ void hlsl_pop_scope(struct hlsl_ctx *ctx);
bool hlsl_scope_add_type(struct hlsl_scope *scope, struct hlsl_type *type);
+void hlsl_initialize_var_components(struct hlsl_ctx *ctx, struct list *instrs, + struct hlsl_ir_var *dst, unsigned int *store_index, struct hlsl_ir_node *src); + struct hlsl_type *hlsl_type_clone(struct hlsl_ctx *ctx, struct hlsl_type *old, unsigned int default_majority, unsigned int modifiers); unsigned int hlsl_type_component_count(const struct hlsl_type *type); diff --git a/libs/vkd3d-shader/hlsl.y b/libs/vkd3d-shader/hlsl.y index 624481d8..f8bc7327 100644 --- a/libs/vkd3d-shader/hlsl.y +++ b/libs/vkd3d-shader/hlsl.y @@ -1761,7 +1761,7 @@ static bool add_increment(struct hlsl_ctx *ctx, struct list *instrs, bool decrem return true; }
-static void initialize_var_components(struct hlsl_ctx *ctx, struct list *instrs, +void hlsl_initialize_var_components(struct hlsl_ctx *ctx, struct list *instrs, struct hlsl_ir_var *dst, unsigned int *store_index, struct hlsl_ir_node *src) { unsigned int src_comp_count = hlsl_type_component_count(src->data_type); @@ -2000,7 +2000,7 @@ static struct list *declare_vars(struct hlsl_ctx *ctx, struct hlsl_type *basic_t
for (k = 0; k < v->initializer.args_count; ++k) { - initialize_var_components(ctx, v->initializer.instrs, var, + hlsl_initialize_var_components(ctx, v->initializer.instrs, var, &store_index, v->initializer.args[k]); } } @@ -2480,7 +2480,7 @@ static struct list *add_constructor(struct hlsl_ctx *ctx, struct hlsl_type *type continue; }
- initialize_var_components(ctx, params->instrs, var, &idx, arg); + hlsl_initialize_var_components(ctx, params->instrs, var, &idx, arg); }
if (!(load = hlsl_new_var_load(ctx, var, loc))) diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index f919a30e..72f7b498 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -454,23 +454,26 @@ static bool transform_ir(struct hlsl_ctx *ctx, bool (*func)(struct hlsl_ctx *ctx return progress; }
-/* Lower casts from vec1 to vecN to swizzles. */ +/* Lower broadcasts casts from scalars or vec1 values to other types. */ static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) { - const struct hlsl_type *src_type, *dst_type; - struct hlsl_type *dst_scalar_type; + struct hlsl_type *src_type, *dst_type; struct hlsl_ir_expr *cast;
if (instr->type != HLSL_IR_EXPR) return false; cast = hlsl_ir_expr(instr); + if (cast->op != HLSL_OP1_CAST) + return false; src_type = cast->operands[0].node->data_type; dst_type = cast->node.data_type;
- if (cast->op == HLSL_OP1_CAST - && src_type->type <= HLSL_CLASS_VECTOR && dst_type->type <= HLSL_CLASS_VECTOR - && src_type->dimx == 1) + if (src_type->type > HLSL_CLASS_VECTOR || src_type->dimx != 1) + return false; + + if (dst_type->type <= HLSL_CLASS_VECTOR) { + struct hlsl_type *dst_scalar_type; struct hlsl_ir_node *replacement; struct hlsl_ir_swizzle *swizzle; struct hlsl_ir_expr *new_cast; @@ -494,6 +497,37 @@ static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, v hlsl_replace_node(&cast->node, replacement); return true; } + else if (HLSL_CLASS_VECTOR < dst_type->type) + { + unsigned int size = hlsl_type_component_count(dst_type); + struct vkd3d_string_buffer *string; + static unsigned int counter = 0; + struct list *broadcast_instrs; + unsigned int store_index = 0; + struct hlsl_ir_load *load; + struct hlsl_ir_var *var; + + if (!(string = hlsl_get_string_buffer(ctx))) + return false; + vkd3d_string_buffer_printf(string, "<broadcast-%x>", counter++); + if (!(var = hlsl_new_synthetic_var(ctx, string->buffer, dst_type, instr->loc))) + return false; + hlsl_release_string_buffer(ctx, string); + + if (!(broadcast_instrs = hlsl_alloc(ctx, sizeof(*broadcast_instrs)))) + return false; + list_init(broadcast_instrs); + + while (store_index < size) + hlsl_initialize_var_components(ctx, broadcast_instrs, var, &store_index, cast->operands[0].node); + + list_move_before(&cast->node.entry, broadcast_instrs); + vkd3d_free(broadcast_instrs); + + load = hlsl_new_var_load(ctx, var, var->loc); + list_add_before(&cast->node.entry, &load->node.entry); + hlsl_replace_node(&cast->node, &load->node); + }
return false; } @@ -1083,12 +1117,6 @@ static bool split_array_copies(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, return false; element_type = type->e.array.type;
- if (rhs->type != HLSL_IR_LOAD) - { - hlsl_fixme(ctx, &instr->loc, "Array store rhs is not HLSL_IR_LOAD. Broadcast may be missing."); - return false; - } - for (i = 0; i < type->e.array.elements_count; ++i) { if (!split_copy(ctx, store, hlsl_ir_load(rhs), i, element_type)) @@ -1119,12 +1147,6 @@ static bool split_struct_copies(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr if (type->type != HLSL_CLASS_STRUCT) return false;
- if (rhs->type != HLSL_IR_LOAD) - { - hlsl_fixme(ctx, &instr->loc, "Struct store rhs is not HLSL_IR_LOAD. Broadcast may be missing."); - return false; - } - for (i = 0; i < type->e.record.field_count; ++i) { const struct hlsl_struct_field *field = &type->e.record.fields[i]; diff --git a/tests/cast-broadcast.shader_test b/tests/cast-broadcast.shader_test index 02d14c0b..8b1948f5 100644 --- a/tests/cast-broadcast.shader_test +++ b/tests/cast-broadcast.shader_test @@ -20,5 +20,5 @@ float4 main() : SV_TARGET }
[test] -todo draw quad -todo probe all rgba (84.0, 84.0, 84.0, 84.0) +draw quad +probe all rgba (84.0, 84.0, 84.0, 84.0)