diff options
author | Yan Yanchii <yyanchiy@gmail.com> | 2025-02-21 18:54:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-21 17:54:22 +0000 |
commit | 38642bff139bde5c0118bc75fda25badc76b85fc (patch) | |
tree | 7e848b6faeda7761350e134cf57df6f104016cd2 /Python/ast_opt.c | |
parent | d88677ac20b9466387459d5adb2e87b7de64bc19 (diff) | |
download | cpython-38642bff139bde5c0118bc75fda25badc76b85fc.tar.gz cpython-38642bff139bde5c0118bc75fda25badc76b85fc.zip |
gh-126835: Move constant unaryop & binop folding to CFG (#129550)
Diffstat (limited to 'Python/ast_opt.c')
-rw-r--r-- | Python/ast_opt.c | 289 |
1 files changed, 41 insertions, 248 deletions
diff --git a/Python/ast_opt.c b/Python/ast_opt.c index ab1ee96b045..2c6e16817f2 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -56,199 +56,6 @@ has_starred(asdl_expr_seq *elts) return 0; } - -static PyObject* -unary_not(PyObject *v) -{ - int r = PyObject_IsTrue(v); - if (r < 0) - return NULL; - return PyBool_FromLong(!r); -} - -static int -fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) -{ - expr_ty arg = node->v.UnaryOp.operand; - - if (arg->kind != Constant_kind) { - /* Fold not into comparison */ - if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind && - asdl_seq_LEN(arg->v.Compare.ops) == 1) { - /* Eq and NotEq are often implemented in terms of one another, so - folding not (self == other) into self != other breaks implementation - of !=. Detecting such cases doesn't seem worthwhile. - Python uses </> for 'is subset'/'is superset' operations on sets. - They don't satisfy not folding laws. */ - cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0); - switch (op) { - case Is: - op = IsNot; - break; - case IsNot: - op = Is; - break; - case In: - op = NotIn; - break; - case NotIn: - op = In; - break; - // The remaining comparison operators can't be safely inverted - case Eq: - case NotEq: - case Lt: - case LtE: - case Gt: - case GtE: - op = 0; // The AST enums leave "0" free as an "unused" marker - break; - // No default case, so the compiler will emit a warning if new - // comparison operators are added without being handled here - } - if (op) { - asdl_seq_SET(arg->v.Compare.ops, 0, op); - COPY_NODE(node, arg); - return 1; - } - } - return 1; - } - - typedef PyObject *(*unary_op)(PyObject*); - static const unary_op ops[] = { - [Invert] = PyNumber_Invert, - [Not] = unary_not, - [UAdd] = PyNumber_Positive, - [USub] = PyNumber_Negative, - }; - PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value); - return make_const(node, newval, arena); -} - -/* Check whether a collection doesn't containing too much items (including - subcollections). This protects from creating a constant that needs - too much time for calculating a hash. - "limit" is the maximal number of items. - Returns the negative number if the total number of items exceeds the - limit. Otherwise returns the limit minus the total number of items. -*/ - -static Py_ssize_t -check_complexity(PyObject *obj, Py_ssize_t limit) -{ - if (PyTuple_Check(obj)) { - Py_ssize_t i; - limit -= PyTuple_GET_SIZE(obj); - for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) { - limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit); - } - return limit; - } - return limit; -} - -#define MAX_INT_SIZE 128 /* bits */ -#define MAX_COLLECTION_SIZE 256 /* items */ -#define MAX_STR_SIZE 4096 /* characters */ -#define MAX_TOTAL_ITEMS 1024 /* including nested collections */ - -static PyObject * -safe_multiply(PyObject *v, PyObject *w) -{ - if (PyLong_Check(v) && PyLong_Check(w) && - !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w) - ) { - int64_t vbits = _PyLong_NumBits(v); - int64_t wbits = _PyLong_NumBits(w); - assert(vbits >= 0); - assert(wbits >= 0); - if (vbits + wbits > MAX_INT_SIZE) { - return NULL; - } - } - else if (PyLong_Check(v) && PyTuple_Check(w)) { - Py_ssize_t size = PyTuple_GET_SIZE(w); - if (size) { - long n = PyLong_AsLong(v); - if (n < 0 || n > MAX_COLLECTION_SIZE / size) { - return NULL; - } - if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) { - return NULL; - } - } - } - else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) { - Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) : - PyBytes_GET_SIZE(w); - if (size) { - long n = PyLong_AsLong(v); - if (n < 0 || n > MAX_STR_SIZE / size) { - return NULL; - } - } - } - else if (PyLong_Check(w) && - (PyTuple_Check(v) || PyUnicode_Check(v) || PyBytes_Check(v))) - { - return safe_multiply(w, v); - } - - return PyNumber_Multiply(v, w); -} - -static PyObject * -safe_power(PyObject *v, PyObject *w) -{ - if (PyLong_Check(v) && PyLong_Check(w) && - !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w) - ) { - int64_t vbits = _PyLong_NumBits(v); - size_t wbits = PyLong_AsSize_t(w); - assert(vbits >= 0); - if (wbits == (size_t)-1) { - return NULL; - } - if ((uint64_t)vbits > MAX_INT_SIZE / wbits) { - return NULL; - } - } - - return PyNumber_Power(v, w, Py_None); -} - -static PyObject * -safe_lshift(PyObject *v, PyObject *w) -{ - if (PyLong_Check(v) && PyLong_Check(w) && - !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w) - ) { - int64_t vbits = _PyLong_NumBits(v); - size_t wbits = PyLong_AsSize_t(w); - assert(vbits >= 0); - if (wbits == (size_t)-1) { - return NULL; - } - if (wbits > MAX_INT_SIZE || (uint64_t)vbits > MAX_INT_SIZE - wbits) { - return NULL; - } - } - - return PyNumber_Lshift(v, w); -} - -static PyObject * -safe_mod(PyObject *v, PyObject *w) -{ - if (PyUnicode_Check(v) || PyBytes_Check(v)) { - return NULL; - } - - return PyNumber_Remainder(v, w); -} - - static expr_ty parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena) { @@ -468,58 +275,7 @@ fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state) return optimize_format(node, lv, rhs->v.Tuple.elts, arena); } - if (rhs->kind != Constant_kind) { - return 1; - } - - PyObject *rv = rhs->v.Constant.value; - PyObject *newval = NULL; - - switch (node->v.BinOp.op) { - case Add: - newval = PyNumber_Add(lv, rv); - break; - case Sub: - newval = PyNumber_Subtract(lv, rv); - break; - case Mult: - newval = safe_multiply(lv, rv); - break; - case Div: - newval = PyNumber_TrueDivide(lv, rv); - break; - case FloorDiv: - newval = PyNumber_FloorDivide(lv, rv); - break; - case Mod: - newval = safe_mod(lv, rv); - break; - case Pow: - newval = safe_power(lv, rv); - break; - case LShift: - newval = safe_lshift(lv, rv); - break; - case RShift: - newval = PyNumber_Rshift(lv, rv); - break; - case BitOr: - newval = PyNumber_Or(lv, rv); - break; - case BitXor: - newval = PyNumber_Xor(lv, rv); - break; - case BitAnd: - newval = PyNumber_And(lv, rv); - break; - // No builtin constants implement the following operators - case MatMult: - return 1; - // No default case, so the compiler will emit a warning if new binary - // operators are added without being handled here - } - - return make_const(node, newval, arena); + return 1; } static PyObject* @@ -670,7 +426,6 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) break; case UnaryOp_kind: CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand); - CALL(fold_unaryop, expr_ty, node_); break; case Lambda_kind: CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args); @@ -962,6 +717,44 @@ astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) } static int +fold_const_match_patterns(expr_ty node, PyArena *ctx_, _PyASTOptimizeState *state) +{ + switch (node->kind) + { + case UnaryOp_kind: + { + if (node->v.UnaryOp.op == USub && + node->v.UnaryOp.operand->kind == Constant_kind) + { + PyObject *operand = node->v.UnaryOp.operand->v.Constant.value; + PyObject *folded = PyNumber_Negative(operand); + return make_const(node, folded, ctx_); + } + break; + } + case BinOp_kind: + { + operator_ty op = node->v.BinOp.op; + if ((op == Add || op == Sub) && + node->v.BinOp.right->kind == Constant_kind) + { + CALL(fold_const_match_patterns, expr_ty, node->v.BinOp.left); + if (node->v.BinOp.left->kind == Constant_kind) { + PyObject *left = node->v.BinOp.left->v.Constant.value; + PyObject *right = node->v.BinOp.right->v.Constant.value; + PyObject *folded = op == Add ? PyNumber_Add(left, right) : PyNumber_Subtract(left, right); + return make_const(node, folded, ctx_); + } + } + break; + } + default: + break; + } + return 1; +} + +static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) { // Currently, this is really only used to form complex/negative numeric @@ -970,7 +763,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) ENTER_RECURSIVE(); switch (node_->kind) { case MatchValue_kind: - CALL(astfold_expr, expr_ty, node_->v.MatchValue.value); + CALL(fold_const_match_patterns, expr_ty, node_->v.MatchValue.value); break; case MatchSingleton_kind: break; @@ -978,7 +771,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns); break; case MatchMapping_kind: - CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys); + CALL_SEQ(fold_const_match_patterns, expr, node_->v.MatchMapping.keys); CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns); break; case MatchClass_kind: |