aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Python/ast_opt.c
diff options
context:
space:
mode:
authorYan Yanchii <yyanchiy@gmail.com>2025-02-21 18:54:22 +0100
committerGitHub <noreply@github.com>2025-02-21 17:54:22 +0000
commit38642bff139bde5c0118bc75fda25badc76b85fc (patch)
tree7e848b6faeda7761350e134cf57df6f104016cd2 /Python/ast_opt.c
parentd88677ac20b9466387459d5adb2e87b7de64bc19 (diff)
downloadcpython-38642bff139bde5c0118bc75fda25badc76b85fc.tar.gz
cpython-38642bff139bde5c0118bc75fda25badc76b85fc.zip
gh-126835: Move constant unaryop & binop folding to CFG (#129550)
Diffstat (limited to 'Python/ast_opt.c')
-rw-r--r--Python/ast_opt.c289
1 files changed, 41 insertions, 248 deletions
diff --git a/Python/ast_opt.c b/Python/ast_opt.c
index ab1ee96b045..2c6e16817f2 100644
--- a/Python/ast_opt.c
+++ b/Python/ast_opt.c
@@ -56,199 +56,6 @@ has_starred(asdl_expr_seq *elts)
return 0;
}
-
-static PyObject*
-unary_not(PyObject *v)
-{
- int r = PyObject_IsTrue(v);
- if (r < 0)
- return NULL;
- return PyBool_FromLong(!r);
-}
-
-static int
-fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
-{
- expr_ty arg = node->v.UnaryOp.operand;
-
- if (arg->kind != Constant_kind) {
- /* Fold not into comparison */
- if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
- asdl_seq_LEN(arg->v.Compare.ops) == 1) {
- /* Eq and NotEq are often implemented in terms of one another, so
- folding not (self == other) into self != other breaks implementation
- of !=. Detecting such cases doesn't seem worthwhile.
- Python uses </> for 'is subset'/'is superset' operations on sets.
- They don't satisfy not folding laws. */
- cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
- switch (op) {
- case Is:
- op = IsNot;
- break;
- case IsNot:
- op = Is;
- break;
- case In:
- op = NotIn;
- break;
- case NotIn:
- op = In;
- break;
- // The remaining comparison operators can't be safely inverted
- case Eq:
- case NotEq:
- case Lt:
- case LtE:
- case Gt:
- case GtE:
- op = 0; // The AST enums leave "0" free as an "unused" marker
- break;
- // No default case, so the compiler will emit a warning if new
- // comparison operators are added without being handled here
- }
- if (op) {
- asdl_seq_SET(arg->v.Compare.ops, 0, op);
- COPY_NODE(node, arg);
- return 1;
- }
- }
- return 1;
- }
-
- typedef PyObject *(*unary_op)(PyObject*);
- static const unary_op ops[] = {
- [Invert] = PyNumber_Invert,
- [Not] = unary_not,
- [UAdd] = PyNumber_Positive,
- [USub] = PyNumber_Negative,
- };
- PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
- return make_const(node, newval, arena);
-}
-
-/* Check whether a collection doesn't containing too much items (including
- subcollections). This protects from creating a constant that needs
- too much time for calculating a hash.
- "limit" is the maximal number of items.
- Returns the negative number if the total number of items exceeds the
- limit. Otherwise returns the limit minus the total number of items.
-*/
-
-static Py_ssize_t
-check_complexity(PyObject *obj, Py_ssize_t limit)
-{
- if (PyTuple_Check(obj)) {
- Py_ssize_t i;
- limit -= PyTuple_GET_SIZE(obj);
- for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
- limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
- }
- return limit;
- }
- return limit;
-}
-
-#define MAX_INT_SIZE 128 /* bits */
-#define MAX_COLLECTION_SIZE 256 /* items */
-#define MAX_STR_SIZE 4096 /* characters */
-#define MAX_TOTAL_ITEMS 1024 /* including nested collections */
-
-static PyObject *
-safe_multiply(PyObject *v, PyObject *w)
-{
- if (PyLong_Check(v) && PyLong_Check(w) &&
- !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
- ) {
- int64_t vbits = _PyLong_NumBits(v);
- int64_t wbits = _PyLong_NumBits(w);
- assert(vbits >= 0);
- assert(wbits >= 0);
- if (vbits + wbits > MAX_INT_SIZE) {
- return NULL;
- }
- }
- else if (PyLong_Check(v) && PyTuple_Check(w)) {
- Py_ssize_t size = PyTuple_GET_SIZE(w);
- if (size) {
- long n = PyLong_AsLong(v);
- if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
- return NULL;
- }
- if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
- return NULL;
- }
- }
- }
- else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
- Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
- PyBytes_GET_SIZE(w);
- if (size) {
- long n = PyLong_AsLong(v);
- if (n < 0 || n > MAX_STR_SIZE / size) {
- return NULL;
- }
- }
- }
- else if (PyLong_Check(w) &&
- (PyTuple_Check(v) || PyUnicode_Check(v) || PyBytes_Check(v)))
- {
- return safe_multiply(w, v);
- }
-
- return PyNumber_Multiply(v, w);
-}
-
-static PyObject *
-safe_power(PyObject *v, PyObject *w)
-{
- if (PyLong_Check(v) && PyLong_Check(w) &&
- !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
- ) {
- int64_t vbits = _PyLong_NumBits(v);
- size_t wbits = PyLong_AsSize_t(w);
- assert(vbits >= 0);
- if (wbits == (size_t)-1) {
- return NULL;
- }
- if ((uint64_t)vbits > MAX_INT_SIZE / wbits) {
- return NULL;
- }
- }
-
- return PyNumber_Power(v, w, Py_None);
-}
-
-static PyObject *
-safe_lshift(PyObject *v, PyObject *w)
-{
- if (PyLong_Check(v) && PyLong_Check(w) &&
- !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
- ) {
- int64_t vbits = _PyLong_NumBits(v);
- size_t wbits = PyLong_AsSize_t(w);
- assert(vbits >= 0);
- if (wbits == (size_t)-1) {
- return NULL;
- }
- if (wbits > MAX_INT_SIZE || (uint64_t)vbits > MAX_INT_SIZE - wbits) {
- return NULL;
- }
- }
-
- return PyNumber_Lshift(v, w);
-}
-
-static PyObject *
-safe_mod(PyObject *v, PyObject *w)
-{
- if (PyUnicode_Check(v) || PyBytes_Check(v)) {
- return NULL;
- }
-
- return PyNumber_Remainder(v, w);
-}
-
-
static expr_ty
parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
{
@@ -468,58 +275,7 @@ fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
}
- if (rhs->kind != Constant_kind) {
- return 1;
- }
-
- PyObject *rv = rhs->v.Constant.value;
- PyObject *newval = NULL;
-
- switch (node->v.BinOp.op) {
- case Add:
- newval = PyNumber_Add(lv, rv);
- break;
- case Sub:
- newval = PyNumber_Subtract(lv, rv);
- break;
- case Mult:
- newval = safe_multiply(lv, rv);
- break;
- case Div:
- newval = PyNumber_TrueDivide(lv, rv);
- break;
- case FloorDiv:
- newval = PyNumber_FloorDivide(lv, rv);
- break;
- case Mod:
- newval = safe_mod(lv, rv);
- break;
- case Pow:
- newval = safe_power(lv, rv);
- break;
- case LShift:
- newval = safe_lshift(lv, rv);
- break;
- case RShift:
- newval = PyNumber_Rshift(lv, rv);
- break;
- case BitOr:
- newval = PyNumber_Or(lv, rv);
- break;
- case BitXor:
- newval = PyNumber_Xor(lv, rv);
- break;
- case BitAnd:
- newval = PyNumber_And(lv, rv);
- break;
- // No builtin constants implement the following operators
- case MatMult:
- return 1;
- // No default case, so the compiler will emit a warning if new binary
- // operators are added without being handled here
- }
-
- return make_const(node, newval, arena);
+ return 1;
}
static PyObject*
@@ -670,7 +426,6 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
break;
case UnaryOp_kind:
CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
- CALL(fold_unaryop, expr_ty, node_);
break;
case Lambda_kind:
CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
@@ -962,6 +717,44 @@ astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
}
static int
+fold_const_match_patterns(expr_ty node, PyArena *ctx_, _PyASTOptimizeState *state)
+{
+ switch (node->kind)
+ {
+ case UnaryOp_kind:
+ {
+ if (node->v.UnaryOp.op == USub &&
+ node->v.UnaryOp.operand->kind == Constant_kind)
+ {
+ PyObject *operand = node->v.UnaryOp.operand->v.Constant.value;
+ PyObject *folded = PyNumber_Negative(operand);
+ return make_const(node, folded, ctx_);
+ }
+ break;
+ }
+ case BinOp_kind:
+ {
+ operator_ty op = node->v.BinOp.op;
+ if ((op == Add || op == Sub) &&
+ node->v.BinOp.right->kind == Constant_kind)
+ {
+ CALL(fold_const_match_patterns, expr_ty, node->v.BinOp.left);
+ if (node->v.BinOp.left->kind == Constant_kind) {
+ PyObject *left = node->v.BinOp.left->v.Constant.value;
+ PyObject *right = node->v.BinOp.right->v.Constant.value;
+ PyObject *folded = op == Add ? PyNumber_Add(left, right) : PyNumber_Subtract(left, right);
+ return make_const(node, folded, ctx_);
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return 1;
+}
+
+static int
astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
// Currently, this is really only used to form complex/negative numeric
@@ -970,7 +763,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
ENTER_RECURSIVE();
switch (node_->kind) {
case MatchValue_kind:
- CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
+ CALL(fold_const_match_patterns, expr_ty, node_->v.MatchValue.value);
break;
case MatchSingleton_kind:
break;
@@ -978,7 +771,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
break;
case MatchMapping_kind:
- CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
+ CALL_SEQ(fold_const_match_patterns, expr, node_->v.MatchMapping.keys);
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
break;
case MatchClass_kind: