summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--extmod/modussl_axtls.c22
-rw-r--r--tests/extmod/ussl_basic.py8
-rw-r--r--tests/extmod/ussl_basic.py.exp3
3 files changed, 29 insertions, 4 deletions
diff --git a/extmod/modussl_axtls.c b/extmod/modussl_axtls.c
index a27f0f1fe5..a5ab8896c0 100644
--- a/extmod/modussl_axtls.c
+++ b/extmod/modussl_axtls.c
@@ -102,6 +102,11 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
+ if (o->ssl_sock == NULL) {
+ *errcode = EBADF;
+ return MP_STREAM_ERROR;
+ }
+
while (o->bytes_left == 0) {
mp_int_t r = ssl_read(o->ssl_sock, &o->buf);
if (r == SSL_OK) {
@@ -131,6 +136,12 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
+
+ if (o->ssl_sock == NULL) {
+ *errcode = EBADF;
+ return MP_STREAM_ERROR;
+ }
+
mp_int_t r = ssl_write(o->ssl_sock, buf, size);
if (r < 0) {
*errcode = r;
@@ -151,9 +162,14 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking);
STATIC mp_obj_t socket_close(mp_obj_t self_in) {
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
- ssl_free(self->ssl_sock);
- ssl_ctx_free(self->ssl_ctx);
- return mp_stream_close(self->sock);
+ if (self->ssl_sock != NULL) {
+ ssl_free(self->ssl_sock);
+ ssl_ctx_free(self->ssl_ctx);
+ self->ssl_sock = NULL;
+ return mp_stream_close(self->sock);
+ }
+
+ return mp_const_none;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(socket_close_obj, socket_close);
diff --git a/tests/extmod/ussl_basic.py b/tests/extmod/ussl_basic.py
index 9f8019a0bc..e8710ed51a 100644
--- a/tests/extmod/ussl_basic.py
+++ b/tests/extmod/ussl_basic.py
@@ -43,6 +43,14 @@ except OSError as er:
# close
ss.close()
+# close 2nd time
+ss.close()
+
+# read on closed socket
+try:
+ ss.read(10)
+except OSError as er:
+ print('read:', repr(er))
# write on closed socket
try:
diff --git a/tests/extmod/ussl_basic.py.exp b/tests/extmod/ussl_basic.py.exp
index b4dd038606..cb9c51f7a1 100644
--- a/tests/extmod/ussl_basic.py.exp
+++ b/tests/extmod/ussl_basic.py.exp
@@ -5,4 +5,5 @@ setblocking: NotImplementedError
4
b''
read: OSError(-261,)
-write: OSError(-256,)
+read: OSError(9,)
+write: OSError(9,)