diff --git a/ESPFirewall/lib/Firewall/src/API.cpp b/ESPFirewall/lib/Firewall/src/API.cpp index f049967..75b9dcd 100644 --- a/ESPFirewall/lib/Firewall/src/API.cpp +++ b/ESPFirewall/lib/Firewall/src/API.cpp @@ -2,10 +2,11 @@ namespace fw { - API::API(const char *cert, const char *key, const char *username, const char *password, const String ip, const uint16_t port) + API::API(fw::Firewall *firewall, const char *cert, const char *key, const char *username, const char *password, const uint16_t port) { - this->server_ip = ip; - this->server_port = port; + this->firewall = firewall; + this->api_ip = WiFi.localIP().toString(); + this->api_port = port; if (this->setup_auth(username, password) == ERROR) endless_loop(); #ifdef ESP32 @@ -32,9 +33,9 @@ namespace fw String API::get_url_base() { #ifdef ESP32 - return "http://" + this->server_ip + ":" + this->server_port; + return "http://" + this->api_ip + ":" + this->api_port; #elif defined(ESP8266) - return "https://" + this->server_ip + ":" + this->server_port; + return "https://" + this->api_ip + ":" + this->api_port; #endif } @@ -75,22 +76,19 @@ namespace fw this->server->getServer().setCache(serverCache); #endif this->server->on("/firewall", HTTP_GET, std::bind(&API::get_firewall_rules_handler, this)); - add_api_endpoint("/firewall", "GET", "Get all Firewall Rules"); - this->server->on(UriRegex("/firewall/([0-9]+)"), HTTP_GET, std::bind(&API::get_firewall_rule_handler, this)); - add_api_endpoint("/firewall/1", "GET", "Get Firewall Rule by key"); - this->server->on("/firewall", HTTP_POST, std::bind(&API::post_firewall_handler, this)); - add_api_endpoint("/firewall", "POST", "Create Firewall Rule"); - this->server->on(UriRegex("/firewall/([0-9]+)"), HTTP_DELETE, std::bind(&API::delete_firewall_handler, this)); - add_api_endpoint("/firewall/1", "DELETE", "Delete Firewall Rule by key"); - - this->server->on("/api", HTTP_GET, std::bind(&API::api_endpoints_handler, this)); + this->server->on("/api", HTTP_GET, std::bind(&API::get_endpoint_list_handler, this)); this->server->onNotFound(std::bind(&API::not_found_handler, this)); + + add_endpoint_to_list("/firewall", "GET", "Get all Firewall Rules"); + add_endpoint_to_list("/firewall/1", "GET", "Get Firewall Rule by key"); + add_endpoint_to_list("/firewall", "POST", "Create Firewall Rule"); + add_endpoint_to_list("/firewall/1", "DELETE", "Delete Firewall Rule by key"); } - void API::add_api_endpoint(const String uri, const char *method, const char *description) + void API::add_endpoint_to_list(const String uri, const char *method, const char *description) { api_endpoint_t *temp; const String url = get_url_base() + uri; @@ -123,7 +121,7 @@ namespace fw 404); } - void API::api_endpoints_handler() + void API::get_endpoint_list_handler() { this->json_generic_response(this->construct_json_api(), 200); } @@ -134,7 +132,7 @@ namespace fw return; String param = this->server->pathArg(0); int rule_number = atoi(param.c_str()); - firewall_rule_t *rule_ptr = get_rule_from_firewall(rule_number); + firewall_rule_t *rule_ptr = firewall->get_rule_from_firewall(rule_number); if (rule_ptr == NULL) this->json_message_response("rule does not exist", 404); else @@ -152,9 +150,9 @@ namespace fw { if (this->check_auth() == DENIED) return; - if (request_has_firewall_parameter()) + if (request_has_all_firewall_parameter()) { - firewall_rule_t *rule_ptr = add_rule_to_firewall( + firewall_rule_t *rule_ptr = firewall->add_rule_to_firewall( this->server->arg("source"), this->server->arg("destination"), this->server->arg("port_from"), @@ -175,13 +173,13 @@ namespace fw return; String param = this->server->pathArg(0); int rule_number = atoi(param.c_str()); - if (delete_rule_from_firewall(rule_number) == SUCCESS) + if (firewall->delete_rule_from_firewall(rule_number) == SUCCESS) this->json_message_response("firewall rule deleted", 200); else this->json_message_response("cannot delete firewall rule", 500); } - bool API::request_has_firewall_parameter() + bool API::request_has_all_firewall_parameter() { if (!this->server->args()) return false; @@ -242,7 +240,7 @@ namespace fw String API::construct_json_firewall() { - firewall_rule_t *rule_ptr = rule_head; + firewall_rule_t *rule_ptr = firewall->get_rule_head(); String serialized_string; while (rule_ptr != NULL) { diff --git a/ESPFirewall/lib/Firewall/src/API.hpp b/ESPFirewall/lib/Firewall/src/API.hpp index f7377ff..d3339da 100644 --- a/ESPFirewall/lib/Firewall/src/API.hpp +++ b/ESPFirewall/lib/Firewall/src/API.hpp @@ -8,14 +8,18 @@ #endif #include "uri/UriRegex.h" - -#include "Rules.hpp" +#include "Firewall.hpp" #include "Utils.hpp" namespace fw { - class API : public Rules + class API { + public: + API(Firewall *, const char *cert, const char *key, const char *username, const char *password, const uint16_t port = 8080); + ~API(); + void handle_client(); + private: #ifdef ESP32 WebServer *server; @@ -23,41 +27,36 @@ namespace fw BearSSL::ESP8266WebServerSecure *server; BearSSL::ServerSessions *serverCache; #endif + Firewall *firewall; credential_t credentials; api_endpoint_t *endpoint_head = NULL; + String api_ip = "0.0.0.0"; + uint16_t api_port; + String get_url_base(); ok_t setup_auth(const char *username, const char *password); auth_t check_auth(); void setup_routing(const char *cert, const char *key); - void add_api_endpoint(const String uri, const char *method, const char *description); + void add_endpoint_to_list(const String uri, const char *method, const char *description); + void not_found_handler(); + void get_endpoint_list_handler(); void get_firewall_rule_handler(); void get_firewall_rules_handler(); void post_firewall_handler(); void delete_firewall_handler(); - void not_found_handler(); - void api_endpoints_handler(); - bool request_has_firewall_parameter(); + bool request_has_all_firewall_parameter(); String json_new_attribute(String key, String value, bool last = false); String json_new_attribute(String key, uint32_t value, bool last = false); void json_generic_response(String serialized_string, const uint16_t response_code); void json_message_response(String message, const uint16_t response_code); - String construct_json_firewall_rule(firewall_rule_t *); + + String construct_json_firewall_rule(firewall_rule_t *rule_ptr); String construct_json_firewall(); - String construct_json_api_endpoint(api_endpoint_t *); + String construct_json_api_endpoint(api_endpoint_t *api_ptr); String construct_json_api(); String construct_json_begin(const uint16_t response_code); - - protected: - String server_ip; - uint16_t server_port; - void handle_client(); - String get_url_base(); - - public: - API(const char *cert, const char *key, const char *username, const char *password, const String ip, const uint16_t port); - ~API(); }; }