aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_long.py11
-rw-r--r--Objects/longobject.c24
2 files changed, 30 insertions, 5 deletions
diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py
index 3c8e9e22e17..f2a622b5868 100644
--- a/Lib/test/test_long.py
+++ b/Lib/test/test_long.py
@@ -1502,6 +1502,17 @@ class LongTest(unittest.TestCase):
self.assertEqual(type(numerator), int)
self.assertEqual(type(denominator), int)
+ def test_square(self):
+ # Multiplication makes a special case of multiplying an int with
+ # itself, using a special, faster algorithm. This test is mostly
+ # to ensure that no asserts in the implementation trigger, in
+ # cases with a maximal amount of carries.
+ for bitlen in range(1, 400):
+ n = (1 << bitlen) - 1 # solid string of 1 bits
+ with self.subTest(bitlen=bitlen, n=n):
+ # (2**i - 1)**2 = 2**(2*i) - 2*2**i + 1
+ self.assertEqual(n**2,
+ (1 << (2 * bitlen)) - (1 << (bitlen + 1)) + 1)
if __name__ == "__main__":
unittest.main()
diff --git a/Objects/longobject.c b/Objects/longobject.c
index b5648fca7dc..2db8701a841 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -3237,12 +3237,12 @@ x_mul(PyLongObject *a, PyLongObject *b)
* via exploiting that each entry in the multiplication
* pyramid appears twice (except for the size_a squares).
*/
+ digit *paend = a->ob_digit + size_a;
for (i = 0; i < size_a; ++i) {
twodigits carry;
twodigits f = a->ob_digit[i];
digit *pz = z->ob_digit + (i << 1);
digit *pa = a->ob_digit + i + 1;
- digit *paend = a->ob_digit + size_a;
SIGCHECK({
Py_DECREF(z);
@@ -3265,13 +3265,27 @@ x_mul(PyLongObject *a, PyLongObject *b)
assert(carry <= (PyLong_MASK << 1));
}
if (carry) {
+ /* See comment below. pz points at the highest possible
+ * carry position from the last outer loop iteration, so
+ * *pz is at most 1.
+ */
+ assert(*pz <= 1);
carry += *pz;
- *pz++ = (digit)(carry & PyLong_MASK);
+ *pz = (digit)(carry & PyLong_MASK);
carry >>= PyLong_SHIFT;
+ if (carry) {
+ /* If there's still a carry, it must be into a position
+ * that still holds a 0. Where the base
+ ^ B is 1 << PyLong_SHIFT, the last add was of a carry no
+ * more than 2*B - 2 to a stored digit no more than 1.
+ * So the sum was no more than 2*B - 1, so the current
+ * carry no more than floor((2*B - 1)/B) = 1.
+ */
+ assert(carry == 1);
+ assert(pz[1] == 0);
+ pz[1] = (digit)carry;
+ }
}
- if (carry)
- *pz += (digit)(carry & PyLong_MASK);
- assert((carry >> PyLong_SHIFT) == 0);
}
}
else { /* a is not the same as b -- gradeschool int mult */