Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Stack switching: fix some optimization passes #7271

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion scripts/test/fuzzing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@
'stack_switching_suspend.wast',
'stack_switching_resume.wast',
'stack_switching_resume_throw.wast',
'stack_switching_switch.wast'
'stack_switching_switch.wast',
'stack_switching_switch.wast_2',
'O3_stack-switching.wast',
'coalesce-locals-stack-switching.wast',
'dce-stack-switching.wast',
'precompute-stack-switching.wast',
'vacuum-stack-switching.wast'
]


Expand Down
27 changes: 27 additions & 0 deletions src/cfg/cfg-traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
self->tryStack.pop_back();
}

static void doEndResume(SubType* self, Expression** currp) {
auto* module = self->getModule();
if (!module || module->features.hasExceptionHandling()) {
// This resume might throw, so run the code to handle that.
doEndThrowingInst(self, currp);
}
auto handlerBlocks = BranchUtils::getUniqueTargets(*currp);
// Add branches to the targets.
for (auto target : handlerBlocks) {
self->branches[target].push_back(self->currBasicBlock);
}
}

static bool isReturnCall(Expression* curr) {
switch (curr->_id) {
case Expression::Id::CallId:
Expand Down Expand Up @@ -521,6 +534,20 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
self->pushTask(SubType::doEndThrow, currp);
break;
}
case Expression::Id::ResumeId:
case Expression::Id::ResumeThrowId: {
self->pushTask(SubType::doEndResume, currp);
break;
}
case Expression::Id::SuspendId:
case Expression::Id::StackSwitchId: {
auto* module = self->getModule();
if (!module || module->features.hasExceptionHandling()) {
// This might throw, so run the code to handle that.
self->pushTask(SubType::doEndCall, currp);
}
break;
}
default: {
if (Properties::isBranch(curr)) {
self->pushTask(SubType::doEndBranch, currp);
Expand Down
14 changes: 12 additions & 2 deletions src/ir/ReFinalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,18 @@ void ReFinalize::visitStringSliceWTF(StringSliceWTF* curr) { curr->finalize(); }
void ReFinalize::visitContNew(ContNew* curr) { curr->finalize(); }
void ReFinalize::visitContBind(ContBind* curr) { curr->finalize(); }
void ReFinalize::visitSuspend(Suspend* curr) { curr->finalize(getModule()); }
void ReFinalize::visitResume(Resume* curr) { curr->finalize(); }
void ReFinalize::visitResumeThrow(ResumeThrow* curr) { curr->finalize(); }
void ReFinalize::visitResume(Resume* curr) {
curr->finalize();
for (size_t i = 0; i < curr->handlerBlocks.size(); i++) {
updateBreakValueType(curr->handlerBlocks[i], curr->sentTypes[i]);
}
}
void ReFinalize::visitResumeThrow(ResumeThrow* curr) {
curr->finalize();
for (size_t i = 0; i < curr->handlerBlocks.size(); i++) {
updateBreakValueType(curr->handlerBlocks[i], curr->sentTypes[i]);
}
}
void ReFinalize::visitStackSwitch(StackSwitch* curr) { curr->finalize(); }

void ReFinalize::visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); }
Expand Down
8 changes: 4 additions & 4 deletions src/ir/branch-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ void operateOnScopeNameUsesAndSentTypes(Expression* expr, T func) {
}
}
} else if (auto* r = expr->dynCast<Resume>()) {
for (Index i = 0; i < r->handlerTags.size(); i++) {
auto dest = r->handlerTags[i];
for (Index i = 0; i < r->handlerBlocks.size(); i++) {
auto dest = r->handlerBlocks[i];
if (!dest.isNull() && dest == name) {
func(name, r->sentTypes[i]);
}
}
} else if (auto* r = expr->dynCast<ResumeThrow>()) {
for (Index i = 0; i < r->handlerTags.size(); i++) {
auto dest = r->handlerTags[i];
for (Index i = 0; i < r->handlerBlocks.size(); i++) {
auto dest = r->handlerBlocks[i];
if (!dest.isNull() && dest == name) {
func(name, r->sentTypes[i]);
}
Expand Down
26 changes: 20 additions & 6 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2617,15 +2617,29 @@ class ConstantExpressionRunner : public ExpressionRunner<SubType> {
}
return ExpressionRunner<SubType>::visitRefAs(curr);
}
Flow visitContNew(ContNew* curr) { WASM_UNREACHABLE("unimplemented"); }
Flow visitContBind(ContBind* curr) { WASM_UNREACHABLE("unimplemented"); }
Flow visitSuspend(Suspend* curr) { WASM_UNREACHABLE("unimplemented"); }
Flow visitResume(Resume* curr) { WASM_UNREACHABLE("unimplemented"); }
Flow visitContNew(ContNew* curr) {
NOTE_ENTER("ContNew");
return Flow(NONCONSTANT_FLOW);
}
Flow visitContBind(ContBind* curr) {
NOTE_ENTER("ContBind");
return Flow(NONCONSTANT_FLOW);
}
Flow visitSuspend(Suspend* curr) {
NOTE_ENTER("Suspend");
return Flow(NONCONSTANT_FLOW);
}
Flow visitResume(Resume* curr) {
NOTE_ENTER("Resume");
return Flow(NONCONSTANT_FLOW);
}
Flow visitResumeThrow(ResumeThrow* curr) {
WASM_UNREACHABLE("unimplemented");
NOTE_ENTER("ResumeThrow");
return Flow(NONCONSTANT_FLOW);
}
Flow visitStackSwitch(StackSwitch* curr) {
WASM_UNREACHABLE("unimplemented");
NOTE_ENTER("StackSwitch");
return Flow(NONCONSTANT_FLOW);
}

void trap(const char* why) override { throw NonconstantException(); }
Expand Down
7 changes: 6 additions & 1 deletion src/wasm/wasm-ir-builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2477,11 +2477,16 @@ Result<> IRBuilder::makeStackSwitch(HeapType ct, Name tag) {
}
StackSwitch curr(wasm.allocator);
curr.tag = tag;
auto nparams = ct.getContinuation().type.getSignature().params.size();
Type params = ct.getContinuation().type.getSignature().params;
auto nparams = params.size();
if (nparams < 1) {
return Err{"arity mismatch: the continuation argument must have, at least, "
"unary arity"};
}
if (!params[nparams - 1].isContinuation()) {
return Err{"the last argument of the continuation argument should be "
"itself a continuation"};
}

// The continuation argument of the continuation is synthetic,
// i.e. it is provided by the runtime.
Expand Down
6 changes: 5 additions & 1 deletion src/wasm/wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,8 +1441,12 @@ void StackSwitch::finalize() {
}

assert(this->cont->type.isContinuation());
type =
Type params =
this->cont->type.getHeapType().getContinuation().type.getSignature().params;
assert(params.size() > 0);
Type cont = params[params.size() - 1];
assert(cont.isContinuation());
type = cont.getHeapType().getContinuation().type.getSignature().params;
Comment on lines +1446 to +1449
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhil Does this look correct to you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks right to me. The last parameter of a switch-continuation must be a continuation type. This argument is synthetic in the sense that it is provided by switch rather than the programmer.

}

size_t Function::getNumParams() { return getParams().size(); }
Expand Down
29 changes: 29 additions & 0 deletions test/lit/basic/stack_switching_switch_2.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited.
;; RUN: wasm-opt -all %s -S -o - | filecheck %s


(module
;; CHECK: (type $function (func (param i64)))
(type $function (func (param i64)))
;; CHECK: (type $cont (cont $function))
(type $cont (cont $function))
;; CHECK: (type $function_2 (func (param i32 (ref $cont))))
(type $function_2 (func (param i32 (ref $cont))))
;; CHECK: (type $cont_2 (cont $function_2))
(type $cont_2 (cont $function_2))
;; CHECK: (tag $tag (type $4))
(tag $tag)

;; CHECK: (func $switch (type $5) (param $c (ref $cont_2)) (result i64)
;; CHECK-NEXT: (switch $cont_2 $tag
;; CHECK-NEXT: (i32.const 0)
;; CHECK-NEXT: (local.get $c)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $switch (param $c (ref $cont_2)) (result i64)
(switch $cont_2 $tag
(i32.const 0)
(local.get $c)
)
)
)
167 changes: 167 additions & 0 deletions test/lit/passes/O3_stack-switching.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited.
;; RUN: wasm-opt -all -O3 %s -S -o - | filecheck %s

;; Fairly comprehensive test case

(module
;; CHECK: (type $function_1 (func (param (ref eq) (ref eq)) (result (ref eq))))
(type $function_1 (func (param (ref eq) (ref eq)) (result (ref eq))))
;; CHECK: (type $closure (sub (struct (field (ref $function_1)))))
(type $closure (sub (struct (field (ref $function_1)))))
;; CHECK: (type $cont (cont $function_1))

;; CHECK: (type $function_2 (func (param (ref eq) (ref eq) (ref eq)) (result (ref eq))))
(type $function_2 (func (param (ref eq) (ref eq) (ref eq)) (result (ref eq))))
;; CHECK: (type $closure_2 (struct (field (ref $function_2))))
(type $closure_2 (struct (field (ref $function_2))))
;; CHECK: (type $handlers (struct (field $value (ref $closure)) (field $exn (ref $closure)) (field $effect (ref $closure_2))))
(type $handlers (struct (field $value (ref $closure)) (field $exn (ref $closure)) (field $effect (ref $closure_2))))
(type $cont (cont $function_1))
;; CHECK: (type $fiber (struct (field $handlers (ref $handlers)) (field $cont (ref $cont))))
(type $fiber (struct (field $handlers (ref $handlers)) (field $cont (ref $cont))))
;; CHECK: (tag $exception (type $8) (param (ref eq)))
(tag $exception (param (ref eq)))
;; CHECK: (tag $effect (type $9) (param (ref eq)) (result (ref eq) (ref eq)))
(tag $effect (param (ref eq)) (result (ref eq) (ref eq)))

;; CHECK: (func $resume (type $10) (param $0 (ref $fiber)) (param $1 (ref $closure)) (param $2 (ref eq)) (result (ref eq))
;; CHECK-NEXT: (local $3 (tuple (ref eq) (ref $cont)))
;; CHECK-NEXT: (local $4 (ref $handlers))
;; CHECK-NEXT: (local $5 (ref $closure_2))
;; CHECK-NEXT: (return_call_ref $function_1
;; CHECK-NEXT: (block $handle_exception (result (ref eq))
;; CHECK-NEXT: (return_call_ref $function_2
;; CHECK-NEXT: (tuple.extract 2 0
;; CHECK-NEXT: (local.tee $3
;; CHECK-NEXT: (block $handle_effect (type $7) (result (ref eq) (ref $cont))
;; CHECK-NEXT: (return_call_ref $function_1
;; CHECK-NEXT: (try_table (result (ref eq)) (catch $exception $handle_exception)
;; CHECK-NEXT: (resume $cont (on $effect $handle_effect)
;; CHECK-NEXT: (local.get $1)
;; CHECK-NEXT: (local.get $2)
;; CHECK-NEXT: (struct.get $fiber $cont
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.tee $1
;; CHECK-NEXT: (struct.get $handlers $value
;; CHECK-NEXT: (struct.get $fiber $handlers
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (struct.get $closure 0
;; CHECK-NEXT: (local.get $1)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (struct.new $fiber
;; CHECK-NEXT: (local.tee $4
;; CHECK-NEXT: (struct.get $fiber $handlers
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (tuple.extract 2 1
;; CHECK-NEXT: (local.get $3)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.tee $5
;; CHECK-NEXT: (struct.get $handlers $effect
;; CHECK-NEXT: (local.get $4)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (struct.get $closure_2 0
;; CHECK-NEXT: (local.get $5)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (local.tee $1
;; CHECK-NEXT: (struct.get $handlers $exn
;; CHECK-NEXT: (struct.get $fiber $handlers
;; CHECK-NEXT: (local.get $0)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (struct.get $closure 0
;; CHECK-NEXT: (local.get $1)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $resume (export "resume") (param $fiber (ref $fiber)) (param $f (ref $closure)) (param $v (ref eq)) (result (ref eq))
(local $g (ref $closure_2))
(local $res (ref eq))
(local $exn (ref eq))
(local $resume_res (tuple (ref eq) (ref $cont)))
(local.set $exn
(block $handle_exception (result (ref eq))
(local.set $resume_res
(block $handle_effect (result (ref eq) (ref $cont))
(local.set $res
(try_table (result (ref eq)) (catch $exception $handle_exception)
(resume $cont (on $effect $handle_effect)
(local.get $f)
(local.get $v)
(struct.get $fiber $cont
(local.get $fiber)
)
)
)
)
(return_call_ref $function_1
(local.get $res)
(local.tee $f
(struct.get $handlers $value
(struct.get $fiber $handlers
(local.get $fiber)
)
)
)
(struct.get $closure 0
(local.get $f)
)
)
)
)
(return_call_ref $function_2
(tuple.extract 2 0
(local.get $resume_res)
)
(struct.new $fiber
(struct.get $fiber $handlers
(local.get $fiber)
)
(tuple.extract 2 1
(local.get $resume_res)
)
)
(local.tee $g
(struct.get $handlers $effect
(struct.get $fiber $handlers
(local.get $fiber)
)
)
)
(struct.get $closure_2 0
(local.get $g)
)
)
)
)
(return_call_ref $function_1
(local.get $exn)
(local.tee $f
(struct.get $handlers $exn
(struct.get $fiber $handlers
(local.get $fiber)
)
)
)
(struct.get $closure 0
(local.get $f)
)
)
)
)
Loading