diff --git a/SourceCode/arduino/include/theSecrets-example.h b/SourceCode/arduino/include/theSecrets-example.h new file mode 100644 index 0000000..1746311 --- /dev/null +++ b/SourceCode/arduino/include/theSecrets-example.h @@ -0,0 +1,9 @@ +#ifndef THESECRETS_H +#define THESECRETS_H + +const char *ssid = "Wifi"; +const char *psk = "password"; +const char *api_username = "username"; +const char *api_password = "password"; + +#endif diff --git a/SourceCode/arduino/lib/Firewall/FirewallTypes.h b/SourceCode/arduino/lib/Firewall/FirewallTypes.h deleted file mode 100644 index d89cc01..0000000 --- a/SourceCode/arduino/lib/Firewall/FirewallTypes.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef FIREWALL_TYPES_H -#define FIREWALL_TYPES_H - -#include "stdint.h" - -static const uint8_t IPV4ADDRESS_LENGTH = 16; - -typedef enum firewall_target : uint8_t -{ - FW_REJECT = 0, - FW_DROP = 1, - FW_ACCEPT = 2, -} firewall_target_t; - -typedef enum firewall_protocol : uint8_t -{ - FW_TCP = 0, - FW_UDP = 1, - FW_ALL = 255, -} firewall_protocol_t; - -typedef struct firewall_rule -{ - uint8_t key; - char source[IPV4ADDRESS_LENGTH]; - char destination[IPV4ADDRESS_LENGTH]; - firewall_protocol_t protocol; - firewall_target_t target; - struct firewall_rule *next; -} firewall_rule_t; - -#endif diff --git a/SourceCode/arduino/lib/Firewall/library.json b/SourceCode/arduino/lib/Firewall/library.json new file mode 100644 index 0000000..bec6af0 --- /dev/null +++ b/SourceCode/arduino/lib/Firewall/library.json @@ -0,0 +1,11 @@ +{ + "name": "firewall", + "license": "MIT", + "version": "0.0.1", + "frameworks": "arduino", + "platforms": ["espressif32"], + "dependencies": { + "bblanchon/ArduinoJson": "^6.19.4", + "external-repo": "https://github.com/fhessel/esp32_https_server/pull/91" + } +} diff --git a/SourceCode/arduino/lib/Firewall/src/Utils.cpp b/SourceCode/arduino/lib/Firewall/src/Utils.cpp new file mode 100644 index 0000000..3446404 --- /dev/null +++ b/SourceCode/arduino/lib/Firewall/src/Utils.cpp @@ -0,0 +1,57 @@ +#include "Utils.hpp" + +namespace firewall +{ + String protocol_to_string(firewall_protocol_t &protocol) + { + switch (protocol) + { + case FW_TCP: + return "TCP"; + case FW_UDP: + return "UDP"; + default: + return "ALL"; + } + } + + firewall_protocol_t string_to_protocol(std::string &protocol) + { + if (protocol.compare("TCP") == 0) + return FW_TCP; + else if (protocol.compare("UDP") == 0) + return FW_UDP; + else + return FW_ALL; + } + + String target_to_string(firewall_target_t &target) + { + switch (target) + { + case FW_REJECT: + return "REJECT"; + case FW_DROP: + return "DROP"; + default: + return "ACCEPT"; + } + } + + firewall_target_t string_to_target(std::string &target) + { + if (target.compare("REJECT") == 0) + return FW_REJECT; + else if (target.compare("DROP") == 0) + return FW_DROP; + else + return FW_ACCEPT; + } + + void endless_loop() + { + log_e("Something went wrong. Running endless loop until fixed..."); + while (true) + sleep(500); + } +} diff --git a/SourceCode/arduino/lib/Firewall/src/Utils.hpp b/SourceCode/arduino/lib/Firewall/src/Utils.hpp new file mode 100644 index 0000000..8be4958 --- /dev/null +++ b/SourceCode/arduino/lib/Firewall/src/Utils.hpp @@ -0,0 +1,56 @@ +#ifndef UTILS_HPP +#define UTILS_HPP + +#include "string" +#include "WString.h" +#include "esp32-hal-log.h" + +static const uint8_t IPV4ADDRESS_LENGTH = 16; + +typedef enum firewall_targets : uint8_t +{ + FW_REJECT = 0, + FW_DROP = 1, + FW_ACCEPT = 2, +} firewall_target_t; + +typedef enum firewall_protocols : uint8_t +{ + FW_TCP = 0, + FW_UDP = 1, + FW_ALL = 255, +} firewall_protocol_t; + +typedef enum ok : uint8_t +{ + SUCCESS = 0, + ERROR = 1, + NO_ACTION = 2, +} ok_t; + +typedef enum auth : uint8_t +{ + AUTHENTICATED = 0, + DENIED = 1, +} auth_t; + +typedef struct firewall_rules +{ + uint8_t key; + char source[IPV4ADDRESS_LENGTH]; + char destination[IPV4ADDRESS_LENGTH]; + firewall_protocol_t protocol; + firewall_target_t target; + struct firewall_rules *next; +} firewall_rule_t; + +namespace firewall +{ + String protocol_to_string(firewall_protocol_t &protocol); + firewall_protocol_t string_to_protocol(std::string &protocol); + String target_to_string(firewall_target_t &target); + firewall_target_t string_to_target(std::string &target); + void endless_loop(); +} + +#endif diff --git a/SourceCode/arduino/lib/Firewall/esp32API.cpp b/SourceCode/arduino/lib/Firewall/src/esp32API.cpp similarity index 72% rename from SourceCode/arduino/lib/Firewall/esp32API.cpp rename to SourceCode/arduino/lib/Firewall/src/esp32API.cpp index 6403d29..e922659 100644 --- a/SourceCode/arduino/lib/Firewall/esp32API.cpp +++ b/SourceCode/arduino/lib/Firewall/src/esp32API.cpp @@ -2,17 +2,18 @@ namespace firewall { - API::API(const uint16_t port) + API::API(const char *username, const char *password, const uint16_t port) { - this->setup_certificate(); + if (this->setup_auth(username, password) == ERROR) + endless_loop(); + if (this->setup_certificate() == ERROR) + endless_loop(); this->server = new HTTPSServer(this->certificate, port, 5); this->setup_routing(); log_i("Starting server..."); this->server->start(); if (this->server->isRunning()) - { - log_i("Server ready."); - } + log_i("Server ready on port: %i", port); } API::~API() @@ -24,12 +25,42 @@ namespace firewall this->server->loop(); } - void API::setup_certificate() + ok_t API::setup_auth(const char *username, const char *password) + { + if (!username || *username == 0x00 || strlen(username) > sizeof(this->username)) + { + log_e("Username too long or missing!"); + return ERROR; + } + strncpy(this->username, username, sizeof(this->username)); + if (!password || *password == 0x00 || strlen(password) > sizeof(this->password)) + { + log_e("Password too long or missing!"); + return ERROR; + } + strncpy(this->password, password, sizeof(this->password)); + return SUCCESS; + } + + auth_t API::check_auth(HTTPRequest *request, HTTPResponse *response) + { + std::string reqUsername = request->getBasicAuthUser(); + std::string reqPassword = request->getBasicAuthPassword(); + if ((strncmp(this->username, reqUsername.c_str(), sizeof(this->username)) != 0) || + (strncmp(this->password, reqPassword.c_str(), sizeof(this->password)) != 0)) + { + this->json_message_response(response, "unauthorized", 403); + return DENIED; + } + return AUTHENTICATED; + } + + ok_t API::setup_certificate() { this->certificate = retrieve_certificate(); if (certificate != NULL) - return; - log_i("Creating the certificate..."); + return NO_ACTION; + log_i("Creating a new certificate..."); this->certificate = new SSLCert(); int createCertResult = createSelfSignedCert( *this->certificate, @@ -39,12 +70,12 @@ namespace firewall "20320101000000"); if (createCertResult != 0) { - log_e("Cerating certificate failed. Error Code = 0x%02X, check SSLCert.hpp for details", createCertResult); - while (true) - delay(500); + log_e("Cannot create a server-certificate"); + return ERROR; } store_certificate(certificate); - log_i("Creating the certificate was successful"); + log_i("Creating a server-certificate was successful"); + return SUCCESS; } void API::setup_routing() @@ -68,6 +99,8 @@ namespace firewall void API::get_firewall_rule_handler(HTTPRequest *request, HTTPResponse *response) { + if (this->check_auth(request, response) == DENIED) + return; ResourceParameters *params = request->getParams(); int rule_number = atoi(params->getPathParameter(0).c_str()); firewall_rule_t *rule_ptr = get_rule_from_firewall(rule_number); @@ -85,29 +118,35 @@ namespace firewall void API::get_firewall_rules_handler(HTTPRequest *request, HTTPResponse *response) { + if (this->check_auth(request, response) == DENIED) + return; this->json_generic_response(response, this->construct_json_firewall(), 200); } bool API::request_has_firewall_parameter(ResourceParameters *params) { - return params->isQueryParameterSet("source") || params->isQueryParameterSet("destination") || params->isQueryParameterSet("protocol") || params->isQueryParameterSet("target"); + return params->isQueryParameterSet("source") || + params->isQueryParameterSet("destination") || + params->isQueryParameterSet("protocol") || + params->isQueryParameterSet("target"); } void API::post_firewall_handler(HTTPRequest *request, HTTPResponse *response) { + if (this->check_auth(request, response) == DENIED) + return; ResourceParameters *params = request->getParams(); if (request_has_firewall_parameter(params)) { firewall_rule_t *rule_ptr = (firewall_rule_t *)malloc(sizeof(firewall_rule_t)); rule_ptr->key = ++amount_of_rules; - // carefully copying c-string that is shorter then the destination char-array length std::string source; params->getQueryParameter("source", source); - strcpy(rule_ptr->source, source.length() <= IPV4ADDRESS_LENGTH ? source.c_str() : ""); + strncpy(rule_ptr->source, source.c_str(), sizeof(rule_ptr->source)); std::string destination; params->getQueryParameter("destination", destination); - strcpy(rule_ptr->destination, destination.length() <= IPV4ADDRESS_LENGTH ? destination.c_str() : ""); + strncpy(rule_ptr->destination, destination.c_str(), sizeof(rule_ptr->destination)); std::string protocol; params->getQueryParameter("protocol", protocol); @@ -127,9 +166,11 @@ namespace firewall void API::delete_firewall_handler(HTTPRequest *request, HTTPResponse *response) { + if (this->check_auth(request, response) == DENIED) + return; ResourceParameters *params = request->getParams(); int rule_number = atoi(params->getPathParameter(0).c_str()); - if (delete_rule_from_firewall(rule_number)) + if (delete_rule_from_firewall(rule_number) == SUCCESS) { this->json_message_response(response, "firewall rule deleted", 200); } diff --git a/SourceCode/arduino/lib/Firewall/esp32API.hpp b/SourceCode/arduino/lib/Firewall/src/esp32API.hpp similarity index 80% rename from SourceCode/arduino/lib/Firewall/esp32API.hpp rename to SourceCode/arduino/lib/Firewall/src/esp32API.hpp index f001d9d..fa7a203 100644 --- a/SourceCode/arduino/lib/Firewall/esp32API.hpp +++ b/SourceCode/arduino/lib/Firewall/src/esp32API.hpp @@ -7,7 +7,7 @@ #include "HTTPResponse.hpp" #include "ArduinoJson.h" -#include "FirewallTypes.h" +#include "Utils.hpp" #include "esp32Firewall.hpp" using namespace httpsserver; @@ -19,8 +19,13 @@ namespace firewall private: HTTPSServer *server; SSLCert *certificate; + char username[32]; + char password[32]; + + ok_t setup_auth(const char *, const char *); + auth_t check_auth(HTTPRequest *, HTTPResponse *); + ok_t setup_certificate(); - void setup_certificate(); void setup_routing(); void get_firewall_rule_handler(HTTPRequest *, HTTPResponse *); void get_firewall_rules_handler(HTTPRequest *, HTTPResponse *); @@ -35,7 +40,7 @@ namespace firewall String construct_json_firewall(); public: - API(const uint16_t = 8080); + API(const char *, const char *, const uint16_t = 8080); ~API(); void handle_clients(); }; diff --git a/SourceCode/arduino/lib/Firewall/esp32Firewall.cpp b/SourceCode/arduino/lib/Firewall/src/esp32Firewall.cpp similarity index 66% rename from SourceCode/arduino/lib/Firewall/esp32Firewall.cpp rename to SourceCode/arduino/lib/Firewall/src/esp32Firewall.cpp index 223c305..07ec32e 100644 --- a/SourceCode/arduino/lib/Firewall/esp32Firewall.cpp +++ b/SourceCode/arduino/lib/Firewall/src/esp32Firewall.cpp @@ -58,21 +58,17 @@ namespace firewall return rule_ptr; } - bool Firewall::delete_rule_from_firewall(uint8_t key) + ok_t Firewall::delete_rule_from_firewall(uint8_t key) { if (this->head == NULL) - { - return false; - } + return NO_ACTION; firewall_rule_t *current_rule_ptr = this->head; firewall_rule_t *previous_rule_ptr = NULL; firewall_rule_t *temp = NULL; while (current_rule_ptr->key != key) { if (current_rule_ptr->next == NULL) - { - return false; - } + return NO_ACTION; else { previous_rule_ptr = current_rule_ptr; @@ -99,52 +95,6 @@ namespace firewall store_settings_value("amount_of_rules", this->amount_of_rules); if (this->amount_of_rules != 0) store_all_firewall_rules(head); - return true; - } - - String Firewall::protocol_to_string(firewall_protocol_t &protocol) - { - switch (protocol) - { - case FW_TCP: - return "TCP"; - case FW_UDP: - return "UDP"; - default: - return "ALL"; - } - } - - firewall_protocol_t Firewall::string_to_protocol(std::string &protocol) - { - if (protocol.compare("TCP") == 0) - return FW_TCP; - else if (protocol.compare("UDP") == 0) - return FW_UDP; - else - return FW_ALL; - } - - String Firewall::target_to_string(firewall_target_t &target) - { - switch (target) - { - case FW_REJECT: - return "REJECT"; - case FW_DROP: - return "DROP"; - default: - return "ACCEPT"; - } - } - - firewall_target_t Firewall::string_to_target(std::string &target) - { - if (target.compare("REJECT") == 0) - return FW_REJECT; - else if (target.compare("DROP") == 0) - return FW_DROP; - else - return FW_ACCEPT; + return SUCCESS; } } diff --git a/SourceCode/arduino/lib/Firewall/esp32Firewall.hpp b/SourceCode/arduino/lib/Firewall/src/esp32Firewall.hpp similarity index 51% rename from SourceCode/arduino/lib/Firewall/esp32Firewall.hpp rename to SourceCode/arduino/lib/Firewall/src/esp32Firewall.hpp index c03df2e..7afd7fc 100644 --- a/SourceCode/arduino/lib/Firewall/esp32Firewall.hpp +++ b/SourceCode/arduino/lib/Firewall/src/esp32Firewall.hpp @@ -1,7 +1,7 @@ #ifndef ESP32_FIREWALL_HPP #define ESP32_FIREWALL_HPP -#include "FirewallTypes.h" +#include "Utils.hpp" #include "esp32Storage.hpp" #include "WString.h" @@ -11,16 +11,11 @@ namespace firewall { protected: uint8_t amount_of_rules = 0; - struct firewall_rule *head = NULL; + firewall_rule_t *head = NULL; void add_rule_to_firewall(firewall_rule_t *); firewall_rule_t *get_rule_from_firewall(uint8_t); - bool delete_rule_from_firewall(uint8_t); - - String protocol_to_string(firewall_protocol_t &protocol); - firewall_protocol_t string_to_protocol(std::string &protocol); - String target_to_string(firewall_target_t &target); - firewall_target_t string_to_target(std::string &target); + ok_t delete_rule_from_firewall(uint8_t); public: Firewall(); diff --git a/SourceCode/arduino/lib/Firewall/esp32Storage.cpp b/SourceCode/arduino/lib/Firewall/src/esp32Storage.cpp similarity index 82% rename from SourceCode/arduino/lib/Firewall/esp32Storage.cpp rename to SourceCode/arduino/lib/Firewall/src/esp32Storage.cpp index cbdb75b..1ef993f 100644 --- a/SourceCode/arduino/lib/Firewall/esp32Storage.cpp +++ b/SourceCode/arduino/lib/Firewall/src/esp32Storage.cpp @@ -4,25 +4,25 @@ namespace firewall { Storage::Storage() { - this->mount_spiffs(); + if (this->mount_spiffs() == ERROR) + endless_loop(); } Storage::~Storage() { } - void Storage::mount_spiffs() + ok_t Storage::mount_spiffs() { if (!SPIFFS.begin(false)) { if (!SPIFFS.begin(true)) { log_e("SPIFFS cannot be mounted"); - while (true) - delay(500); + return ERROR; }; } - log_i("SPIFFS mounted"); + return SUCCESS; } uint8_t Storage::retrieve_settings_value(const char *key) @@ -52,8 +52,8 @@ namespace firewall sprintf(rulename, "fwRule%i", key); this->memory.begin(rulename, true); - strcpy(rule_ptr->source, this->memory.getString("source", "0.0.0.0").c_str()); - strcpy(rule_ptr->destination, this->memory.getString("destination", "0.0.0.0").c_str()); + strncpy(rule_ptr->source, this->memory.getString("source", "0.0.0.0").c_str(), sizeof(rule_ptr->source)); + strncpy(rule_ptr->destination, this->memory.getString("destination", "0.0.0.0").c_str(), sizeof(rule_ptr->source)); rule_ptr->protocol = static_cast(this->memory.getUChar("protocol", FW_ALL)); rule_ptr->target = static_cast(this->memory.getUChar("target", FW_REJECT)); this->memory.end(); @@ -90,7 +90,7 @@ namespace firewall File certFile = SPIFFS.open("/cert.der"); if (!keyFile || !certFile || keyFile.size() == 0 || certFile.size() == 0) { - log_w("No certificate found in SPIFFS"); + log_e("No server-certificate found in SPIFFS"); return NULL; } size_t keySize = keyFile.size(); @@ -99,14 +99,14 @@ namespace firewall uint8_t *keyBuffer = new uint8_t[keySize]; if (keyBuffer == NULL) { - log_w("Not enough memory to load privat key"); + log_w("Not enough memory to load private key"); return NULL; } uint8_t *certBuffer = new uint8_t[certSize]; if (certBuffer == NULL) { delete[] keyBuffer; - log_w("Not enough memory to load certificate"); + log_w("Not enough memory to load server-certificate"); return NULL; } keyFile.read(keyBuffer, keySize); @@ -126,7 +126,7 @@ namespace firewall keyFile = SPIFFS.open("/key.der", FILE_WRITE); if (!keyFile || !keyFile.write(certificate->getPKData(), certificate->getPKLength())) { - log_w("Could not write /key.der"); + log_w("Cannot write /key.der"); failure = true; } if (keyFile) @@ -135,7 +135,7 @@ namespace firewall certFile = SPIFFS.open("/cert.der", FILE_WRITE); if (!certFile || !certFile.write(certificate->getCertData(), certificate->getCertLength())) { - log_w("Could not write /key.der"); + log_w("Cannot write /cert.der"); failure = true; } if (certFile) @@ -143,7 +143,7 @@ namespace firewall if (failure) { - log_w("Certificate could not be stored permanently, generating new certificate on reboot..."); + log_w("Server-certificate could not be stored permanently, generating new certificate on reboot..."); } } } diff --git a/SourceCode/arduino/lib/Firewall/esp32Storage.hpp b/SourceCode/arduino/lib/Firewall/src/esp32Storage.hpp similarity index 92% rename from SourceCode/arduino/lib/Firewall/esp32Storage.hpp rename to SourceCode/arduino/lib/Firewall/src/esp32Storage.hpp index cd88378..04ec725 100644 --- a/SourceCode/arduino/lib/Firewall/esp32Storage.hpp +++ b/SourceCode/arduino/lib/Firewall/src/esp32Storage.hpp @@ -3,7 +3,7 @@ #include "Preferences.h" #include "SPIFFS.h" -#include "FirewallTypes.h" +#include "Utils.hpp" #include "SSLCert.hpp" namespace firewall @@ -12,7 +12,7 @@ namespace firewall { private: Preferences memory; - void mount_spiffs(); + ok_t mount_spiffs(); protected: uint8_t retrieve_settings_value(const char *); diff --git a/SourceCode/arduino/platformio.ini b/SourceCode/arduino/platformio.ini index 8ea22a8..943e5b2 100644 --- a/SourceCode/arduino/platformio.ini +++ b/SourceCode/arduino/platformio.ini @@ -14,6 +14,7 @@ board = esp32-evb framework = arduino monitor_speed = 115200 build_flags = + -DHTTPS_LOGLEVEL=1 -DCORE_DEBUG_LEVEL=3 lib_deps = bblanchon/ArduinoJson@^6.19.4 @@ -23,5 +24,6 @@ board = az-delivery-devkit-v4 framework = arduino monitor_speed = 115200 build_flags = + -DHTTPS_LOGLEVEL=1 -DCORE_DEBUG_LEVEL=3 lib_deps = bblanchon/ArduinoJson@^6.19.4 diff --git a/SourceCode/arduino/src/main.cpp b/SourceCode/arduino/src/main.cpp index 9a1af17..4670332 100644 --- a/SourceCode/arduino/src/main.cpp +++ b/SourceCode/arduino/src/main.cpp @@ -8,7 +8,7 @@ void setup_wifi() { uint8_t max_retries = 5; uint8_t retries = 1; - log_i("Attempting to connect to WPA SSID: %s", ssid); + log_d("Attempting to connect to WPA SSID: %s", ssid); WiFi.mode(WIFI_STA); WiFi.begin(ssid, psk); while (WiFi.status() != WL_CONNECTED && retries <= max_retries) @@ -16,13 +16,13 @@ void setup_wifi() delay(2000); log_d("Connecting... (%i/%i)", retries++, max_retries); } - log_i("Connected, IP Address: %s", WiFi.localIP().toString().c_str()); + log_i("IP Address: %s", WiFi.localIP().toString().c_str()); } void setup() { setup_wifi(); - firewall_api = new firewall::API; + firewall_api = new firewall::API(api_username, api_password, 8080); } void loop()