summaryrefslogtreecommitdiffstatshomepage
path: root/py/mpz.c
diff options
context:
space:
mode:
Diffstat (limited to 'py/mpz.c')
-rw-r--r--py/mpz.c116
1 files changed, 73 insertions, 43 deletions
diff --git a/py/mpz.c b/py/mpz.c
index e503927d09..f5675a2917 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -49,11 +49,17 @@
Definition of normalise: ?
*/
+STATIC size_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
+ for (--idig; idig >= oidig && *idig == 0; --idig) {
+ }
+ return idig + 1 - oidig;
+}
+
/* compares i with j
returns sign(i - j)
assumes i, j are normalised
*/
-STATIC int mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t *jdig, mp_uint_t jlen) {
+STATIC int mpn_cmp(const mpz_dig_t *idig, size_t ilen, const mpz_dig_t *jdig, size_t jlen) {
if (ilen < jlen) { return -1; }
if (ilen > jlen) { return 1; }
@@ -71,7 +77,7 @@ STATIC int mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t *jdig,
assumes enough memory in i; assumes normalised j; assumes n > 0
can have i, j pointing to same memory
*/
-STATIC mp_uint_t mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_uint_t n) {
+STATIC size_t mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mp_uint_t n) {
mp_uint_t n_whole = (n + DIG_SIZE - 1) / DIG_SIZE;
mp_uint_t n_part = n % DIG_SIZE;
if (n_part == 0) {
@@ -84,7 +90,7 @@ STATIC mp_uint_t mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
// shift the digits
mpz_dbl_dig_t d = 0;
- for (mp_uint_t i = jlen; i > 0; i--, idig--, jdig--) {
+ for (size_t i = jlen; i > 0; i--, idig--, jdig--) {
d |= *jdig;
*idig = (d >> (DIG_SIZE - n_part)) & DIG_MASK;
d <<= DIG_SIZE;
@@ -110,7 +116,7 @@ STATIC mp_uint_t mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
assumes enough memory in i; assumes normalised j; assumes n > 0
can have i, j pointing to same memory
*/
-STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_uint_t n) {
+STATIC size_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mp_uint_t n) {
mp_uint_t n_whole = n / DIG_SIZE;
mp_uint_t n_part = n % DIG_SIZE;
@@ -121,7 +127,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
jdig += n_whole;
jlen -= n_whole;
- for (mp_uint_t i = jlen; i > 0; i--, idig++, jdig++) {
+ for (size_t i = jlen; i > 0; i--, idig++, jdig++) {
mpz_dbl_dig_t d = *jdig;
if (i > 1) {
d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE;
@@ -142,7 +148,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC size_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
mpz_dig_t *oidig = idig;
mpz_dbl_dig_t carry = 0;
@@ -172,7 +178,7 @@ STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
assumes enough memory in i; assumes normalised j, k; assumes j >= k
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC size_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
mpz_dig_t *oidig = idig;
mpz_dbl_dig_signed_t borrow = 0;
@@ -190,16 +196,7 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
borrow >>= DIG_SIZE;
}
- for (--idig; idig >= oidig && *idig == 0; --idig) {
- }
-
- return idig + 1 - oidig;
-}
-
-STATIC mp_uint_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
- for (--idig; idig >= oidig && *idig == 0; --idig) {
- }
- return idig + 1 - oidig;
+ return mpn_remove_trailing_zeros(oidig, idig);
}
#if MICROPY_OPT_MPZ_BITWISE
@@ -209,7 +206,7 @@ STATIC mp_uint_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC size_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t *kdig, size_t klen) {
mpz_dig_t *oidig = idig;
for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
@@ -230,7 +227,7 @@ STATIC mp_uint_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t
assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+STATIC size_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
mpz_dig_t *oidig = idig;
mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
@@ -261,7 +258,7 @@ STATIC mp_uint_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t j
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC size_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
mpz_dig_t *oidig = idig;
jlen -= klen;
@@ -291,7 +288,7 @@ STATIC mp_uint_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
#if MICROPY_OPT_MPZ_BITWISE
-STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+STATIC size_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
mpz_dig_t *oidig = idig;
mpz_dbl_dig_t carryi = 1;
@@ -321,7 +318,7 @@ STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jl
#else
-STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+STATIC size_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
mpz_dig_t *oidig = idig;
mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
@@ -353,7 +350,7 @@ STATIC mp_uint_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jl
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC size_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
mpz_dig_t *oidig = idig;
jlen -= klen;
@@ -380,7 +377,7 @@ STATIC mp_uint_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
can have i, j, k pointing to same memory
*/
-STATIC mp_uint_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen, const mpz_dig_t *kdig, mp_uint_t klen,
+STATIC size_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
mpz_dig_t *oidig = idig;
@@ -405,7 +402,7 @@ STATIC mp_uint_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t j
returns number of digits in i
assumes enough memory in i; assumes normalised i; assumes dmul != 0
*/
-STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
+STATIC size_t mpn_mul_dig_add_dig(mpz_dig_t *idig, size_t ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
mpz_dig_t *oidig = idig;
mpz_dbl_dig_t carry = dadd;
@@ -427,15 +424,15 @@ STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t
assumes enough memory in i; assumes i is zeroed; assumes normalised j, k
can have j, k point to same memory
*/
-STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_dig_t *kdig, mp_uint_t klen) {
+STATIC size_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mpz_dig_t *kdig, size_t klen) {
mpz_dig_t *oidig = idig;
- mp_uint_t ilen = 0;
+ size_t ilen = 0;
for (; klen > 0; --klen, ++idig, ++kdig) {
mpz_dig_t *id = idig;
mpz_dbl_dig_t carry = 0;
- mp_uint_t jl = jlen;
+ size_t jl = jlen;
for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) {
carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
*id = carry & DIG_MASK;
@@ -458,7 +455,7 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
assumes quo_dig has enough memory (as many digits as num)
assumes quo_dig is filled with zeros
*/
-STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, const mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
+STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_dig, size_t den_len, mpz_dig_t *quo_dig, size_t *quo_len) {
mpz_dig_t *orig_num_dig = num_dig;
mpz_dig_t *orig_quo_dig = quo_dig;
mpz_dig_t norm_shift = 0;
@@ -661,7 +658,7 @@ void mpz_init_from_int(mpz_t *z, mp_int_t val) {
mpz_set_from_int(z, val);
}
-void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, mp_uint_t alloc, mp_int_t val) {
+void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, size_t alloc, mp_int_t val) {
z->neg = 0;
z->fixed_dig = 1;
z->alloc = alloc;
@@ -705,7 +702,7 @@ mpz_t *mpz_from_float(mp_float_t val) {
}
#endif
-mpz_t *mpz_from_str(const char *str, mp_uint_t len, bool neg, mp_uint_t base) {
+mpz_t *mpz_from_str(const char *str, size_t len, bool neg, unsigned int base) {
mpz_t *z = mpz_zero();
mpz_set_from_str(z, str, len, neg, base);
return z;
@@ -719,7 +716,7 @@ STATIC void mpz_free(mpz_t *z) {
}
}
-STATIC void mpz_need_dig(mpz_t *z, mp_uint_t need) {
+STATIC void mpz_need_dig(mpz_t *z, size_t need) {
if (need < MIN_ALLOC) {
need = MIN_ALLOC;
}
@@ -873,7 +870,7 @@ typedef uint32_t mp_float_int_t;
#endif
// returns number of bytes from str that were processed
-mp_uint_t mpz_set_from_str(mpz_t *z, const char *str, mp_uint_t len, bool neg, mp_uint_t base) {
+size_t mpz_set_from_str(mpz_t *z, const char *str, size_t len, bool neg, unsigned int base) {
assert(base <= 36);
const char *cur = str;
@@ -909,6 +906,39 @@ mp_uint_t mpz_set_from_str(mpz_t *z, const char *str, mp_uint_t len, bool neg, m
return cur - str;
}
+void mpz_set_from_bytes(mpz_t *z, bool big_endian, size_t len, const byte *buf) {
+ int delta = 1;
+ if (big_endian) {
+ buf += len - 1;
+ delta = -1;
+ }
+
+ mpz_need_dig(z, (len * 8 + DIG_SIZE - 1) / DIG_SIZE);
+
+ mpz_dig_t d = 0;
+ int num_bits = 0;
+ z->neg = 0;
+ z->len = 0;
+ while (len) {
+ while (len && num_bits < DIG_SIZE) {
+ d |= *buf << num_bits;
+ num_bits += 8;
+ buf += delta;
+ len--;
+ }
+ z->dig[z->len++] = d & DIG_MASK;
+ // Need this #if because it's C undefined behavior to do: uint32_t >> 32
+ #if DIG_SIZE != 8 && DIG_SIZE != 16 && DIG_SIZE != 32
+ d >>= DIG_SIZE;
+ #else
+ d = 0;
+ #endif
+ num_bits -= DIG_SIZE;
+ }
+
+ z->len = mpn_remove_trailing_zeros(z->dig, z->dig + z->len);
+}
+
bool mpz_is_zero(const mpz_t *z) {
return z->len == 0;
}
@@ -1120,7 +1150,7 @@ void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, mp_uint_t rhs) {
mp_uint_t n_whole = rhs / DIG_SIZE;
mp_uint_t n_part = rhs % DIG_SIZE;
mpz_dig_t round_up = 0;
- for (mp_uint_t i = 0; i < lhs->len && i < n_whole; i++) {
+ for (size_t i = 0; i < lhs->len && i < n_whole; i++) {
if (lhs->dig[i] != 0) {
round_up = 1;
break;
@@ -1364,9 +1394,6 @@ void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
mpz_free(n);
}
-#if 0
-these functions are unused
-
/* computes dest = (lhs ** rhs) % mod
can have dest, lhs, rhs the same; mod can't be the same as dest
*/
@@ -1405,6 +1432,9 @@ void mpz_pow3_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs, const mpz_t
mpz_free(n);
}
+#if 0
+these functions are unused
+
/* computes gcd(z1, z2)
based on Knuth's modified gcd algorithm (I think?)
gcd(z1, z2) >= 0
@@ -1593,7 +1623,7 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
}
// writes at most len bytes to buf (so buf should be zeroed before calling)
-void mpz_as_bytes(const mpz_t *z, bool big_endian, mp_uint_t len, byte *buf) {
+void mpz_as_bytes(const mpz_t *z, bool big_endian, size_t len, byte *buf) {
byte *b = buf;
if (big_endian) {
b += len;
@@ -1602,7 +1632,7 @@ void mpz_as_bytes(const mpz_t *z, bool big_endian, mp_uint_t len, byte *buf) {
int bits = 0;
mpz_dbl_dig_t d = 0;
mpz_dbl_dig_t carry = 1;
- for (mp_uint_t zlen = z->len; zlen > 0; --zlen) {
+ for (size_t zlen = z->len; zlen > 0; --zlen) {
bits += DIG_SIZE;
d = (d << DIG_SIZE) | *zdig++;
for (; bits >= 8; bits -= 8, d >>= 8) {
@@ -1645,8 +1675,8 @@ mp_float_t mpz_as_float(const mpz_t *i) {
#if 0
this function is unused
-char *mpz_as_str(const mpz_t *i, mp_uint_t base) {
- char *s = m_new(char, mpz_as_str_size(i, base, NULL, '\0'));
+char *mpz_as_str(const mpz_t *i, unsigned int base) {
+ char *s = m_new(char, mp_int_format_size(mpz_max_num_bits(i), base, NULL, '\0'));
mpz_as_str_inpl(i, base, NULL, 'a', '\0', s);
return s;
}
@@ -1654,7 +1684,7 @@ char *mpz_as_str(const mpz_t *i, mp_uint_t base) {
// assumes enough space as calculated by mp_int_format_size
// returns length of string, not including null byte
-mp_uint_t mpz_as_str_inpl(const mpz_t *i, mp_uint_t base, const char *prefix, char base_char, char comma, char *str) {
+size_t mpz_as_str_inpl(const mpz_t *i, unsigned int base, const char *prefix, char base_char, char comma, char *str) {
if (str == NULL) {
return 0;
}
@@ -1663,7 +1693,7 @@ mp_uint_t mpz_as_str_inpl(const mpz_t *i, mp_uint_t base, const char *prefix, ch
return 0;
}
- mp_uint_t ilen = i->len;
+ size_t ilen = i->len;
char *s = str;
if (ilen == 0) {