summaryrefslogtreecommitdiffstatshomepage
path: root/py/mpz.c
diff options
context:
space:
mode:
Diffstat (limited to 'py/mpz.c')
-rw-r--r--py/mpz.c267
1 files changed, 199 insertions, 68 deletions
diff --git a/py/mpz.c b/py/mpz.c
index b3f8b15b60..f02b75c2be 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -29,9 +29,6 @@
#include "py/mpz.h"
-// this is only needed for mp_not_implemented, which should eventually be removed
-#include "py/runtime.h"
-
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
#define DIG_SIZE (MPZ_DIG_SIZE)
@@ -199,6 +196,14 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
return idig + 1 - oidig;
}
+STATIC mp_uint_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
+ for (--idig; idig >= oidig && *idig == 0; --idig) {
+ }
+ return idig + 1 - oidig;
+}
+
+#if MICROPY_OPT_MPZ_BITWISE
+
/* computes i = j & k
returns number of digits in i
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
@@ -211,41 +216,46 @@ STATIC mp_uint_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t
*idig = *jdig & *kdig;
}
- // remove trailing zeros
- for (--idig; idig >= oidig && *idig == 0; --idig) {
- }
-
- return idig + 1 - oidig;
+ return mpn_remove_trailing_zeros(oidig, idig);
}
-/* computes i = j & -k = j & (~k + 1)
+#endif
+
+/* i = -((-j) & (-k)) = ~((~j + 1) & (~k + 1)) + 1
+ i = (j & (-k)) = (j & (~k + 1)) = ( j & (~k + 1))
+ i = ((-j) & k) = ((~j + 1) & k) = ((~j + 1) & k )
+ computes general form:
+ i = (im ^ (((j ^ jm) + jc) & ((k ^ km) + kc))) + ic where Xm = Xc == 0 ? 0 : DIG_MASK
returns number of digits in i
- assumes enough memory in i; assumes normalised j, k
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
mpz_dig_t *oidig = idig;
- mpz_dbl_dig_t carry = 1;
+ mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
+ mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
+ mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
- for (; jlen > 0 && klen > 0; --jlen, --klen, ++idig, ++jdig, ++kdig) {
- carry += *kdig ^ DIG_MASK;
- *idig = (*jdig & carry) & DIG_MASK;
- carry >>= DIG_SIZE;
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig ^ jmask;
+ carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
+ carryi += ((carryj & carryk) ^ imask) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
}
- for (; jlen > 0; --jlen, ++idig, ++jdig) {
- carry += DIG_MASK;
- *idig = (*jdig & carry) & DIG_MASK;
- carry >>= DIG_SIZE;
- }
-
- // remove trailing zeros
- for (--idig; idig >= oidig && *idig == 0; --idig) {
+ if (0 != carryi) {
+ *idig++ = carryi;
}
- return idig + 1 - oidig;
+ return mpn_remove_trailing_zeros(oidig, idig);
}
+#if MICROPY_OPT_MPZ_BITWISE
+
/* computes i = j | k
returns number of digits in i
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -267,6 +277,74 @@ STATIC mp_uint_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
return idig - oidig;
}
+#endif
+
+/* i = -((-j) | (-k)) = ~((~j + 1) | (~k + 1)) + 1
+ i = -(j | (-k)) = -(j | (~k + 1)) = ~( j | (~k + 1)) + 1
+ i = -((-j) | k) = -((~j + 1) | k) = ~((~j + 1) | k ) + 1
+ computes general form:
+ i = ~(((j ^ jm) + jc) | ((k ^ km) + kc)) + 1 where Xm = Xc == 0 ? 0 : DIG_MASK
+ returns number of digits in i
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
+ can have i, j, k pointing to same memory
+*/
+
+#if MICROPY_OPT_MPZ_BITWISE
+
+STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
+ mpz_dig_t *oidig = idig;
+ mpz_dbl_dig_t carryi = 1;
+ mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
+ mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
+
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig ^ jmask;
+ carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
+ carryi += ((carryj | carryk) ^ DIG_MASK) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
+ }
+
+ if (0 != carryi) {
+ *idig++ = carryi;
+ }
+
+ return mpn_remove_trailing_zeros(oidig, idig);
+}
+
+#else
+
+STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
+ mpz_dig_t *oidig = idig;
+ mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
+ mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
+ mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
+
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig ^ jmask;
+ carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
+ carryi += ((carryj | carryk) ^ imask) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
+ }
+
+ if (0 != carryi) {
+ *idig++ = carryi;
+ }
+
+ return mpn_remove_trailing_zeros(oidig, idig);
+}
+
+#endif
+
+#if MICROPY_OPT_MPZ_BITWISE
+
/* computes i = j ^ k
returns number of digits in i
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
@@ -285,11 +363,39 @@ STATIC mp_uint_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
*idig = *jdig;
}
- // remove trailing zeros
- for (--idig; idig >= oidig && *idig == 0; --idig) {
+ return mpn_remove_trailing_zeros(oidig, idig);
+}
+
+#endif
+
+/* i = (-j) ^ (-k) = ~(j - 1) ^ ~(k - 1) = (j - 1) ^ (k - 1)
+ i = -(j ^ (-k)) = -(j ^ ~(k - 1)) = ~(j ^ ~(k - 1)) + 1 = (j ^ (k - 1)) + 1
+ i = -((-j) ^ k) = -(~(j - 1) ^ k) = ~(~(j - 1) ^ k) + 1 = ((j - 1) ^ k) + 1
+ computes general form:
+ i = ((j - 1 + jc) ^ (k - 1 + kc)) + ic
+ returns number of digits in i
+ assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
+ can have i, j, k pointing to same memory
+*/
+STATIC mp_uint_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+ mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
+ mpz_dig_t *oidig = idig;
+
+ for (; jlen > 0; ++idig, ++jdig) {
+ carryj += *jdig + DIG_MASK;
+ carryk += (--klen <= --jlen) ? (*kdig++ + DIG_MASK) : DIG_MASK;
+ carryi += (carryj ^ carryk) & DIG_MASK;
+ *idig = carryi & DIG_MASK;
+ carryk >>= DIG_SIZE;
+ carryj >>= DIG_SIZE;
+ carryi >>= DIG_SIZE;
}
- return idig + 1 - oidig;
+ if (0 != carryi) {
+ *idig++ = carryi;
+ }
+
+ return mpn_remove_trailing_zeros(oidig, idig);
}
/* computes i = i * d1 + d2
@@ -1097,81 +1203,106 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
can have dest, lhs, rhs the same
*/
void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
- if (lhs->neg == rhs->neg) {
- if (lhs->neg == 0) {
- // make sure lhs has the most digits
- if (lhs->len < rhs->len) {
- const mpz_t *temp = lhs;
- lhs = rhs;
- rhs = temp;
- }
- // do the and'ing
- mpz_need_dig(dest, rhs->len);
- dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
- dest->neg = 0;
- } else {
- // TODO both args are negative
- mp_not_implemented("bignum and with negative args");
- }
+ // make sure lhs has the most digits
+ if (lhs->len < rhs->len) {
+ const mpz_t *temp = lhs;
+ lhs = rhs;
+ rhs = temp;
+ }
+
+ #if MICROPY_OPT_MPZ_BITWISE
+
+ if ((0 == lhs->neg) && (0 == rhs->neg)) {
+ mpz_need_dig(dest, lhs->len);
+ dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
+ dest->neg = 0;
} else {
- // args have different sign
- // make sure lhs is the positive arg
- if (rhs->neg == 0) {
- const mpz_t *temp = lhs;
- lhs = rhs;
- rhs = temp;
- }
mpz_need_dig(dest, lhs->len + 1);
- dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
- assert(dest->len <= dest->alloc);
- dest->neg = 0;
+ dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ lhs->neg == rhs->neg, 0 != lhs->neg, 0 != rhs->neg);
+ dest->neg = lhs->neg & rhs->neg;
}
+
+ #else
+
+ mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
+ dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ (lhs->neg == rhs->neg) ? lhs->neg : 0, lhs->neg, rhs->neg);
+ dest->neg = lhs->neg & rhs->neg;
+
+ #endif
}
/* computes dest = lhs | rhs
can have dest, lhs, rhs the same
*/
void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
- if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
+ // make sure lhs has the most digits
+ if (lhs->len < rhs->len) {
const mpz_t *temp = lhs;
lhs = rhs;
rhs = temp;
}
- if (lhs->neg == rhs->neg) {
+ #if MICROPY_OPT_MPZ_BITWISE
+
+ if ((0 == lhs->neg) && (0 == rhs->neg)) {
mpz_need_dig(dest, lhs->len);
dest->len = mpn_or(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ dest->neg = 0;
} else {
- mpz_need_dig(dest, lhs->len);
- // TODO
- mp_not_implemented("bignum or with negative args");
-// dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ mpz_need_dig(dest, lhs->len + 1);
+ dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ 0 != lhs->neg, 0 != rhs->neg);
+ dest->neg = 1;
}
- dest->neg = lhs->neg;
+ #else
+
+ mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
+ dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ (lhs->neg || rhs->neg), lhs->neg, rhs->neg);
+ dest->neg = lhs->neg | rhs->neg;
+
+ #endif
}
/* computes dest = lhs ^ rhs
can have dest, lhs, rhs the same
*/
void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
- if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
+ // make sure lhs has the most digits
+ if (lhs->len < rhs->len) {
const mpz_t *temp = lhs;
lhs = rhs;
rhs = temp;
}
+ #if MICROPY_OPT_MPZ_BITWISE
+
if (lhs->neg == rhs->neg) {
mpz_need_dig(dest, lhs->len);
- dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ if (lhs->neg == 0) {
+ dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ } else {
+ dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 0, 0, 0);
+ }
+ dest->neg = 0;
} else {
- mpz_need_dig(dest, lhs->len);
- // TODO
- mp_not_implemented("bignum xor with negative args");
-// dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
+ mpz_need_dig(dest, lhs->len + 1);
+ dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 1,
+ 0 == lhs->neg, 0 == rhs->neg);
+ dest->neg = 1;
}
- dest->neg = 0;
+ #else
+
+ mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
+ dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
+ (lhs->neg != rhs->neg), 0 == lhs->neg, 0 == rhs->neg);
+ dest->neg = lhs->neg ^ rhs->neg;
+
+ #endif
}
/* computes dest = lhs * rhs