From e6d3b912afa26c447d6faa9cb4e0f821ba799bbf Mon Sep 17 00:00:00 2001 From: gf712 Date: Fri, 26 Jun 2026 22:42:44 +0100 Subject: [PATCH] mlir: fix break/continue inside a try/with --- integration/tests/break_continue_in_try.py | 288 ++++++++++++++++++ .../PythonToPythonBytecode.cpp | 286 ++++++++++++++--- src/executable/mlir/compile.cpp | 22 +- 3 files changed, 547 insertions(+), 49 deletions(-) create mode 100644 integration/tests/break_continue_in_try.py diff --git a/integration/tests/break_continue_in_try.py b/integration/tests/break_continue_in_try.py new file mode 100644 index 00000000..66e90cdf --- /dev/null +++ b/integration/tests/break_continue_in_try.py @@ -0,0 +1,288 @@ +# Regression test: `break`/`continue` that leaves a `try`/`except` or `with` +# block inside a loop. These previously produced a `cf.br` from inside the +# still-nested python.try/with region to a block in the enclosing loop region +# (an invalid cross-region branch), which sent MLIR's region DCE into unbounded +# recursion and crashed the compiler. os.py's removedirs()/_walk() hit this, +# so `import os` segfaulted. + +# break out of an except handler (while) -- the removedirs() shape. +def break_in_except_while(items): + out = [] + i = 0 + while i < 5: + try: + out.append(items[i]) + except IndexError: + break + i += 1 + return out + +assert break_in_except_while([10, 20]) == [10, 20], break_in_except_while([10, 20]) + +# continue out of an except handler (for). +def continue_in_except_for(items): + out = [] + for x in items: + try: + if x == 0: + raise ValueError("zero") + out.append(x) + except ValueError: + continue + return out + +assert continue_in_except_for([1, 0, 2, 0, 5]) == [1, 2, 5], continue_in_except_for([1, 0, 2, 0, 5]) + +# break from the try body itself (not from a handler). +def break_in_try_body(): + out = [] + i = 0 + while i < 10: + try: + out.append(i) + if i == 3: + break + except Exception: + out.append(-1) + i += 1 + return out + +assert break_in_try_body() == [0, 1, 2, 3], break_in_try_body() + +# break/continue out of a `with` must still run __exit__. +class CM: + def __init__(self, log): + self._log = log + def __enter__(self): + self._log.append("enter") + return self + def __exit__(self, *a): + self._log.append("exit") + return False + +def break_out_of_with(): + log = [] + i = 0 + while i < 3: + with CM(log): + if i == 1: + break + i += 1 + return log + +assert break_out_of_with() == ["enter", "exit", "enter", "exit"], break_out_of_with() + +def continue_out_of_with(): + log = [] + for i in range(3): + with CM(log): + if i == 1: + continue + log.append(("after", i)) + return log + +assert continue_out_of_with() == [ + "enter", "exit", ("after", 0), + "enter", "exit", + "enter", "exit", ("after", 2), +], continue_out_of_with() + +# nested try/except with break in the inner handler -- the _walk() shape. +def nested_try_break(values): + out = [] + it = iter(values) + while True: + try: + try: + v = next(it) + except StopIteration: + break + except RuntimeError: + out.append("runtime") + continue + out.append(v) + return out + +assert nested_try_break([1, 2, 3]) == [1, 2, 3], nested_try_break([1, 2, 3]) + +# break/continue out of a try must still run its finally first. +def break_in_except_with_finally(): + out = [] + i = 0 + while i < 5: + try: + out.append(("body", i)) + raise ValueError + except ValueError: + out.append(("except", i)) + break + finally: + out.append(("finally", i)) + i += 1 + return out + +assert break_in_except_with_finally() == [ + ("body", 0), ("except", 0), ("finally", 0), +], break_in_except_with_finally() + +def continue_with_finally(): + out = [] + for i in range(3): + try: + if i == 1: + raise ValueError + out.append(("ok", i)) + except ValueError: + continue + finally: + out.append(("fin", i)) + return out + +assert continue_with_finally() == [ + ("ok", 0), ("fin", 0), ("fin", 1), ("ok", 2), ("fin", 2), +], continue_with_finally() + +# both break and continue, each unwinding the same finally. +def break_and_continue_with_finally(): + out = [] + for i in range(6): + try: + if i == 1: + continue + if i == 4: + break + out.append(("ok", i)) + finally: + out.append(("fin", i)) + return out + +assert break_and_continue_with_finally() == [ + ("ok", 0), ("fin", 0), + ("fin", 1), + ("ok", 2), ("fin", 2), + ("ok", 3), ("fin", 3), + ("fin", 4), +], break_and_continue_with_finally() + +# break from the try body (no exception raised) still runs finally. +def break_in_body_with_finally(): + out = [] + i = 0 + while i < 5: + try: + out.append(("body", i)) + if i == 2: + break + finally: + out.append(("fin", i)) + i += 1 + return out + +assert break_in_body_with_finally() == [ + ("body", 0), ("fin", 0), + ("body", 1), ("fin", 1), + ("body", 2), ("fin", 2), +], break_in_body_with_finally() + +# break through *nested* try/finally runs every finally, innermost first. +def break_through_nested_finally(): + out = [] + i = 0 + while i < 4: + try: + try: + out.append(("body", i)) + if i == 1: + break + finally: + out.append(("inner-fin", i)) + finally: + out.append(("outer-fin", i)) + i += 1 + return out + +assert break_through_nested_finally() == [ + ("body", 0), ("inner-fin", 0), ("outer-fin", 0), + ("body", 1), ("inner-fin", 1), ("outer-fin", 1), +], break_through_nested_finally() + +# break written *inside* a finally exits the loop. +def break_inside_finally(): + out = [] + for i in range(4): + try: + out.append(("body", i)) + finally: + out.append(("fin", i)) + if i == 1: + break + return out + +assert break_inside_finally() == [ + ("body", 0), ("fin", 0), ("body", 1), ("fin", 1), +], break_inside_finally() + +# a break inside a finally swallows an exception that is in flight -- even one +# raised by a called function. +def boom(): + raise RuntimeError("from callee") + +def break_inside_finally_swallows_exception(): + out = [] + for i in range(4): + try: + out.append(("body", i)) + if i == 1: + boom() + finally: + out.append(("fin", i)) + if i == 1: + break + return out + +assert break_inside_finally_swallows_exception() == [ + ("body", 0), ("fin", 0), ("body", 1), ("fin", 1), +], break_inside_finally_swallows_exception() + +# a continue inside a finally overrides a break in the try body. +def finally_continue_overrides_break(): + out = [] + for i in range(4): + try: + out.append(("body", i)) + if i == 1: + break + finally: + out.append(("fin", i)) + if i == 1: + continue + return out + +assert finally_continue_overrides_break() == [ + ("body", 0), ("fin", 0), + ("body", 1), ("fin", 1), + ("body", 2), ("fin", 2), + ("body", 3), ("fin", 3), +], finally_continue_overrides_break() + +# a break inside an inner finally still runs the enclosing finally. +def break_inside_inner_finally(): + out = [] + for i in range(3): + try: + try: + out.append(("inner-body", i)) + finally: + out.append(("inner-fin", i)) + if i == 1: + break + finally: + out.append(("outer-fin", i)) + return out + +assert break_inside_inner_finally() == [ + ("inner-body", 0), ("inner-fin", 0), ("outer-fin", 0), + ("inner-body", 1), ("inner-fin", 1), ("outer-fin", 1), +], break_inside_inner_finally() + +print("break_continue_in_try: ok") diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp index 117a5742..8354be03 100644 --- a/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp @@ -29,6 +29,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "utilities.hpp" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -104,6 +106,91 @@ namespace py { region.walk(callback); } + // Collects the loop-control (break/continue) kinds that appear directly + // in `region` — i.e. that belong to the enclosing loop rather than to a + // nested loop/try/with (whose own lowering owns them). Descends into + // TryHandlerOp so except-handler bodies are scanned. + void collect_loop_control_kinds(mlir::Region ®ion, llvm::SmallSet &kinds) + { + if (region.empty()) { return; } + region.walk([&kinds](mlir::Operation *op) { + if (mlir::isa(op)) { + return WalkResult::skip(); + } + if (auto y = mlir::dyn_cast(op); + y && y.getKind().has_value()) { + kinds.insert(static_cast(*y.getKind())); + } + return WalkResult::advance(); + }); + } + + // A `break`/`continue` leaving a try/with body is represented as a + // `br_yield break_/continue_` marker that only the enclosing loop pass + // can resolve (to the loop's exit / continue target). Because try/with + // are flattened *before* the loops, we re-emit the marker here, after + // the block's exception cleanup ops, so it survives region inlining and + // ends up directly in the loop body region for the loop pass to lower. + // + // When the try has a `finally`, the marker cannot fire directly: the + // finally must run first. `finally_exits` maps each loop-control kind to + // the entry of a pre-built finally clone whose normal exit *is* the + // marker, so we branch there instead (see build_finally_loop_exits). + void forward_loop_control_yield(mlir::PatternRewriter &rewriter, + const llvm::DenseMap &finally_exits, + mlir::py::BranchYieldOp yield_op) + { + if (finally_exits.empty()) { + rewriter.create(yield_op.getLoc(), yield_op.getKindAttr()); + return; + } + auto it = finally_exits.find(static_cast(*yield_op.getKind())); + ASSERT(it != finally_exits.end()); + rewriter.create(yield_op.getLoc(), it->second); + } + + // For each break/continue kind that escapes the try through its finally, + // clone the (still-pristine) finally region onto a dedicated exit path + // and rewrite the clone's normal-completion (kindless) terminators into + // the loop-control marker. The result is a per-kind entry block: branch + // to it to "run finally, then break/continue". Must be called before the + // original finally region is rewritten/inlined below. + llvm::DenseMap build_finally_loop_exits(mlir::PatternRewriter &rewriter, + mlir::py::TryOp op, + mlir::Block *endBlock) + { + llvm::DenseMap exits; + if (op.getFinally().empty()) { return exits; } + + llvm::SmallSet kinds; + collect_loop_control_kinds(op.getBody(), kinds); + for (mlir::Region &handler : op.getHandlers()) { + collect_loop_control_kinds(handler, kinds); + } + collect_loop_control_kinds(op.getOrelse(), kinds); + + for (int kind : kinds) { + auto kind_attr = mlir::py::LoopOpKindAttr::get( + rewriter.getContext(), static_cast(kind)); + mlir::IRMapping mapping; + rewriter.cloneRegionBefore( + op.getFinally(), *endBlock->getParent(), endBlock->getIterator(), mapping); + for (mlir::Block &orig : op.getFinally()) { + auto *cloned = mapping.lookup(&orig); + if (auto y = mlir::dyn_cast(cloned->getTerminator()); + y && !y.getKind().has_value()) { + rewriter.setInsertionPoint(y); + rewriter.replaceOpWithNewOp(y, kind_attr); + } + } + exits[kind] = mapping.lookup(&op.getFinally().front()); + } + return exits; + } + struct ForLoopOpLowering : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -269,8 +356,10 @@ namespace py { || mlir::isa(childOp)) { return WalkResult::skip(); } - if (mlir::isa(childOp) - && !mlir::cast(childOp).getKind().has_value()) { + if (mlir::isa(childOp)) { + // Both normal-completion (kindless) and loop-control + // (break/continue) yields are surfaced; the callback + // dispatches on the kind. callback(childOp); return WalkResult::skip(); } @@ -281,6 +370,31 @@ namespace py { mlir::LogicalResult matchAndRewrite(mlir::py::TryOp op, mlir::PatternRewriter &rewriter) const final { + // Lower innermost-first when a finally is involved. A + // break/continue nested in an inner try only surfaces into this + // try's body once the inner try is lowered, and we must see it + // before pre-scanning the loop-control kinds to thread through + // our own finally (build_finally_loop_exits). MLIR's greedy + // worklist doesn't guarantee inner-first, so defer (fail this + // match) until any nested try in our regions has been lowered + // away; the driver re-tries us when the inner try rewrites. + // (Nested with/loops don't need this: with lowers in a later + // pass, and a loop consumes its own break/continue.) + if (!op.getFinally().empty()) { + bool has_nested_try = false; + auto scan = [&](mlir::Region ®ion) { + if (has_nested_try || region.empty()) { return; } + region.walk([&](mlir::py::TryOp) { + has_nested_try = true; + return WalkResult::interrupt(); + }); + }; + scan(op.getBody()); + for (mlir::Region &handler : op.getHandlers()) { scan(handler); } + scan(op.getOrelse()); + if (has_nested_try) { return mlir::failure(); } + } + auto *initBlock = rewriter.getInsertionBlock(); auto initPos = rewriter.getInsertionPoint(); @@ -288,13 +402,27 @@ namespace py { auto *body_start = &op.getBody().front(); - replace_controlflow_yield( - op.getBody(), [&rewriter, &op, endBlock](mlir::Operation *childOp) { + // Pre-build the per-kind finally exit paths for break/continue + // while the finally region is still pristine (the loop below + // rewrites it). Empty when there is no finally. + const auto finally_exits = build_finally_loop_exits(rewriter, op, endBlock); + + replace_controlflow_yield(op.getBody(), + [&rewriter, &op, &finally_exits, endBlock](mlir::Operation *childOp) { auto *current = childOp->getBlock(); auto *next = rewriter.splitBlock(current, childOp->getIterator()); rewriter.setInsertionPointToEnd(current); rewriter.create( childOp->getLoc()); + if (auto y = mlir::cast(childOp); + y.getKind().has_value()) { + // break/continue out of the try body: pop the + // exception handler, then defer to the enclosing loop + // (running the finally first if there is one). + forward_loop_control_yield(rewriter, finally_exits, y); + rewriter.eraseBlock(next); + return; + } if (op.getHandlers().empty()) { ASSERT(!op.getFinally().empty()); rewriter.create( @@ -323,22 +451,47 @@ namespace py { replace_controlflow_yield(op.getFinally(), [&rewriter, &op, &finally_mapping, endBlock](mlir::Operation *childOp) { + // A break/continue written *inside* the finally overrides + // whatever exit the try was heading for. Its kind drives + // both the normal and the exceptional finally copies. + auto kind_attr = + mlir::cast(childOp).getKindAttr(); + // Normal-completion copy: kindless yields fall through to + // the try's exit; an inside break/continue re-emits the + // loop-control marker for the enclosing loop pass. { auto *current = childOp->getBlock(); auto *next = rewriter.splitBlock(current, childOp->getIterator()); rewriter.setInsertionPointToEnd(current); - rewriter.create(childOp->getLoc(), endBlock); + if (kind_attr) { + rewriter.create( + childOp->getLoc(), kind_attr); + } else { + rewriter.create( + childOp->getLoc(), endBlock); + } rewriter.eraseBlock(next); } childOp = finally_mapping->lookup(childOp); ASSERT(childOp); + // Exceptional copy: kindless yields re-raise the in-flight + // exception after the finally; an inside break/continue + // instead *swallows* it (Python semantics) and performs + // the loop control flow. { auto *current = childOp->getBlock(); auto *next = rewriter.splitBlock(current, childOp->getIterator()); rewriter.setInsertionPointToEnd(current); - rewriter.create( - childOp->getLoc(), endBlock); + if (kind_attr) { + rewriter.create( + childOp->getLoc()); + rewriter.create( + childOp->getLoc(), kind_attr); + } else { + rewriter.create( + childOp->getLoc(), endBlock); + } rewriter.eraseBlock(next); } }); @@ -393,12 +546,21 @@ namespace py { rewriter.inlineRegionBefore(handler_scope.getCond(), endBlock); } replace_controlflow_yield(handler_scope.getHandler(), - [&rewriter, &op, endBlock](mlir::Operation *childOp) { + [&rewriter, &op, &finally_exits, endBlock](mlir::Operation *childOp) { auto *current = childOp->getBlock(); auto *next = rewriter.splitBlock(current, childOp->getIterator()); rewriter.setInsertionPointToEnd(current); rewriter.create( op.getLoc()); + if (auto y = mlir::cast(childOp); + y.getKind().has_value()) { + // break/continue out of an except handler: + // clear the active exception, then defer to + // the enclosing loop. + forward_loop_control_yield(rewriter, finally_exits, y); + rewriter.eraseBlock(next); + return; + } if (!op.getFinally().empty()) { rewriter.create( childOp->getLoc(), &op.getFinally().front()); @@ -439,12 +601,21 @@ namespace py { } replace_controlflow_yield(handler_scope.getHandler(), - [&rewriter, &op, endBlock](mlir::Operation *childOp) { + [&rewriter, &op, &finally_exits, endBlock](mlir::Operation *childOp) { auto *current = childOp->getBlock(); auto *next = rewriter.splitBlock(current, childOp->getIterator()); rewriter.setInsertionPointToEnd(current); rewriter.create( op.getLoc()); + if (auto y = mlir::cast(childOp); + y.getKind().has_value()) { + // break/continue out of an except handler: + // clear the active exception, then defer to + // the enclosing loop. + forward_loop_control_yield(rewriter, finally_exits, y); + rewriter.eraseBlock(next); + return; + } if (!op.getFinally().empty()) { rewriter.create( childOp->getLoc(), &op.getFinally().front()); @@ -458,11 +629,20 @@ namespace py { } } - replace_controlflow_yield( - op.getOrelse(), [&rewriter, &op, endBlock](mlir::Operation *childOp) { + replace_controlflow_yield(op.getOrelse(), + [&rewriter, &op, &finally_exits, endBlock](mlir::Operation *childOp) { auto *current = childOp->getBlock(); auto *next = rewriter.splitBlock(current, childOp->getIterator()); rewriter.setInsertionPointToEnd(current); + if (auto y = mlir::cast(childOp); + y.getKind().has_value()) { + // break/continue out of the else clause: the handler + // was already left when the body completed normally, + // so just defer to the enclosing loop. + forward_loop_control_yield(rewriter, finally_exits, y); + rewriter.eraseBlock(next); + return; + } if (!op.getFinally().empty()) { rewriter.create( childOp->getLoc(), &op.getFinally().front()); @@ -497,7 +677,37 @@ namespace py { auto *cleanup_block = rewriter.createBlock(endBlock); auto *exit_block = rewriter.createBlock(endBlock); - op.getBody().walk([&rewriter, exit_block, cleanup_block]( + // Emits the non-exceptional __exit__(None, None, None) sequence + // at the current insertion point. Shared by the normal-exit path + // and the break/continue path (both leave without an exception). + auto emit_normal_exit = [&rewriter, &op]() { + for (const auto &item : op.getItems()) { + auto exit = rewriter.create(item.getLoc(), + mlir::py::PyObjectType::get(rewriter.getContext()), + item, + "__exit__"); + auto none = rewriter.create( + item.getLoc(), rewriter.getNoneType()); + rewriter.create(item.getLoc(), + mlir::py::PyObjectType::get(rewriter.getContext()), + exit, + std::vector{ none, none, none }, + mlir::DenseStringElementsAttr::get( + mlir::VectorType::get( + { 0 }, mlir::StringAttr::get(rewriter.getContext()).getType()), + {}), + std::vector{}, + false, + false); + rewriter.create(item.getLoc()); + } + }; + + op.getBody().walk([&rewriter, + exit_block, + cleanup_block, + endBlock, + &emit_normal_exit]( mlir::Operation *childOp) { static_assert(mlir::py::BranchYieldOp::hasTrait::Impl>()); @@ -520,13 +730,28 @@ namespace py { rewriter.replaceOpWithNewOp( op, BlockRange{ cleanup_block }); } - } else if (auto op = mlir::dyn_cast(childOp); - op && !op.getKind().has_value()) { - auto *current = op->getBlock(); - auto *next = rewriter.splitBlock(current, op->getIterator()); + } else if (auto y = mlir::dyn_cast(childOp); + y && !y.getKind().has_value()) { + auto *current = y->getBlock(); + auto *next = rewriter.splitBlock(current, y->getIterator()); rewriter.setInsertionPointToEnd(current); - rewriter.create(op->getLoc()); - rewriter.create(op->getLoc(), exit_block); + rewriter.create(y->getLoc()); + rewriter.create(y->getLoc(), exit_block); + rewriter.eraseBlock(next); + } else if (auto y = mlir::dyn_cast(childOp); + y && y.getKind().has_value()) { + // break/continue out of the with body: leave the + // exception handler, run __exit__, then hand the marker + // to the enclosing loop on a dedicated exit path. + auto *current = y->getBlock(); + auto *next = rewriter.splitBlock(current, y->getIterator()); + auto *lc_block = rewriter.createBlock(endBlock); + rewriter.setInsertionPointToEnd(current); + rewriter.create(y->getLoc()); + rewriter.create(y->getLoc(), lc_block); + rewriter.setInsertionPointToStart(lc_block); + emit_normal_exit(); + rewriter.create(y->getLoc(), y.getKindAttr()); rewriter.eraseBlock(next); } return WalkResult::advance(); @@ -571,30 +796,7 @@ namespace py { } rewriter.setInsertionPointToStart(exit_block); - for (const auto &item : op.getItems()) { - auto exit = rewriter.create(item.getLoc(), - mlir::py::PyObjectType::get(rewriter.getContext()), - item, - "__exit__"); - - auto none = rewriter.create( - item.getLoc(), rewriter.getNoneType()); - - rewriter.create(item.getLoc(), - mlir::py::PyObjectType::get(rewriter.getContext()), - exit, - std::vector{ none, none, none }, - mlir::DenseStringElementsAttr::get( - mlir::VectorType::get( - { 0 }, mlir::StringAttr::get(rewriter.getContext()).getType()), - {}), - std::vector{}, - false, - false); - - rewriter.create(item.getLoc()); - } - + emit_normal_exit(); rewriter.create(op.getLoc(), endBlock); rewriter.setInsertionPointToEnd(initBlock); diff --git a/src/executable/mlir/compile.cpp b/src/executable/mlir/compile.cpp index e539679c..f501d4a5 100644 --- a/src/executable/mlir/compile.cpp +++ b/src/executable/mlir/compile.cpp @@ -62,17 +62,25 @@ std::shared_ptr compile(std::shared_ptr node, // own pass so we can interleave canonicalize + CSE between them. The // patterns perform structural surgery (block splits, region inlining, // IRMapping clones) and previously ran as part of a single greedy - // rewrite that couldn't simplify between them. Order: ForLoop and - // While first (they bake in step/condition blocks that the inner - // Try/With patterns may walk), then Try and With. - pm.addPass(::mlir::py::createConvertForLoopPass()); - pm.addPass(::mlir::py::createConvertWhileLoopPass()); - pm.addPass(::mlir::createCanonicalizerPass()); - pm.addPass(::mlir::createCSEPass()); + // rewrite that couldn't simplify between them. + // + // Try and With lower *before* ForLoop and While: a `break`/`continue` + // inside a try/with body becomes a `br_yield break_/continue_` marker + // that only the enclosing loop pass can resolve (to the loop's exit / + // continue target). If the loop lowered first, it would reach into the + // still-nested try/with region and emit a branch to a block defined in + // the parent region — invalid IR (a cross-region branch) that sends the + // later region-DCE into unbounded recursion. Flattening try/with first + // hoists those markers into the loop body region, so the loop pass + // resolves them with valid same-region branches. pm.addPass(::mlir::py::createConvertTryPass()); pm.addPass(::mlir::py::createConvertWithPass()); pm.addPass(::mlir::createCanonicalizerPass()); pm.addPass(::mlir::createCSEPass()); + pm.addPass(::mlir::py::createConvertForLoopPass()); + pm.addPass(::mlir::py::createConvertWhileLoopPass()); + pm.addPass(::mlir::createCanonicalizerPass()); + pm.addPass(::mlir::createCSEPass()); pm.addPass(::mlir::py::createPythonToPythonBytecodePass()); // Post-lowering canonicalize + CSE: dedupes the emitpybytecode.LOAD_CONST // ops the lowering and MLIRGenerator emit. Idiomatic Python compiles to