diff --git a/SourceCode/arduino/lib/Firewall/Firewall.cpp b/SourceCode/arduino/lib/Firewall/Firewall.cpp index 276e2c9..6df0b24 100644 --- a/SourceCode/arduino/lib/Firewall/Firewall.cpp +++ b/SourceCode/arduino/lib/Firewall/Firewall.cpp @@ -2,34 +2,152 @@ ESPFirewall::ESPFirewall(int port) { + this->setup_eeprom(); log_i("Starting Firewall-API on %i", port); this->firewall_api = new AsyncWebServer(port); this->setup_routing(); } -void ESPFirewall::add_rule_to_firewall(firewall_rule_t *rule) +String ESPFirewall::protocol_to_string(firewall_protocol_t &protocol) +{ + switch (protocol) + { + case FW_TCP: + return "TCP"; + case FW_UDP: + return "UDP"; + default: + return "ALL"; + } +} + +firewall_protocol_t ESPFirewall::string_to_protocol(String &protocol) +{ + if (protocol.equals("TCP")) + return FW_TCP; + else if (protocol.equals("UDP")) + return FW_UDP; + else + return FW_ALL; +} + +String ESPFirewall::target_to_string(firewall_target_t &target) +{ + switch (target) + { + case FW_REJECT: + return "REJECT"; + case FW_DROP: + return "DROP"; + default: + return "ACCEPT"; + } +} + +firewall_target_t ESPFirewall::string_to_target(String &target) +{ + if (target.equals("REJECT")) + return FW_REJECT; + else if (target.equals("DROP")) + return FW_DROP; + else + return FW_ACCEPT; +} + +void ESPFirewall::setup_eeprom() +{ + EEPROM.begin(this->eeprom_size); + this->amount_of_rules = EEPROM.read(this->eeprom_settings_head); + uint8_t security_number = EEPROM.read(this->eeprom_settings_head + 1); + log_i("Amount of existing Rules %i", this->amount_of_rules); + if (this->amount_of_rules > 50 || security_number != this->security_number) + { + this->amount_of_rules = 0; + EEPROM.write(this->eeprom_settings_head, this->amount_of_rules); + EEPROM.write(this->eeprom_settings_head + 1, this->security_number); + EEPROM.commit(); + } + log_i("Amount of existing Rules %i", this->amount_of_rules); + this->eeprom_read_firewall_rules(); +} + +void ESPFirewall::eeprom_write_firewall_rule(firewall_rule_t *rule_ptr) +{ + EEPROM.write(this->eeprom_settings_head, this->amount_of_rules); + EEPROM.writeString(this->eeprom_rules_head, rule_ptr->source); + this->eeprom_rules_head += IP4ADDR_STRLEN_MAX; + EEPROM.writeString(this->eeprom_rules_head, rule_ptr->destination); + this->eeprom_rules_head += IP4ADDR_STRLEN_MAX; + EEPROM.write(this->eeprom_rules_head, rule_ptr->protocol); + this->eeprom_rules_head += sizeof(firewall_protocol_t); + EEPROM.write(this->eeprom_rules_head, rule_ptr->target); + this->eeprom_rules_head += sizeof(firewall_target_t); + EEPROM.commit(); +} + +void ESPFirewall::eeprom_write_firewall_rules() +{ + this->eeprom_rules_head = eeprom_start_firewall_rules; + firewall_rule_t *rule_ptr = this->head; + while (rule_ptr != NULL) + { + this->eeprom_write_firewall_rule(rule_ptr); + rule_ptr = rule_ptr->next; + } +} + +void ESPFirewall::eeprom_read_firewall_rule(uint8_t &eeprom_address, uint8_t &rule_nr) +{ + firewall_rule_t *rule_ptr = (firewall_rule_t *)malloc(sizeof(firewall_rule_t)); + rule_ptr->key = rule_nr; + strcpy(rule_ptr->source, EEPROM.readString(eeprom_address).c_str()); + eeprom_address += IP4ADDR_STRLEN_MAX; + strcpy(rule_ptr->destination, EEPROM.readString(eeprom_address).c_str()); + eeprom_address += IP4ADDR_STRLEN_MAX; + rule_ptr->protocol = static_cast(EEPROM.read(eeprom_address)); + eeprom_address += sizeof(firewall_protocol_t); + rule_ptr->target = static_cast(EEPROM.read(eeprom_address)); + eeprom_address += sizeof(firewall_target_t); + add_rule_to_firewall(rule_ptr); + log_i("%s, %s, %s, %s", + rule_ptr->source, + rule_ptr->destination, + protocol_to_string(rule_ptr->protocol), + target_to_string(rule_ptr->target)); +} + +void ESPFirewall::eeprom_read_firewall_rules() +{ + uint8_t eeprom_address = eeprom_start_firewall_rules; + for (uint8_t i = 1; i <= this->amount_of_rules; i++) + { + eeprom_read_firewall_rule(eeprom_address, i); + } +} + +void ESPFirewall::add_rule_to_firewall(firewall_rule_t *rule_ptr) { firewall_rule_t *temp; - if (head == NULL) + if (this->head == NULL) { - head = rule; - rule->next = NULL; + this->head = rule_ptr; + rule_ptr->next = NULL; return; } - temp = head; + temp = this->head; while (temp->next != NULL) { temp = temp->next; } - temp->next = rule; - rule->next = NULL; + temp->next = rule_ptr; + rule_ptr->next = NULL; return; } -firewall_rule_t *ESPFirewall::get_rule_from_firewall(int key) +firewall_rule_t *ESPFirewall::get_rule_from_firewall(uint8_t key) { firewall_rule_t *rule_ptr = this->head; - if (head == NULL) + if (this->head == NULL) { return NULL; } @@ -47,9 +165,9 @@ firewall_rule_t *ESPFirewall::get_rule_from_firewall(int key) return rule_ptr; } -bool ESPFirewall::delete_rule_from_firewall(int key) +bool ESPFirewall::delete_rule_from_firewall(uint8_t key) { - if (head == NULL) + if (this->head == NULL) { return false; } @@ -68,10 +186,10 @@ bool ESPFirewall::delete_rule_from_firewall(int key) current_rule_ptr = current_rule_ptr->next; } } - if (current_rule_ptr == head) + if (current_rule_ptr == this->head) { - head = head->next; - temp = head; + this->head = head->next; + temp = this->head; } else { @@ -84,7 +202,8 @@ bool ESPFirewall::delete_rule_from_firewall(int key) temp = temp->next; } free(current_rule_ptr); - amount_of_rules--; + this->amount_of_rules--; + this->eeprom_write_firewall_rules(); return true; } @@ -94,6 +213,8 @@ void ESPFirewall::setup_routing() firewall_api->on("/api/v1/firewall", HTTP_GET, std::bind(&ESPFirewall::get_firewall_rules_handler, this, std::placeholders::_1)); firewall_api->on("/api/v1/firewall", HTTP_POST, std::bind(&ESPFirewall::post_firewall_handler, this, std::placeholders::_1)); firewall_api->on("^\\/api/v1/firewall\\/([0-9]+)$", HTTP_DELETE, std::bind(&ESPFirewall::delete_firewall_handler, this, std::placeholders::_1)); + + firewall_api->on("/api/v1/device/restart", HTTP_GET, std::bind(&ESPFirewall::restart_device_handler, this, std::placeholders::_1)); firewall_api->onNotFound(std::bind(&ESPFirewall::not_found, this, std::placeholders::_1)); this->firewall_api->begin(); } @@ -113,8 +234,8 @@ String ESPFirewall::construct_json_firewall_rule(firewall_rule_t *rule_ptr) doc["key"] = rule_ptr->key; doc["source"] = rule_ptr->source; doc["destination"] = rule_ptr->destination; - doc["protocol"] = rule_ptr->protocol; - doc["target"] = rule_ptr->target; + doc["protocol"] = protocol_to_string(rule_ptr->protocol); + doc["target"] = target_to_string(rule_ptr->target); String response; serializeJson(doc, response); return response; @@ -123,10 +244,10 @@ String ESPFirewall::construct_json_firewall_rule(firewall_rule_t *rule_ptr) String ESPFirewall::construct_json_firewall() { firewall_rule_t *rule_ptr = this->head; - // Size for 12 Rules + // Size for max 12 Rules StaticJsonDocument<2048> doc; String response; - doc["amount"] = amount_of_rules; + doc["amount_of_rules"] = this->amount_of_rules; JsonArray rules = doc.createNestedArray("rules"); while (rule_ptr != NULL) { @@ -134,8 +255,8 @@ String ESPFirewall::construct_json_firewall() rule["key"] = rule_ptr->key; rule["source"] = rule_ptr->source; rule["destination"] = rule_ptr->destination; - rule["protocol"] = rule_ptr->protocol; - rule["target"] = rule_ptr->target; + rule["protocol"] = protocol_to_string(rule_ptr->protocol); + rule["target"] = target_to_string(rule_ptr->target); rule_ptr = rule_ptr->next; } serializeJson(doc, response); @@ -147,6 +268,13 @@ void ESPFirewall::not_found(AsyncWebServerRequest *request) json_message_response(request, "not found", 404); } +void ESPFirewall::restart_device_handler(AsyncWebServerRequest *request) +{ + json_message_response(request, "restarting device in 2 sec", 200); + sleep(2000); + esp_restart(); +} + void ESPFirewall::get_firewall_rule_handler(AsyncWebServerRequest *request) { int rule_number = request->pathArg(0).toInt(); @@ -184,12 +312,14 @@ void ESPFirewall::post_firewall_handler(AsyncWebServerRequest *request) strcpy(rule_ptr->source, source.length() <= IP4ADDR_STRLEN_MAX ? source.c_str() : ""); String destination = request->getParam("destination")->value(); strcpy(rule_ptr->destination, destination.length() <= IP4ADDR_STRLEN_MAX ? destination.c_str() : ""); + String protocol = request->getParam("protocol")->value(); - strcpy(rule_ptr->protocol, protocol.length() <= PROTOCOL_LENGTH ? protocol.c_str() : ""); + rule_ptr->protocol = string_to_protocol(protocol); String target = request->getParam("target")->value(); - strcpy(rule_ptr->target, target.length() <= TARGET_LENGTH ? target.c_str() : ""); + rule_ptr->target = string_to_target(target); add_rule_to_firewall(rule_ptr); + eeprom_write_firewall_rule(rule_ptr); request->send(200, "application/json", construct_json_firewall_rule(rule_ptr)); } else diff --git a/SourceCode/arduino/lib/Firewall/Firewall.h b/SourceCode/arduino/lib/Firewall/Firewall.h index 7ed2ab7..65033c4 100644 --- a/SourceCode/arduino/lib/Firewall/Firewall.h +++ b/SourceCode/arduino/lib/Firewall/Firewall.h @@ -4,6 +4,7 @@ #include "Arduino.h" #include "AsyncJson.h" #include "ArduinoJson.h" +#include "EEPROM.h" #ifdef ESP32 #include "WiFi.h" #include "AsyncTCP.h" @@ -13,30 +14,60 @@ #endif #include "ESPAsyncWebServer.h" -#define PROTOCOL_LENGTH 4 -#define TARGET_LENGTH 7 +#define eeprom_start_firewall_rules 4 + +typedef enum firewall_target : uint8_t +{ + FW_REJECT = 0, + FW_DROP = 1, + FW_ACCEPT = 2, +} firewall_target_t; + +typedef enum firewall_protocol : uint8_t +{ + FW_TCP = 0, + FW_UDP = 1, + FW_ALL = 255, +} firewall_protocol_t; typedef struct firewall_rule { - int key; + uint8_t key; char source[IP4ADDR_STRLEN_MAX]; char destination[IP4ADDR_STRLEN_MAX]; - char protocol[PROTOCOL_LENGTH]; - char target[TARGET_LENGTH]; + firewall_protocol_t protocol; + firewall_target_t target; struct firewall_rule *next; } firewall_rule_t; class ESPFirewall { - unsigned int amount_of_rules = 0; + uint16_t eeprom_size = 512; + uint8_t amount_of_rules = 0; + uint8_t security_number = 93; + int eeprom_settings_head = 0; + int eeprom_rules_head = eeprom_start_firewall_rules; struct firewall_rule *head = NULL; AsyncWebServer *firewall_api; + // Protocol / Target conversion + String protocol_to_string(firewall_protocol_t &); + firewall_protocol_t string_to_protocol(String &); + String target_to_string(firewall_target_t &); + firewall_target_t string_to_target(String &); + + // EEPROM + void setup_eeprom(); + void eeprom_write_firewall_rule(firewall_rule_t *rule); + void eeprom_write_firewall_rules(); + void eeprom_read_firewall_rule(uint8_t &, uint8_t &); + void eeprom_read_firewall_rules(); + // Firewall Actions void add_rule_to_firewall(firewall_rule_t *); - firewall_rule_t *get_rule_from_firewall(int key); - bool delete_rule_from_firewall(int key); + firewall_rule_t *get_rule_from_firewall(uint8_t); + bool delete_rule_from_firewall(uint8_t); // Firewall-API Actions void setup_routing(); @@ -44,6 +75,7 @@ class ESPFirewall String construct_json_firewall_rule(firewall_rule_t *); String construct_json_firewall(); void not_found(AsyncWebServerRequest *); + void restart_device_handler(AsyncWebServerRequest *); void get_firewall_rule_handler(AsyncWebServerRequest *); void get_firewall_rules_handler(AsyncWebServerRequest *); bool request_has_firewall_parameter(AsyncWebServerRequest *); diff --git a/SourceCode/arduino/src/main.cpp b/SourceCode/arduino/src/main.cpp index 618f54f..08e942a 100644 --- a/SourceCode/arduino/src/main.cpp +++ b/SourceCode/arduino/src/main.cpp @@ -8,12 +8,15 @@ ESPFirewall *firewall; void setup_wifi() { + uint8_t max_retries = 10; + uint8_t retries = 1; log_i("Attempting to connect to WPA SSID: %s", ssid); WiFi.mode(WIFI_STA); WiFi.begin(ssid, psk); - while (WiFi.status() != WL_CONNECTED) + while (WiFi.status() != WL_CONNECTED && retries <= max_retries) { delay(1000); + log_i("Connecting... (%i/%i)", retries++, max_retries); } esp_ip_address = WiFi.localIP().toString().c_str(); log_i("Connected, IP Address: %s", esp_ip_address);