diff --git a/SourceCode/arduino/lib/Firewall/Firewall.cpp b/SourceCode/arduino/lib/Firewall/Firewall.cpp index 06d482c..281d1c8 100644 --- a/SourceCode/arduino/lib/Firewall/Firewall.cpp +++ b/SourceCode/arduino/lib/Firewall/Firewall.cpp @@ -7,19 +7,13 @@ ESPFirewall::ESPFirewall(int port) this->setup_routing(); } -void ESPFirewall::add_rule_to_firewall(const char *source, const char *destination, const char *protocol, const char *target) +void ESPFirewall::add_rule_to_firewall(firewall_rule_t *rule) { firewall_rule_t *temp; - firewall_rule_t *link = (firewall_rule_t *)malloc(sizeof(firewall_rule_t)); - link->key = ++amount_of_rules; - strcpy(link->source, source); - strcpy(link->destination, destination); - strcpy(link->protocol, protocol); - strcpy(link->target, target); if (head == NULL) { - head = link; - link->next = NULL; + head = rule; + rule->next = NULL; return; } temp = head; @@ -27,8 +21,8 @@ void ESPFirewall::add_rule_to_firewall(const char *source, const char *destinati { temp = temp->next; } - temp->next = link; - link->next = NULL; + temp->next = rule; + rule->next = NULL; return; } @@ -58,17 +52,30 @@ void ESPFirewall::post_firewall_handler(AsyncWebServerRequest *request) DynamicJsonDocument json(1024); String response; int response_code; - if (request->hasArg("source") || request->hasArg("destination") || request->hasArg("protocol") || request->hasArg("target")) + if (request_has_firewall_parameter(request)) { + firewall_rule_t *rule = (firewall_rule_t *)malloc(sizeof(firewall_rule_t)); + rule->key = ++amount_of_rules; + const char *source = request->arg("source").c_str(); + strcpy(rule->source, strlen(source) <= IP4ADDR_STRLEN_MAX ? source : "-"); + const char *destination = request->arg("destination").c_str(); + strcpy(rule->destination, strlen(destination) <= IP4ADDR_STRLEN_MAX ? destination : "-"); + const char *protocol = request->arg("protocol").c_str(); + strcpy(rule->protocol, strlen(protocol) <= PROTOCOL_LENGTH ? protocol : "-"); + const char *target = request->arg("target").c_str(); - json["source"] = source; - json["destination"] = destination; - json["protocol"] = protocol; - json["target"] = target; - add_rule_to_firewall(source, destination, protocol, target); + strcpy(rule->target, strlen(target) <= TARGET_LENGTH ? target : "-"); + + add_rule_to_firewall(rule); + + json["key"] = rule->key; + json["source"] = rule->source; + json["destination"] = rule->destination; + json["protocol"] = rule->protocol; + json["target"] = rule->target; response_code = 200; } else @@ -89,6 +96,11 @@ void ESPFirewall::not_found(AsyncWebServerRequest *request) request->send(404, "application/json", response); } +bool ESPFirewall::request_has_firewall_parameter(AsyncWebServerRequest *request) +{ + return request->hasArg("source") || request->hasArg("destination") || request->hasArg("protocol") || request->hasArg("target"); +} + void ESPFirewall::setup_routing() { firewall_api->on("/api/v1/firewall", HTTP_GET, std::bind(&ESPFirewall::get_firewall_handler, this, std::placeholders::_1)); diff --git a/SourceCode/arduino/lib/Firewall/Firewall.h b/SourceCode/arduino/lib/Firewall/Firewall.h index 662d601..77c916d 100644 --- a/SourceCode/arduino/lib/Firewall/Firewall.h +++ b/SourceCode/arduino/lib/Firewall/Firewall.h @@ -13,13 +13,16 @@ #endif #include "ESPAsyncWebServer.h" +#define PROTOCOL_LENGTH 4 +#define TARGET_LENGTH 7 + typedef struct firewall_rule { int key; char source[IP4ADDR_STRLEN_MAX]; char destination[IP4ADDR_STRLEN_MAX]; - char protocol[4]; - char target[7]; + char protocol[PROTOCOL_LENGTH]; + char target[TARGET_LENGTH]; struct firewall_rule *next; } firewall_rule_t; @@ -29,10 +32,11 @@ class ESPFirewall unsigned int amount_of_rules = 0; struct firewall_rule *head = NULL; - void add_rule_to_firewall(const char *source, const char *destination, const char *protocol, const char *target); + void add_rule_to_firewall(firewall_rule_t *rule); void get_firewall_handler(AsyncWebServerRequest *request); void post_firewall_handler(AsyncWebServerRequest *request); void not_found(AsyncWebServerRequest *request); + bool request_has_firewall_parameter(AsyncWebServerRequest *request); void setup_routing(); public: