diff options
Diffstat (limited to 'extmod/modussl_mbedtls.c')
-rw-r--r-- | extmod/modussl_mbedtls.c | 64 |
1 files changed, 44 insertions, 20 deletions
diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c index 40dd8c049f..12ec60a756 100644 --- a/extmod/modussl_mbedtls.c +++ b/extmod/modussl_mbedtls.c @@ -29,11 +29,12 @@ #include <stdio.h> #include <string.h> -#include <errno.h> +#include <errno.h> // needed because mp_is_nonblocking_error uses system error codes #include "py/nlr.h" #include "py/runtime.h" #include "py/stream.h" +#include "py/obj.h" // mbedtls_time_t #include "mbedtls/platform.h" @@ -84,6 +85,9 @@ int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) { int out_sz = sock_stream->write(sock, buf, len, &err); if (out_sz == MP_STREAM_ERROR) { + if (mp_is_nonblocking_error(err)) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } return -err; } else { return out_sz; @@ -98,6 +102,9 @@ int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { int out_sz = sock_stream->read(sock, buf, len, &err); if (out_sz == MP_STREAM_ERROR) { + if (mp_is_nonblocking_error(err)) { + return MBEDTLS_ERR_SSL_WANT_READ; + } return -err; } else { return out_sz; @@ -128,7 +135,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { } ret = mbedtls_ssl_config_defaults(&o->conf, - MBEDTLS_SSL_IS_CLIENT, + args->server_side.u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); if (ret != 0) { @@ -172,21 +179,27 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { assert(ret == 0); } - if (args->server_side.u_bool) { - assert(0); - } else { - while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) { - if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { - //assert(0); - printf("mbedtls_ssl_handshake error: -%x\n", -ret); - mp_raise_OSError(MP_EIO); - } + while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + //assert(0); + printf("mbedtls_ssl_handshake error: -%x\n", -ret); + mp_raise_OSError(MP_EIO); } } return o; } +STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) { + mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); + if (!mp_obj_is_true(binary_form)) { + mp_raise_NotImplementedError(NULL); + } + const mbedtls_x509_crt* peer_cert = mbedtls_ssl_get_peer_cert(&o->ssl); + return mp_obj_new_bytes(peer_cert->raw.p, peer_cert->raw.len); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_getpeercert_obj, mod_ssl_getpeercert); + STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) { (void)kind; mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); @@ -197,9 +210,16 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); int ret = mbedtls_ssl_read(&o->ssl, buf, size); + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + // end of stream + return 0; + } if (ret >= 0) { return ret; } + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + ret = MP_EWOULDBLOCK; + } *errcode = ret; return MP_STREAM_ERROR; } @@ -211,32 +231,35 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in if (ret >= 0) { return ret; } + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + ret = MP_EWOULDBLOCK; + } *errcode = ret; return MP_STREAM_ERROR; } STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) { - // Currently supports only blocking mode - (void)self_in; - if (!mp_obj_is_true(flag_in)) { - mp_not_implemented(""); - } - return mp_const_none; + mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in); + mp_obj_t sock = o->sock; + mp_obj_t dest[3]; + mp_load_method(sock, MP_QSTR_setblocking, dest); + dest[2] = flag_in; + return mp_call_method_n_kw(1, 0, dest); } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking); STATIC mp_obj_t socket_close(mp_obj_t self_in) { mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); + mbedtls_pk_free(&self->pkey); + mbedtls_x509_crt_free(&self->cert); mbedtls_x509_crt_free(&self->cacert); mbedtls_ssl_free(&self->ssl); mbedtls_ssl_config_free(&self->conf); mbedtls_ctr_drbg_free(&self->ctr_drbg); mbedtls_entropy_free(&self->entropy); - mp_obj_t dest[2]; - mp_load_method(self->sock, MP_QSTR_close, dest); - return mp_call_method_n_kw(0, 0, dest); + return mp_stream_close(self->sock); } STATIC MP_DEFINE_CONST_FUN_OBJ_1(socket_close_obj, socket_close); @@ -247,6 +270,7 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) }, { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) }, { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&socket_close_obj) }, + { MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) }, }; STATIC MP_DEFINE_CONST_DICT(ussl_socket_locals_dict, ussl_socket_locals_dict_table); |