#include "API.hpp" namespace fw { API::API(fw::Firewall *firewall, const char *cert, const char *key, const char *username, const char *password, const String ip, const uint16_t port) { this->firewall = firewall; this->api_ip = ip; this->api_port = port; if (this->setup_auth(username, password) == ERROR) endless_loop(); #ifdef ESP32 this->server = new WebServer(port); #elif defined(ESP8266) this->server = new ESP8266WebServerSecure(port); this->serverCache = new ServerSessions(5); #endif this->setup_routing(cert, key); this->server->begin(); Serial.printf("%s listening on port %i\n", log, port); } API::~API() { this->server->stop(); } void API::handle_client() { this->server->handleClient(); } String API::get_url_base() { #ifdef ESP32 return "http://" + this->api_ip + ":" + this->api_port; #elif defined(ESP8266) return "https://" + this->api_ip + ":" + this->api_port; #endif } ok_t API::setup_auth(const char *username, const char *password) { if (!username || *username == 0x00 || strlen(username) > CREDENTIALS_LENGTH) { Serial.printf("%s Username too long or missing!\n", log); return ERROR; } strncpy(credentials.username, username, CREDENTIALS_LENGTH); if (!password || *password == 0x00 || strlen(password) > CREDENTIALS_LENGTH) { Serial.printf("%s Password too long or missing!\n", log); return ERROR; } strncpy(credentials.password, password, CREDENTIALS_LENGTH); return SUCCESS; } auth_t API::check_auth() { if (server->authenticate(this->credentials.username, this->credentials.password)) { return AUTHENTICATED; } else { this->json_message_response("unauthorised", 403); return DENIED; } } void API::setup_routing(const char *cert, const char *key) { #ifdef ESP8266 this->server->getServer().setRSACert(new BearSSL::X509List(cert), new BearSSL::PrivateKey(key)); this->server->getServer().setCache(serverCache); #endif this->server->on("/api/firewall/rules", HTTP_GET, std::bind(&API::get_firewall_rules_handler, this)); this->server->on(UriRegex("/api/firewall/rules/([0-9]+)"), HTTP_GET, std::bind(&API::get_firewall_rule_handler, this)); this->server->on("/api/firewall/rules", HTTP_POST, std::bind(&API::post_firewall_handler, this)); this->server->on(UriRegex("/api/firewall/rules/([0-9]+)"), HTTP_DELETE, std::bind(&API::delete_firewall_handler, this)); this->server->on("/api", HTTP_GET, std::bind(&API::get_endpoint_list_handler, this)); this->server->onNotFound(std::bind(&API::not_found_handler, this)); add_endpoint_to_list("/api/firewall/rules", "GET", "Get all Firewall Rules"); add_endpoint_to_list("/api/firewall/rules/", "GET", "Get Firewall Rule by key"); add_endpoint_to_list("/api/firewall/rules", "POST", "Create Firewall Rule"); add_endpoint_to_list("/api/firewall/rules/", "DELETE", "Delete Firewall Rule by key"); } void API::add_endpoint_to_list(const String uri, const char *method, const char *description) { api_endpoint_t *temp; const String url = get_url_base() + uri; api_endpoint_t *api_ptr = (api_endpoint_t *)malloc(sizeof(api_endpoint_t)); strncpy(api_ptr->uri, url.c_str(), sizeof(api_ptr->uri)); strncpy(api_ptr->method, method, sizeof(api_ptr->method)); strncpy(api_ptr->description, description, sizeof(api_ptr->description)); if (this->endpoint_head == NULL) { this->endpoint_head = api_ptr; api_ptr->next = NULL; return; } temp = this->endpoint_head; while (temp->next != NULL) { temp = temp->next; } temp->next = api_ptr; api_ptr->next = NULL; return; } void API::not_found_handler() { this->json_message_response("see " + get_url_base() + "/api for available routes", 404); } void API::get_endpoint_list_handler() { this->json_array_response(this->construct_json_api(), 200); } void API::get_firewall_rule_handler() { if (this->check_auth() == DENIED) return; String param = this->server->pathArg(0); int rule_number = atoi(param.c_str()); firewall_rule_t *rule_ptr = firewall->get_rule_from_firewall(rule_number); if (rule_ptr == NULL) this->json_message_response("rule does not exist", 404); else this->json_generic_response(this->construct_json_firewall_rule(rule_ptr), 200); } void API::get_firewall_rules_handler() { if (this->check_auth() == DENIED) return; this->json_array_response(this->construct_json_firewall(), 200); } void API::post_firewall_handler() { if (this->check_auth() == DENIED) return; if (request_has_all_firewall_parameter()) { String args[IPV4ADDRESS_LENGTH] = {}; for (uint8_t i = 0; i < firewall_fields_amount; i++) { args[i] = this->server->arg(firewall_fields[i]); } firewall_rule_t *rule_ptr = firewall->add_rule_to_firewall(args); this->json_generic_response(this->construct_json_firewall_rule(rule_ptr), 201); } else { this->json_message_response("not enough parameter provided", 400); } } void API::delete_firewall_handler() { if (this->check_auth() == DENIED) return; String param = this->server->pathArg(0); int rule_number = atoi(param.c_str()); if (firewall->delete_rule_from_firewall(rule_number) == SUCCESS) this->json_message_response("firewall rule deleted", 200); else this->json_message_response("cannot delete firewall rule", 500); } bool API::request_has_all_firewall_parameter() { if (!this->server->args()) return false; else { for (uint8_t i = 0; i < firewall_fields_amount; i++) { if (i != KEY && !this->server->hasArg(firewall_fields[i])) return false; } return true; } } String API::json_new_attribute(String key, String value, bool last) { String json_string; json_string += "\"" + key + "\": \"" + value + "\""; if (!last) json_string += ","; return json_string; } String API::json_new_attribute(String key, uint32_t value, bool last) { return json_new_attribute(key, String(value), last); } void API::json_generic_response(String serialized_string, const uint16_t response_code) { this->server->send(response_code, json_response_type, serialized_string); } void API::json_array_response(String serialized_string, const uint16_t response_code) { this->server->send(response_code, json_response_type, "[" + serialized_string + "]"); } void API::json_message_response(String message, const uint16_t response_code) { String serialized_string = "{"; serialized_string += json_new_attribute("message", message, true); serialized_string += "}"; this->server->send(response_code, json_response_type, serialized_string); } String API::construct_json_firewall_rule(firewall_rule_t *rule_ptr) { String serialized_string = "{"; serialized_string += json_new_attribute(firewall_fields[KEY], rule_ptr->key); serialized_string += json_new_attribute(firewall_fields[IP], rule_ptr->ip); serialized_string += json_new_attribute(firewall_fields[PORT_FROM], rule_ptr->port_from); serialized_string += json_new_attribute(firewall_fields[PORT_TO], rule_ptr->port_to); serialized_string += json_new_attribute(firewall_fields[PROTOCOL], protocol_to_string(rule_ptr->protocol)); serialized_string += json_new_attribute(firewall_fields[TARGET], target_to_string(rule_ptr->target), true); serialized_string += "}"; return serialized_string; } String API::construct_json_firewall() { firewall_rule_t *rule_ptr = firewall->get_rule_head(); String serialized_string; while (rule_ptr != NULL) { serialized_string += construct_json_firewall_rule(rule_ptr); rule_ptr = rule_ptr->next; if (rule_ptr != NULL) serialized_string += ","; } return serialized_string; } String API::construct_json_api_endpoint(api_endpoint_t *api_ptr) { String serialized_string = "{"; serialized_string += json_new_attribute("endpoint", api_ptr->uri); serialized_string += json_new_attribute("description", api_ptr->description); serialized_string += json_new_attribute("method", api_ptr->method, true); serialized_string += "}"; return serialized_string; } String API::construct_json_api() { api_endpoint_t *api_ptr = this->endpoint_head; String serialized_string; while (api_ptr != NULL) { serialized_string += construct_json_api_endpoint(api_ptr); api_ptr = api_ptr->next; if (api_ptr != NULL) serialized_string += ","; } return serialized_string; } }