From 235f70357e3f8214897072c36470f1402d490607 Mon Sep 17 00:00:00 2001 From: Thomas Chopitea Date: Mon, 30 Sep 2024 08:43:17 +0200 Subject: [PATCH] Remove templates from database, move to filesystem (#1141) --- core/schemas/template.py | 43 ++++++++++++----- core/web/apiv2/templates.py | 96 +++++++++++++++++-------------------- tests/apiv2/tasks.py | 1 + tests/apiv2/templates.py | 85 ++++++++++++++++---------------- tests/schemas/fixture.py | 2 +- 5 files changed, 120 insertions(+), 107 deletions(-) diff --git a/core/schemas/template.py b/core/schemas/template.py index afc806b0d..95c5e7b09 100644 --- a/core/schemas/template.py +++ b/core/schemas/template.py @@ -1,29 +1,22 @@ import os -from typing import TYPE_CHECKING, ClassVar +from pathlib import Path +from typing import TYPE_CHECKING, Optional import jinja2 +from pydantic import BaseModel -from core import database_arango -from core.schemas.model import YetiModel +from core.config.config import yeti_config if TYPE_CHECKING: from core.schemas.observable import Observable -# TODO: Import Jinja functions to render templates - -class Template(YetiModel, database_arango.ArangoYetiConnector): +class Template(BaseModel): """A template for exporting data to an external system.""" - _collection_name: ClassVar[str] = "templates" - name: str template: str - @classmethod - def load(cls, object: dict) -> "Template": - return cls(**object) - def render(self, data: list["Observable"], output_file: str | None) -> None | str: """Renders the template with the given data to the output file.""" @@ -37,3 +30,29 @@ def render(self, data: list["Observable"], output_file: str | None) -> None | st return None else: return result + + def save(self) -> "Template": + directory = Path( + yeti_config.get("system", "template_dir", "/opt/yeti/templates") + ) + Path.mkdir(directory, parents=True, exist_ok=True) + file = directory / f"{self.name}.jinja2" + file.write_text(self.template) + return self + + def delete(self) -> None: + directory = Path( + yeti_config.get("system", "template_dir", "/opt/yeti/templates") + ) + file = directory / f"{self.name}.jinja2" + file.unlink() + + @classmethod + def find(cls, name: str) -> Optional["Template"]: + directory = Path( + yeti_config.get("system", "template_dir", "/opt/yeti/templates") + ) + file = directory / f"{name}.jinja2" + if file.exists(): + return Template(name=name, template=file.read_text()) + return None diff --git a/core/web/apiv2/templates.py b/core/web/apiv2/templates.py index b9bcf068b..3c0a742be 100644 --- a/core/web/apiv2/templates.py +++ b/core/web/apiv2/templates.py @@ -1,7 +1,11 @@ +import logging +from pathlib import Path + from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict +from core.config.config import yeti_config from core.schemas.observable import Observable from core.schemas.template import Template @@ -10,8 +14,7 @@ class TemplateSearchRequest(BaseModel): model_config = ConfigDict(extra="forbid") - query: dict[str, str | int | list] = {} - sorting: list[tuple[str, bool]] = [] + name: str = "" count: int = 50 page: int = 0 @@ -23,17 +26,17 @@ class TemplateSearchResponse(BaseModel): total: int -class PatchTemplateRequest(BaseModel): - model_config = ConfigDict(extra="forbid") +# class PatchTemplateRequest(BaseModel): +# model_config = ConfigDict(extra="forbid") - template: Template +# template: Template -class RenderExportRequest(BaseModel): +class RenderTemplateRequest(BaseModel): model_config = ConfigDict(extra="forbid") - template_id: str - observable_ids: list[str] | None = None + template_name: str + observable_ids: list[str] = [] search_query: str | None = None @@ -41,42 +44,33 @@ class RenderExportRequest(BaseModel): router = APIRouter() -@router.post("/") -async def new(request: PatchTemplateRequest) -> Template: - """Creates a new template.""" - # TODO: Validate template - return request.template.save() - - -@router.patch("/{template_id}") -async def update(template_id: str, request: PatchTemplateRequest) -> Template: - """Updates a template.""" - db_template = Template.get(template_id) - if not db_template: - raise HTTPException( - status_code=404, detail=f"Template {template_id} not found." - ) - update_data = request.template.model_dump(exclude_unset=True) - updated_template = db_template.model_copy(update=update_data) - new = updated_template.save() - return new - - @router.post("/search") async def search(request: TemplateSearchRequest) -> TemplateSearchResponse: """Searches for observables.""" - query = request.query - templates, total = Template.filter( - query, - offset=request.page * request.count, - count=request.count, - sorting=request.sorting, - ) + glob = "*" + if request.name: + glob = f"*{request.name}*" + + template_dir = yeti_config.get("system", "template_dir", "/opt/yeti/templates") + files = [] + total = 0 + for file in Path(template_dir).rglob(f"{glob}.jinja2"): + total += 1 + files.append(file) + + files = sorted(files) + templates = [] + for file in files[ + (request.page * request.count) : ((request.page + 1) * request.count) + ]: + template = Template(name=file.stem, template=file.read_text()) + templates.append(template) + return TemplateSearchResponse(templates=templates, total=total) @router.post("/render") -async def render(request: RenderExportRequest) -> StreamingResponse: +async def render(request: RenderTemplateRequest) -> StreamingResponse: """Renders a template.""" if not request.search_query and not request.observable_ids: raise HTTPException( @@ -84,10 +78,10 @@ async def render(request: RenderExportRequest) -> StreamingResponse: detail="Must specify either search_query or observable_ids.", ) - template = Template.get(request.template_id) + template = Template.find(name=request.template_name) if not template: raise HTTPException( - status_code=404, detail=f"Template {request.template_id} not found." + status_code=404, detail=f"Template {request.template_name} not found." ) if request.search_query: @@ -97,9 +91,16 @@ async def render(request: RenderExportRequest) -> StreamingResponse: status_code=404, detail="No observables found for search query." ) else: - observables = [ - Observable.get(observable_id) for observable_id in request.observable_ids - ] + observables = [] + for observable_id in request.observable_ids: + db_obs = Observable.get(observable_id) + if not db_obs: + logging.warning( + f"Observable with id {observable_id} not found, skipping..." + ) + continue + observables.append(db_obs) + data = template.render(observables, None) def _stream(): @@ -111,14 +112,3 @@ def _stream(): media_type="text/plain", headers={"Content-Disposition": f"attachment; filename={template.name}.txt"}, ) - - -@router.delete("/{template_id}") -async def delete(template_id: str): - """Deletes a template from the database.""" - template = Template.get(template_id) - if not template: - raise HTTPException( - status_code=404, detail=f"Template {template_id} not found." - ) - template.delete() diff --git a/tests/apiv2/tasks.py b/tests/apiv2/tasks.py index 802cb724b..df78fcd4f 100644 --- a/tests/apiv2/tasks.py +++ b/tests/apiv2/tasks.py @@ -183,3 +183,4 @@ def test_delete_export(self): def tearDown(self) -> None: database_arango.db.clear() + self.template.delete() diff --git a/tests/apiv2/templates.py b/tests/apiv2/templates.py index d5e9e2c4a..6f76cab90 100644 --- a/tests/apiv2/templates.py +++ b/tests/apiv2/templates.py @@ -1,6 +1,8 @@ +import json import logging import sys import unittest +from pathlib import Path from fastapi.testclient import TestClient @@ -30,70 +32,61 @@ def setUp(self) -> None: "/api/v2/auth/api-token", headers={"x-yeti-apikey": user.api_key} ).json() client.headers = {"Authorization": "Bearer " + token_data["access_token"]} - self.template = Template(name="FakeTemplate", template=TEST_TEMPLATE).save() + temp_path = Path("/opt/yeti/templates") + temp_path.mkdir(parents=True, exist_ok=True) + self.temp_template_path = temp_path + + Template(name="FakeTemplate", template=TEST_TEMPLATE).save() + for i in range(0, 100): + Template(name=f"template_blah_{i:02}", template=f"fake_template_{i}").save() def tearDown(self) -> None: + for file in Path(self.temp_template_path).rglob("*.jinja2"): + file.unlink() database_arango.db.clear() def test_search_template(self): - response = client.post("/api/v2/templates/search", json={"query": {"name": ""}}) + response = client.post("/api/v2/templates/search", json={"name": "Fake"}) data = response.json() self.assertEqual(response.status_code, 200, data) self.assertEqual(data["templates"][0]["name"], "FakeTemplate") self.assertEqual(data["total"], 1) - def test_delete_template(self): - response = client.delete(f"/api/v2/templates/{self.template.id}") - self.assertEqual(response.status_code, 200, response.json()) - self.assertEqual(Template.get(self.template.id), None) - - def test_create_template(self): - response = client.post( - "/api/v2/templates/", - json={"template": {"name": "FakeTemplate2", "template": ""}}, - ) + def test_pagination(self): + response = client.post("/api/v2/templates/search", json={"name": "blah"}) data = response.json() + self.assertEqual(response.status_code, 200, data) - self.assertEqual(data["name"], "FakeTemplate2") - self.assertEqual(data["template"], "") - self.assertEqual(data["id"], Template.find(name="FakeTemplate2").id) + self.assertEqual(len(data["templates"]), 50) + self.assertEqual(data["templates"][0]["name"], "template_blah_00") + self.assertEqual(data["templates"][49]["name"], "template_blah_49") + self.assertEqual(data["total"], 100) - def test_update_template(self): - response = client.patch( - f"/api/v2/templates/{self.template.id}", - json={ - "template": { - "name": "FakeTemplateFoo", - "template": "", - } - }, + response = client.post( + "/api/v2/templates/search", json={"name": "blah", "page": 3, "count": 5} ) data = response.json() - self.assertEqual(response.status_code, 200, data) - self.assertEqual(data["name"], "FakeTemplateFoo") - self.assertEqual(data["template"], "") - self.assertEqual(data["id"], self.template.id) - db_template = Template.get(self.template.id) - self.assertEqual(db_template.template, "") - self.assertEqual(db_template.name, "FakeTemplateFoo") - self.assertEqual(db_template.id, data["id"]) + self.assertEqual(len(data["templates"]), 5) + self.assertEqual(data["templates"][0]["name"], "template_blah_15") + self.assertEqual(data["templates"][4]["name"], "template_blah_19") - def test_render_template_by_id(self): + def test_render_template_by_obs_ids(self): ipv4.IPv4(value="1.1.1.1").save() ipv4.IPv4(value="2.2.2.2").save() ipv4.IPv4(value="3.3.3.3").save() response = client.post( "/api/v2/templates/render", json={ - "template_id": self.template.id, + "template_name": "FakeTemplate", "observable_ids": [o.id for o in Observable.list()], }, ) data = response.text - response.headers["Content-Disposition"] = ( - "attachment; filename=FakeTemplate.txt" - ) self.assertEqual(response.status_code, 200, data) + self.assertEqual( + response.headers["Content-Disposition"], + "attachment; filename=FakeTemplate.txt", + ) self.assertEqual(data, "\n1.1.1.1\n2.2.2.2\n3.3.3.3\n\n\n") def test_render_template_by_search(self): @@ -103,11 +96,21 @@ def test_render_template_by_search(self): hostname.Hostname(value="hacker.com").save() response = client.post( "/api/v2/templates/render", - json={"template_id": self.template.id, "search_query": "yeti"}, + json={"template_name": "FakeTemplate", "search_query": "yeti"}, ) data = response.text - response.headers["Content-Disposition"] = ( - "attachment; filename=FakeTemplate.txt" - ) self.assertEqual(response.status_code, 200, data) + self.assertEqual( + response.headers["Content-Disposition"], + "attachment; filename=FakeTemplate.txt", + ) self.assertEqual(data, "\nyeti1.com\nyeti2.com\nyeti3.com\n\n\n") + + def test_render_nonexistent(self): + response = client.post( + "/api/v2/templates/render", + json={"template_name": "NotExist", "search_query": "yeti"}, + ) + data = response.text + self.assertEqual(response.status_code, 404, data) + self.assertEqual(json.loads(data), {"detail": "Template NotExist not found."}) diff --git a/tests/schemas/fixture.py b/tests/schemas/fixture.py index 4ba63a045..0a4a48f16 100644 --- a/tests/schemas/fixture.py +++ b/tests/schemas/fixture.py @@ -21,7 +21,7 @@ def setUp(self) -> None: database_arango.db.connect(database="yeti_test") database_arango.db.clear() - def test_something(self): + def general_fixture_test(self): user = UserSensitive(username="yeti", admin=True, enabled=True) user.set_password("yeti") user.save()