From: Zebediah Figura zfigura@codeweavers.com
--- include/private/list.h | 24 ++++ libs/vkd3d-shader/hlsl.c | 14 +++ libs/vkd3d-shader/hlsl.h | 5 + libs/vkd3d-shader/hlsl_codegen.c | 186 +++++++++++++++++++++++++++++++ tests/return.shader_test | 14 +-- 5 files changed, 236 insertions(+), 7 deletions(-)
diff --git a/include/private/list.h b/include/private/list.h index 5e92cfb2..2e1d95f3 100644 --- a/include/private/list.h +++ b/include/private/list.h @@ -186,6 +186,30 @@ static inline void list_move_tail( struct list *dst, struct list *src ) list_move_before( dst, src ); }
+/* move the slice of elements from begin to end inclusive to the head of dst */ +static inline void list_move_slice_head( struct list *dst, struct list *begin, struct list *end ) +{ + struct list *dst_next = dst->next; + begin->prev->next = end->next; + end->next->prev = begin->prev; + dst->next = begin; + dst_next->prev = end; + begin->prev = dst; + end->next = dst_next; +} + +/* move the slice of elements from begin to end inclusive to the tail of dst */ +static inline void list_move_slice_tail( struct list *dst, struct list *begin, struct list *end ) +{ + struct list *dst_prev = dst->prev; + begin->prev->next = end->next; + end->next->prev = begin->prev; + dst_prev->next = begin; + dst->prev = end; + begin->prev = dst_prev; + end->next = dst; +} + /* iterate through the list */ #define LIST_FOR_EACH(cursor,list) \ for ((cursor) = (list)->next; (cursor) != (list); (cursor) = (cursor)->next) diff --git a/libs/vkd3d-shader/hlsl.c b/libs/vkd3d-shader/hlsl.c index 3656d05d..0375e226 100644 --- a/libs/vkd3d-shader/hlsl.c +++ b/libs/vkd3d-shader/hlsl.c @@ -1209,6 +1209,8 @@ struct hlsl_ir_function_decl *hlsl_new_func_decl(struct hlsl_ctx *ctx, const struct hlsl_semantic *semantic, const struct vkd3d_shader_location *loc) { struct hlsl_ir_function_decl *decl; + struct hlsl_ir_constant *constant; + struct hlsl_ir_store *store;
if (!(decl = hlsl_alloc(ctx, sizeof(*decl)))) return NULL; @@ -1227,6 +1229,18 @@ struct hlsl_ir_function_decl *hlsl_new_func_decl(struct hlsl_ctx *ctx, decl->return_var->semantic = *semantic; }
+ if (!(decl->early_return_var = hlsl_new_synthetic_var(ctx, "early_return", + hlsl_get_scalar_type(ctx, HLSL_TYPE_BOOL), loc))) + return decl; + + if (!(constant = hlsl_new_bool_constant(ctx, false, loc))) + return decl; + list_add_tail(&decl->body.instrs, &constant->node.entry); + + if (!(store = hlsl_new_simple_store(ctx, decl->early_return_var, &constant->node))) + return decl; + list_add_tail(&decl->body.instrs, &store->node.entry); + return decl; }
diff --git a/libs/vkd3d-shader/hlsl.h b/libs/vkd3d-shader/hlsl.h index d615ff8a..bc0c4b4c 100644 --- a/libs/vkd3d-shader/hlsl.h +++ b/libs/vkd3d-shader/hlsl.h @@ -429,6 +429,11 @@ struct hlsl_ir_function_decl * Not to be confused with the function parameters! */ unsigned int attr_count; const struct hlsl_attribute *const *attrs; + + /* Synthetic boolean variable marking whether a return statement has been + * executed. Needed to deal with return statements in non-uniform control + * flow, since some backends can't handle them. */ + struct hlsl_ir_var *early_return_var; };
struct hlsl_ir_call diff --git a/libs/vkd3d-shader/hlsl_codegen.c b/libs/vkd3d-shader/hlsl_codegen.c index 7c95ebd1..579b9b53 100644 --- a/libs/vkd3d-shader/hlsl_codegen.c +++ b/libs/vkd3d-shader/hlsl_codegen.c @@ -499,6 +499,185 @@ static bool find_recursive_calls(struct hlsl_ctx *ctx, struct hlsl_ir_node *inst return false; }
+static void insert_early_return_break(struct hlsl_ctx *ctx, + struct hlsl_ir_function_decl *func, struct hlsl_ir_node *cf_instr) +{ + struct hlsl_ir_jump *jump; + struct hlsl_ir_load *load; + struct hlsl_ir_if *iff; + + if (!(load = hlsl_new_var_load(ctx, func->early_return_var, cf_instr->loc))) + return; + list_add_after(&cf_instr->entry, &load->node.entry); + + if (!(iff = hlsl_new_if(ctx, &load->node, cf_instr->loc))) + return; + list_add_after(&load->node.entry, &iff->node.entry); + + if (!(jump = hlsl_new_jump(ctx, HLSL_IR_JUMP_BREAK, cf_instr->loc))) + return; + list_add_tail(&iff->then_instrs.instrs, &jump->node.entry); +} + +/* Remove HLSL_IR_JUMP_RETURN calls by altering subsequent control flow. */ +static void lower_return(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *func, + struct hlsl_block *block, bool in_loop) +{ + struct hlsl_ir_node *return_instr = NULL, *cf_instr = NULL; + struct hlsl_ir_node *instr, *next; + + /* SM1 has no function calls. SM4 does, but native d3dcompiler inlines + * everything anyway. We are safest following suit. + * + * The basic idea is to keep track of whether the function has executed an + * early return in a synthesized boolean variable (func->early_return_var) + * and guard all code after the return on that variable being false. In the + * case of loops we also replace the return with a break. + * + * The following algorithm loops over instructions in a block, recursing + * into inferior CF blocks, until it hits one of the following two things: + * + * - A return statement. In this case, we remove everything after the return + * statement in this block. We have to stop and do this in a separate + * loop, because instructions must be deleted in reverse order (due to + * def-use chains.) + * + * If we're inside of a loop CF block, we can instead just turn the + * return into a break, which offers the right semantics—except that it + * won't break out of nested loops. + * + * - A CF block which might contain a return statement. After calling + * lower_return() on the CF block body, we stop, pull out everything after + * the CF instruction, shove it into an if block, and then lower that if + * block. + * + * (We could return a "did we make progress" boolean like transform_ir() + * and run this pass multiple times, but we already know the only block + * that still needs to be addressed, so there's not much point.) + * + * If we're inside of a loop CF block, we again do things differently. We + * already turned any returns into breaks. If the block we just processed + * was conditional, then "break" did our work for us. If it was a loop, + * we need to propagate that break to the outer loop. + */ + + LIST_FOR_EACH_ENTRY_SAFE(instr, next, &block->instrs, struct hlsl_ir_node, entry) + { + if (instr->type == HLSL_IR_CALL) + { + struct hlsl_ir_call *call = hlsl_ir_call(instr); + + lower_return(ctx, call->decl, &call->decl->body, false); + } + else if (instr->type == HLSL_IR_IF) + { + struct hlsl_ir_if *iff = hlsl_ir_if(instr); + + lower_return(ctx, func, &iff->then_instrs, in_loop); + lower_return(ctx, func, &iff->else_instrs, in_loop); + + /* If we're in a loop, we don't need to do anything here. We + * turned the return into a break, and that will already skip + * anything that comes after this "if" block. */ + if (!in_loop) + { + cf_instr = instr; + break; + } + } + else if (instr->type == HLSL_IR_LOOP) + { + lower_return(ctx, func, &hlsl_ir_loop(instr)->body, true); + + if (in_loop) + { + /* "instr" is a nested loop. "return" breaks out of all + * loops, so break out of this one too now. */ + insert_early_return_break(ctx, func, instr); + } + else + { + cf_instr = instr; + break; + } + } + else if (instr->type == HLSL_IR_JUMP) + { + struct hlsl_ir_jump *jump = hlsl_ir_jump(instr); + struct hlsl_ir_constant *constant; + struct hlsl_ir_store *store; + + if (jump->type == HLSL_IR_JUMP_RETURN) + { + if (!(constant = hlsl_new_bool_constant(ctx, true, &jump->node.loc))) + return; + list_add_before(&jump->node.entry, &constant->node.entry); + + if (!(store = hlsl_new_simple_store(ctx, func->early_return_var, &constant->node))) + return; + list_add_after(&constant->node.entry, &store->node.entry); + + if (in_loop) + { + jump->type = HLSL_IR_JUMP_BREAK; + } + else + { + return_instr = instr; + break; + } + } + } + } + + if (return_instr) + { + /* If we're in a loop, we should have used "break" instead. */ + assert(!in_loop); + + /* Iterate in reverse, to avoid use-after-free when unlinking sources from + * the "uses" list. */ + LIST_FOR_EACH_ENTRY_SAFE_REV(instr, next, &block->instrs, struct hlsl_ir_node, entry) + { + list_remove(&instr->entry); + hlsl_free_instr(instr); + + /* Yes, we just freed it, but we're comparing pointers. */ + if (instr == return_instr) + break; + } + } + else if (cf_instr) + { + struct list *tail = list_tail(&block->instrs); + struct hlsl_ir_load *load; + struct hlsl_ir_node *not; + struct hlsl_ir_if *iff; + + /* If we're in a loop, we should have used "break" instead. */ + assert(!in_loop); + + if (tail == &cf_instr->entry) + return; + + if (!(load = hlsl_new_var_load(ctx, func->early_return_var, cf_instr->loc))) + return; + list_add_tail(&block->instrs, &load->node.entry); + + if (!(not = hlsl_new_unary_expr(ctx, HLSL_OP1_LOGIC_NOT, &load->node, cf_instr->loc))) + return; + list_add_tail(&block->instrs, ¬->entry); + + if (!(iff = hlsl_new_if(ctx, not, cf_instr->loc))) + return; + list_add_tail(&block->instrs, &iff->node.entry); + + list_move_slice_tail(&iff->then_instrs.instrs, list_next(&block->instrs, &cf_instr->entry), tail); + + lower_return(ctx, func, &iff->then_instrs, in_loop); + } +} + /* Lower casts from vec1 to vecN to swizzles. */ static bool lower_broadcasts(struct hlsl_ctx *ctx, struct hlsl_ir_node *instr, void *context) { @@ -2926,6 +3105,13 @@ int hlsl_emit_bytecode(struct hlsl_ctx *ctx, struct hlsl_ir_function_decl *entry transform_ir(ctx, find_recursive_calls, body, &recursive_call_ctx); vkd3d_free(recursive_call_ctx.backtrace);
+ /* Avoid going into an infinite loop when processing call instructions. + * lower_return() recurses into inferior calls. */ + if (ctx->result) + return ctx->result; + + lower_return(ctx, entry_func, body, false); + LIST_FOR_EACH_ENTRY(var, &ctx->globals->vars, struct hlsl_ir_var, scope_entry) { if (var->storage_modifiers & HLSL_STORAGE_UNIFORM) diff --git a/tests/return.shader_test b/tests/return.shader_test index e913d15d..3847765e 100644 --- a/tests/return.shader_test +++ b/tests/return.shader_test @@ -10,7 +10,7 @@ float4 main() : sv_target
[test] draw quad -todo probe all rgba (0.1, 0.2, 0.3, 0.4) +probe all rgba (0.1, 0.2, 0.3, 0.4)
[pixel shader]
@@ -23,7 +23,7 @@ void main(out float4 ret : sv_target)
[test] draw quad -todo probe all rgba (0.1, 0.2, 0.3, 0.4) +probe all rgba (0.1, 0.2, 0.3, 0.4)
[pixel shader]
@@ -39,7 +39,7 @@ float4 main() : sv_target [test] uniform 0 float 0.2 draw quad -todo probe all rgba (0.1, 0.2, 0.3, 0.4) +probe all rgba (0.1, 0.2, 0.3, 0.4) uniform 0 float 0.8 draw quad probe all rgba (0.5, 0.6, 0.7, 0.8) @@ -69,7 +69,7 @@ draw quad probe all rgba (0.3, 0.4, 0.5, 0.6) uniform 0 float 0.8 draw quad -todo probe all rgba (0.1, 0.2, 0.3, 0.4) +probe all rgba (0.1, 0.2, 0.3, 0.4)
[pixel shader]
@@ -93,10 +93,10 @@ void main(out float4 ret : sv_target) [test] uniform 0 float 0.1 draw quad -todo probe all rgba (0.1, 0.2, 0.3, 0.4) 1 +probe all rgba (0.1, 0.2, 0.3, 0.4) 1 uniform 0 float 0.5 draw quad -todo probe all rgba (0.2, 0.3, 0.4, 0.5) 1 +probe all rgba (0.2, 0.3, 0.4, 0.5) 1 uniform 0 float 0.9 draw quad probe all rgba (0.5, 0.6, 0.7, 0.8) 1 @@ -120,7 +120,7 @@ void main(out float4 ret : sv_target) [test] uniform 0 float 0.1 draw quad -todo probe all rgba (0.1, 0.2, 0.3, 0.4) 1 +probe all rgba (0.1, 0.2, 0.3, 0.4) 1 uniform 0 float 0.5 draw quad probe all rgba (0.5, 0.6, 0.7, 0.8) 1