summaryrefslogtreecommitdiffstatshomepage
path: root/extmod/modussl_mbedtls.c
diff options
context:
space:
mode:
Diffstat (limited to 'extmod/modussl_mbedtls.c')
-rw-r--r--extmod/modussl_mbedtls.c64
1 files changed, 44 insertions, 20 deletions
diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c
index 40dd8c049f..12ec60a756 100644
--- a/extmod/modussl_mbedtls.c
+++ b/extmod/modussl_mbedtls.c
@@ -29,11 +29,12 @@
#include <stdio.h>
#include <string.h>
-#include <errno.h>
+#include <errno.h> // needed because mp_is_nonblocking_error uses system error codes
#include "py/nlr.h"
#include "py/runtime.h"
#include "py/stream.h"
+#include "py/obj.h"
// mbedtls_time_t
#include "mbedtls/platform.h"
@@ -84,6 +85,9 @@ int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
int out_sz = sock_stream->write(sock, buf, len, &err);
if (out_sz == MP_STREAM_ERROR) {
+ if (mp_is_nonblocking_error(err)) {
+ return MBEDTLS_ERR_SSL_WANT_WRITE;
+ }
return -err;
} else {
return out_sz;
@@ -98,6 +102,9 @@ int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
int out_sz = sock_stream->read(sock, buf, len, &err);
if (out_sz == MP_STREAM_ERROR) {
+ if (mp_is_nonblocking_error(err)) {
+ return MBEDTLS_ERR_SSL_WANT_READ;
+ }
return -err;
} else {
return out_sz;
@@ -128,7 +135,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
}
ret = mbedtls_ssl_config_defaults(&o->conf,
- MBEDTLS_SSL_IS_CLIENT,
+ args->server_side.u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT);
if (ret != 0) {
@@ -172,21 +179,27 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
assert(ret == 0);
}
- if (args->server_side.u_bool) {
- assert(0);
- } else {
- while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
- if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
- //assert(0);
- printf("mbedtls_ssl_handshake error: -%x\n", -ret);
- mp_raise_OSError(MP_EIO);
- }
+ while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) {
+ if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
+ //assert(0);
+ printf("mbedtls_ssl_handshake error: -%x\n", -ret);
+ mp_raise_OSError(MP_EIO);
}
}
return o;
}
+STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) {
+ mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
+ if (!mp_obj_is_true(binary_form)) {
+ mp_raise_NotImplementedError(NULL);
+ }
+ const mbedtls_x509_crt* peer_cert = mbedtls_ssl_get_peer_cert(&o->ssl);
+ return mp_obj_new_bytes(peer_cert->raw.p, peer_cert->raw.len);
+}
+STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_getpeercert_obj, mod_ssl_getpeercert);
+
STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) {
(void)kind;
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
@@ -197,9 +210,16 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
int ret = mbedtls_ssl_read(&o->ssl, buf, size);
+ if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
+ // end of stream
+ return 0;
+ }
if (ret >= 0) {
return ret;
}
+ if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
+ ret = MP_EWOULDBLOCK;
+ }
*errcode = ret;
return MP_STREAM_ERROR;
}
@@ -211,32 +231,35 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
if (ret >= 0) {
return ret;
}
+ if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+ ret = MP_EWOULDBLOCK;
+ }
*errcode = ret;
return MP_STREAM_ERROR;
}
STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) {
- // Currently supports only blocking mode
- (void)self_in;
- if (!mp_obj_is_true(flag_in)) {
- mp_not_implemented("");
- }
- return mp_const_none;
+ mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in);
+ mp_obj_t sock = o->sock;
+ mp_obj_t dest[3];
+ mp_load_method(sock, MP_QSTR_setblocking, dest);
+ dest[2] = flag_in;
+ return mp_call_method_n_kw(1, 0, dest);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking);
STATIC mp_obj_t socket_close(mp_obj_t self_in) {
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
+ mbedtls_pk_free(&self->pkey);
+ mbedtls_x509_crt_free(&self->cert);
mbedtls_x509_crt_free(&self->cacert);
mbedtls_ssl_free(&self->ssl);
mbedtls_ssl_config_free(&self->conf);
mbedtls_ctr_drbg_free(&self->ctr_drbg);
mbedtls_entropy_free(&self->entropy);
- mp_obj_t dest[2];
- mp_load_method(self->sock, MP_QSTR_close, dest);
- return mp_call_method_n_kw(0, 0, dest);
+ return mp_stream_close(self->sock);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(socket_close_obj, socket_close);
@@ -247,6 +270,7 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) },
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&socket_close_obj) },
+ { MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) },
};
STATIC MP_DEFINE_CONST_DICT(ussl_socket_locals_dict, ussl_socket_locals_dict_table);