summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/objstr.c14
-rw-r--r--tests/basics/string_count.py26
2 files changed, 33 insertions, 7 deletions
diff --git a/py/objstr.c b/py/objstr.c
index 6a2625b621..64ba6c5fad 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -495,8 +495,8 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) {
GET_STR_DATA_LEN(args[0], haystack, haystack_len);
GET_STR_DATA_LEN(args[1], needle, needle_len);
- size_t start = 0;
- size_t end = haystack_len;
+ machine_uint_t start = 0;
+ machine_uint_t end = haystack_len;
/* TODO use a non-exception-throwing mp_get_index */
if (n_args >= 3 && args[2] != mp_const_none) {
start = mp_get_index(&str_type, haystack_len, args[2], true);
@@ -505,13 +505,13 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) {
end = mp_get_index(&str_type, haystack_len, args[3], true);
}
- machine_int_t num_occurrences = 0;
-
- // needle won't exist in haystack if it's longer, so nothing to count
- if (needle_len > haystack_len) {
- MP_OBJ_NEW_SMALL_INT(0);
+ // if needle_len is zero then we count each gap between characters as an occurrence
+ if (needle_len == 0) {
+ return MP_OBJ_NEW_SMALL_INT(end - start + 1);
}
+ // count the occurrences
+ machine_int_t num_occurrences = 0;
for (machine_uint_t haystack_index = start; haystack_index + needle_len <= end; haystack_index++) {
if (memcmp(&haystack[haystack_index], needle, needle_len) == 0) {
num_occurrences++;
diff --git a/tests/basics/string_count.py b/tests/basics/string_count.py
index bac99e78d8..0da1b1fcae 100644
--- a/tests/basics/string_count.py
+++ b/tests/basics/string_count.py
@@ -1,3 +1,29 @@
+print("".count(""))
+print("".count("a"))
+print("a".count(""))
+print("a".count("a"))
+print("a".count("b"))
+print("b".count("a"))
+
+print("aaa".count(""))
+print("aaa".count("a"))
+print("aaa".count("aa"))
+print("aaa".count("aaa"))
+print("aaa".count("aaaa"))
+
+print("aaaa".count(""))
+print("aaaa".count("a"))
+print("aaaa".count("aa"))
+print("aaaa".count("aaa"))
+print("aaaa".count("aaaa"))
+print("aaaa".count("aaaaa"))
+
+print("aaa".count("", 1))
+print("aaa".count("", 2))
+print("aaa".count("", 3))
+
+print("aaa".count("", 1, 2))
+
print("asdfasdfaaa".count("asdf", -100))
print("asdfasdfaaa".count("asdf", -8))
print("asdf".count('s', True))