#include "API.hpp" namespace fw { API::API(const char *username, const char *password, const uint16_t port) { if (this->setup_auth(username, password) == ERROR) endless_loop(); #ifdef ESP32 this->server = new WebServer(port); #elif defined(ESP8266) this->server = new ESP8266WebServerSecure(port); this->serverCache = new ServerSessions(5); #endif this->setup_routing(); Serial.println("Starting server..."); this->server->begin(); } API::~API() { } void API::handle_client() { this->server->handleClient(); } ok_t API::setup_auth(const char *username, const char *password) { if (!username || *username == 0x00 || strlen(username) > CREDENTIALS_LENGTH) { Serial.println("Username too long or missing!"); return ERROR; } strncpy(credentials.username, username, CREDENTIALS_LENGTH); if (!password || *password == 0x00 || strlen(password) > CREDENTIALS_LENGTH) { Serial.println("Password too long or missing!"); return ERROR; } strncpy(credentials.password, password, CREDENTIALS_LENGTH); return SUCCESS; } auth_t API::check_auth() { if (server->authenticate(this->credentials.username, this->credentials.password)) { return AUTHENTICATED; } else { this->json_message_response("unauthorized", 403); return DENIED; } } void API::setup_routing() { #ifdef ESP8266 this->server->getServer().setRSACert(new BearSSL::X509List(serverCert), new BearSSL::PrivateKey(serverKey)); this->server->getServer().setCache(serverCache); #endif this->server->on(UriRegex("/api/v1/firewall/([0-9]+)"), HTTP_GET, std::bind(&API::get_firewall_rule_handler, this)); this->server->on("/api/v1/firewall", HTTP_GET, std::bind(&API::get_firewall_rules_handler, this)); this->server->on("/api/v1/firewall", HTTP_POST, std::bind(&API::post_firewall_handler, this)); this->server->on(UriRegex("/api/v1/firewall/([0-9]+)"), HTTP_DELETE, std::bind(&API::delete_firewall_handler, this)); this->server->onNotFound(std::bind(&API::not_found_handler, this)); } void API::not_found_handler() { this->json_message_response("not found", 404); } void API::get_firewall_rule_handler() { if (this->check_auth() == DENIED) return; String param = this->server->pathArg(0); int rule_number = atoi(param.c_str()); firewall_rule_t *rule_ptr = get_rule_from_firewall(rule_number); if (rule_ptr == NULL) { this->json_message_response("rule not found", 404); } else { this->json_generic_response(construct_json_firewall_rule(rule_ptr), 200); } } void API::get_firewall_rules_handler() { if (this->check_auth() == DENIED) return; this->json_generic_response(this->construct_json_firewall(), 200); } void API::post_firewall_handler() { if (this->check_auth() == DENIED) return; if (request_has_firewall_parameter()) { firewall_rule_t *rule_ptr = (firewall_rule_t *)malloc(sizeof(firewall_rule_t)); rule_ptr->key = ++amount_of_rules; String source = this->server->arg("source"); strncpy(rule_ptr->source, source.c_str(), sizeof(rule_ptr->source)); String destination = this->server->arg("destination"); strncpy(rule_ptr->destination, destination.c_str(), sizeof(rule_ptr->destination)); String protocol = this->server->arg("protocol"); rule_ptr->protocol = string_to_protocol(protocol); String target = this->server->arg("target"); rule_ptr->target = string_to_target(target); add_rule_to_firewall(rule_ptr); this->json_generic_response(this->construct_json_firewall_rule(rule_ptr), 200); } else { this->json_message_response("not enough parameter", 400); } } void API::delete_firewall_handler() { if (this->check_auth() == DENIED) return; String param = this->server->pathArg(0); int rule_number = atoi(param.c_str()); if (delete_rule_from_firewall(rule_number) == SUCCESS) { this->json_message_response("firewall rule deleted", 200); } else { this->json_message_response("cannot delete firewall rule", 500); } } bool API::request_has_firewall_parameter() { if (!this->server->args()) { return false; } else { return this->server->hasArg("source") || this->server->hasArg("destination") || this->server->hasArg("protocol") || this->server->hasArg("target"); } } String API::json_new_attribute(String key, String value, bool last) { String json_string; json_string += "\"" + key + "\": \"" + value + "\""; if (!last) json_string += ","; return json_string; } String API::json_new_attribute(String key, uint8_t value, bool last) { return json_new_attribute(key, String(value), last); } void API::json_generic_response(String serialized_string, const uint16_t response_code) { this->server->send(response_code, "application/json; charset=utf-8", serialized_string); } void API::json_message_response(String message, const uint16_t response_code) { String serialized_string = "{"; serialized_string += json_new_attribute("message", message, true); serialized_string += "}"; this->server->send(response_code, "application/json; charset=utf-8", serialized_string); } String API::construct_json_firewall_rule(firewall_rule_t *rule_ptr) { String serialized_string = "{"; serialized_string += json_new_attribute("key", rule_ptr->key); serialized_string += json_new_attribute("source", rule_ptr->source); serialized_string += json_new_attribute("destination", rule_ptr->destination); serialized_string += json_new_attribute("protocol", protocol_to_string(rule_ptr->protocol)); serialized_string += json_new_attribute("target", target_to_string(rule_ptr->target), true); serialized_string += "}"; return serialized_string; } String API::construct_json_firewall() { firewall_rule_t *rule_ptr = head; String serialized_string = "{"; serialized_string += json_new_attribute("amount_of_rules", amount_of_rules); serialized_string += "\"rules\": ["; while (rule_ptr != NULL) { serialized_string += construct_json_firewall_rule(rule_ptr); serialized_string += ","; rule_ptr = rule_ptr->next; } serialized_string += "]}"; return serialized_string; } }