Function expr_common_shape can be used for boolean operators, for which a common shape must be determined even if the base type of the result is always bool.
Signed-off-by: Giovanni Mascellani gmascellani@codeweavers.com --- libs/vkd3d-shader/hlsl.y | 82 +++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 38 deletions(-)
diff --git a/libs/vkd3d-shader/hlsl.y b/libs/vkd3d-shader/hlsl.y index 61789def..5f320323 100644 --- a/libs/vkd3d-shader/hlsl.y +++ b/libs/vkd3d-shader/hlsl.y @@ -912,13 +912,9 @@ static enum hlsl_base_type expr_common_base_type(enum hlsl_base_type t1, enum hl return HLSL_TYPE_INT; }
-static struct hlsl_type *expr_common_type(struct hlsl_ctx *ctx, struct hlsl_type *t1, struct hlsl_type *t2, - struct vkd3d_shader_location *loc) +static bool expr_common_shape(struct hlsl_ctx *ctx, struct hlsl_type *t1, struct hlsl_type *t2, + struct vkd3d_shader_location *loc, enum hlsl_type_class *type, unsigned int *dimx, unsigned int *dimy) { - enum hlsl_type_class type; - enum hlsl_base_type base; - unsigned int dimx, dimy; - if (t1->type > HLSL_CLASS_LAST_NUMERIC) { struct vkd3d_string_buffer *string; @@ -927,7 +923,7 @@ static struct hlsl_type *expr_common_type(struct hlsl_ctx *ctx, struct hlsl_type hlsl_error(ctx, *loc, VKD3D_SHADER_ERROR_HLSL_INVALID_TYPE, "Expression of type "%s" cannot be used in a numeric expression.", string->buffer); hlsl_release_string_buffer(ctx, string); - return NULL; + return false; }
if (t2->type > HLSL_CLASS_LAST_NUMERIC) @@ -938,12 +934,9 @@ static struct hlsl_type *expr_common_type(struct hlsl_ctx *ctx, struct hlsl_type hlsl_error(ctx, *loc, VKD3D_SHADER_ERROR_HLSL_INVALID_TYPE, "Expression of type "%s" cannot be used in a numeric expression.", string->buffer); hlsl_release_string_buffer(ctx, string); - return NULL; + return false; }
- if (hlsl_types_are_equal(t1, t2)) - return t1; - if (!expr_compatible_data_types(t1, t2)) { struct vkd3d_string_buffer *t1_string = hlsl_type_to_string(ctx, t1); @@ -955,28 +948,26 @@ static struct hlsl_type *expr_common_type(struct hlsl_ctx *ctx, struct hlsl_type t1_string->buffer, t2_string->buffer); hlsl_release_string_buffer(ctx, t1_string); hlsl_release_string_buffer(ctx, t2_string); - return NULL; + return false; }
- base = expr_common_base_type(t1->base_type, t2->base_type); - if (t1->dimx == 1 && t1->dimy == 1) { - type = t2->type; - dimx = t2->dimx; - dimy = t2->dimy; + *type = t2->type; + *dimx = t2->dimx; + *dimy = t2->dimy; } else if (t2->dimx == 1 && t2->dimy == 1) { - type = t1->type; - dimx = t1->dimx; - dimy = t1->dimy; + *type = t1->type; + *dimx = t1->dimx; + *dimy = t1->dimy; } else if (t1->type == HLSL_CLASS_MATRIX && t2->type == HLSL_CLASS_MATRIX) { - type = HLSL_CLASS_MATRIX; - dimx = min(t1->dimx, t2->dimx); - dimy = min(t1->dimy, t2->dimy); + *type = HLSL_CLASS_MATRIX; + *dimx = min(t1->dimx, t2->dimx); + *dimy = min(t1->dimy, t2->dimy); } else { @@ -987,40 +978,55 @@ static struct hlsl_type *expr_common_type(struct hlsl_ctx *ctx, struct hlsl_type max_dim_2 = max(t2->dimx, t2->dimy); if (t1->dimx * t1->dimy == t2->dimx * t2->dimy) { - type = HLSL_CLASS_VECTOR; - dimx = max(t1->dimx, t2->dimx); - dimy = 1; + *type = HLSL_CLASS_VECTOR; + *dimx = max(t1->dimx, t2->dimx); + *dimy = 1; } else if (max_dim_1 <= max_dim_2) { - type = t1->type; - if (type == HLSL_CLASS_VECTOR) + *type = t1->type; + if (*type == HLSL_CLASS_VECTOR) { - dimx = max_dim_1; - dimy = 1; + *dimx = max_dim_1; + *dimy = 1; } else { - dimx = t1->dimx; - dimy = t1->dimy; + *dimx = t1->dimx; + *dimy = t1->dimy; } } else { - type = t2->type; - if (type == HLSL_CLASS_VECTOR) + *type = t2->type; + if (*type == HLSL_CLASS_VECTOR) { - dimx = max_dim_2; - dimy = 1; + *dimx = max_dim_2; + *dimy = 1; } else { - dimx = t2->dimx; - dimy = t2->dimy; + *dimx = t2->dimx; + *dimy = t2->dimy; } } }
+ return true; +} + +static struct hlsl_type *expr_common_type(struct hlsl_ctx *ctx, struct hlsl_type *t1, struct hlsl_type *t2, + struct vkd3d_shader_location *loc) +{ + enum hlsl_type_class type; + enum hlsl_base_type base; + unsigned int dimx, dimy; + + if (!expr_common_shape(ctx, t1, t2, loc, &type, &dimx, &dimy)) + return NULL; + + base = expr_common_base_type(t1->base_type, t2->base_type); + return hlsl_get_numeric_type(ctx, type, base, dimx, dimy); }