summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/objtype.c24
-rw-r--r--tests/basics/subclass_native_buffer.py16
2 files changed, 39 insertions, 1 deletions
diff --git a/py/objtype.c b/py/objtype.c
index 07aa56434a..d10d6cbd5a 100644
--- a/py/objtype.c
+++ b/py/objtype.c
@@ -666,6 +666,25 @@ STATIC mp_obj_t instance_getiter(mp_obj_t self_in) {
return mp_call_function_n_kw(meth, 0, 0, NULL);
}
+STATIC mp_int_t instance_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) {
+ mp_obj_instance_t *self = self_in;
+ mp_obj_t member[2] = {MP_OBJ_NULL};
+ struct class_lookup_data lookup = {
+ .obj = self,
+ .attr = MP_QSTR_, // don't actually look for a method
+ .meth_offset = offsetof(mp_obj_type_t, buffer_p.get_buffer),
+ .dest = member,
+ .is_type = false,
+ };
+ mp_obj_class_lookup(&lookup, self->base.type);
+ if (member[0] == MP_OBJ_SENTINEL) {
+ mp_obj_type_t *type = mp_obj_get_type(self->subobj[0]);
+ return type->buffer_p.get_buffer(self->subobj[0], bufinfo, flags);
+ } else {
+ return 1; // object does not support buffer protocol
+ }
+}
+
/******************************************************************************/
// type object
// - the struct is mp_obj_type_t and is defined in obj.h so const types can be made
@@ -807,13 +826,16 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict)
o->name = name;
o->print = instance_print;
o->make_new = instance_make_new;
+ o->call = mp_obj_instance_call;
o->unary_op = instance_unary_op;
o->binary_op = instance_binary_op;
o->load_attr = mp_obj_instance_load_attr;
o->store_attr = mp_obj_instance_store_attr;
o->subscr = instance_subscr;
- o->call = mp_obj_instance_call;
o->getiter = instance_getiter;
+ //o->iternext = ; not implemented
+ o->buffer_p.get_buffer = instance_get_buffer;
+ //o->stream_p = ; not implemented
o->bases_tuple = bases_tuple;
o->locals_dict = locals_dict;
diff --git a/tests/basics/subclass_native_buffer.py b/tests/basics/subclass_native_buffer.py
new file mode 100644
index 0000000000..43c3819657
--- /dev/null
+++ b/tests/basics/subclass_native_buffer.py
@@ -0,0 +1,16 @@
+# test when we subclass a type with the buffer protocol
+
+class my_bytes(bytes):
+ pass
+
+b1 = my_bytes([0, 1])
+b2 = my_bytes([2, 3])
+b3 = bytes([4, 5])
+
+# addition will use the buffer protocol on the RHS
+print(b1 + b2)
+print(b1 + b3)
+print(b3 + b1)
+
+# bytearray construction will use the buffer protocol
+print(bytearray(b1))