Skip to content

Commit

Permalink
Add tests for server
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Feb 16, 2025
1 parent 18257e1 commit 8c1f71b
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 0 deletions.
163 changes: 163 additions & 0 deletions docker/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
import json
import base64
import pytest
import uuid
import textwrap
from datetime import datetime

from fastapi.testclient import TestClient
from moto import mock_aws # Use the generic AWS mock

# Set the environment variables before importing your app.
os.environ["DYNAMODB_CACHE_TABLE"] = "TensorgradCache"
os.environ["DYNAMODB_SNIPPET_TABLE"] = "CodeSnippets"

# Now import the app and helper functions from your FastAPI application.
from drawTensors import app, safe_execute, create_snippet, get_snippet, ExecutionResult

# Create a TestClient for FastAPI
client = TestClient(app)


@pytest.fixture(autouse=True)
def setup_dynamodb():
"""
Use Moto to mock AWS and create the necessary DynamoDB tables for testing.
"""
with mock_aws():
import boto3 # Import boto3 within the Moto context

dynamodb = boto3.resource("dynamodb", region_name="us-east-1")

# Create the CodeCache table
try:
dynamodb.create_table(
TableName="CodeCache",
KeySchema=[{"AttributeName": "code_hash", "KeyType": "HASH"}],
AttributeDefinitions=[
{"AttributeName": "code_hash", "AttributeType": "S"}
],
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
)
except Exception:
pass

# Create the CodeSnippets table
try:
dynamodb.create_table(
TableName="CodeSnippets",
KeySchema=[{"AttributeName": "snippet_id", "KeyType": "HASH"}],
AttributeDefinitions=[
{"AttributeName": "snippet_id", "AttributeType": "S"}
],
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
)
except Exception:
pass

# Yield to run tests inside the Moto context.
yield


# -------- Tests for safe_execute --------


def test_safe_execute_success():
code = "print('Hello, world!')"
result: ExecutionResult = safe_execute(code)
assert result.success is True
assert "Hello, world!" in result.output
# For this code, there should be no image, error, or stacktrace.
assert result.image is None
assert result.error == ""
assert result.stacktrace is None


def test_safe_execute_syntax_error():
code = "print('Hello" # unbalanced quote causes SyntaxError
result: ExecutionResult = safe_execute(code)
assert result.success is False
assert "SyntaxError" in result.error or "unterminated" in result.error


def test_safe_execute_unsafe_code():
# Attempt to import os should be blocked.
code = "import os\nprint('Unsafe')"
result: ExecutionResult = safe_execute(code)
assert result.success is False
assert "unsafe operations" in result.error


# -------- Tests for /execute endpoint --------


def test_execute_endpoint_success():
payload = {"code": "print('Test Execute')"}
response = client.post("/execute", json=payload)
assert response.status_code == 200
data = response.json()
# Validate using the ExecutionResult schema fields.
assert data["success"] is True
assert "Test Execute" in data["output"]
assert data["error"] == ""


def test_execute_good_code_with_image():
code = textwrap.dedent(
"""
i = sp.symbols('i');
x = tg.Delta(i, 'i', 'j');
y = x * 2;
save_steps(y);
"""
)
response = client.post("/execute", json={"code": code})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "image" in data
# parse image
image_data = data["image"]
assert image_data.startswith("data:image/png;base64,")


def test_execute_endpoint_invalid_payload():
# Send an empty payload.
response = client.post("/execute", json={"code": ""})
assert response.status_code == 400
data = response.json()
assert "No code provided" in data["detail"]


# -------- Tests for Snippet Endpoints --------


def test_create_and_fetch_snippet():
code = "print('Snippet Test')"
# Create a snippet.
response_post = client.post("/snippets", json={"code": code})
assert response_post.status_code == 200
post_data = response_post.json()
assert "snippet_id" in post_data
snippet_id = post_data["snippet_id"]

# Retrieve the snippet.
response_get = client.get(f"/snippets/{snippet_id}")
assert response_get.status_code == 200
get_data = response_get.json()
# Verify that the snippet contains the correct code.
assert get_data["snippet_id"] == snippet_id
assert get_data["code"] == code
# Verify created_at and author_id are present.
assert "created_at" in get_data
assert "author_id" in get_data


def test_fetch_nonexistent_snippet():
# Attempt to fetch a snippet that doesn't exist.
fake_snippet_id = str(uuid.uuid4())
response_get = client.get(f"/snippets/{fake_snippet_id}")
assert response_get.status_code == 404
data = response_get.json()
assert "not found" in data["detail"]
159 changes: 159 additions & 0 deletions docker/tests/test_docker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import os
import sys
import subprocess
import time
import requests
import pytest
import json
import socket
import time
from typing import Optional
from pydantic import BaseModel

from drawTensors import CodePayload, ExecutionResult, SnippetCreationResponse, Snippet

PORT = 9000 # Host port
LAMBDA_URL = f"http://localhost:{PORT}/2015-03-31/functions/function/invocations"


def wait_for_port(host: str, port: int, timeout: float = 10.0):
"""Wait for a network port to become available."""
deadline = time.time() + timeout
while time.time() < deadline:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
result = sock.connect_ex((host, port))
if result == 0:
return True
time.sleep(0.5)
raise TimeoutError(f"Timeout waiting for {host}:{port} to become available.")


@pytest.fixture(scope="session")
def docker_container():
"""
Build the Docker image once per session, run the container, yield the container ID,
and then stop and remove the container after the session ends.
"""

# Run the build command with live output.
build_proc = subprocess.run(
[
"docker",
"buildx",
"build",
"-t",
"tensorgrad",
"-f",
"docker/Dockerfile",
".",
],
capture_output=True,
text=True,
cwd="..",
)
assert build_proc.returncode == 0, f"Docker build failed:\n{build_proc.stderr}"

# Run the Docker container in detached mode (mapping host port PORT to container port 8080).
run_proc = subprocess.run(
[
"docker",
"run",
"-d",
"--name",
"lambda_local_test",
"-p",
f"{PORT}:8080",
"tensorgrad",
],
capture_output=True,
text=True,
cwd="..",
)
assert run_proc.returncode == 0, f"Docker run failed:\n{run_proc.stderr}"
container_id = run_proc.stdout.strip()

# Allow time for the container to start.
wait_for_port("localhost", PORT, timeout=10)
time.sleep(1)
print("Docker container is running.")

try:
yield container_id

finally:
print("Cleaning up Docker container...")
logs = subprocess.run(
["docker", "logs", container_id], capture_output=True, text=True
)
print("\nDocker container logs:\n", logs.stdout, logs.stderr)

# Teardown: stop and remove the container.
subprocess.run(
["docker", "stop", container_id], stdout=sys.stdout, stderr=sys.stderr
)
subprocess.run(
["docker", "rm", container_id], stdout=sys.stdout, stderr=sys.stderr
)


def invoke_api(
body: Optional[BaseModel] = None,
path: str = "/",
method: str = "GET",
pathParameters: dict = None,
) -> requests.Response:
"""
Helper function that creates a simulated API Gateway event and invokes the Lambda endpoint.
"""
response = requests.post(
LAMBDA_URL,
json={
"resource": path,
"path": path,
"httpMethod": method,
"queryStringParameters": None,
"pathParameters": pathParameters,
"body": None if body is None else body.model_dump_json(),
"isBase64Encoded": False,
"requestContext": {
"resourcePath": path,
"httpMethod": method,
"path": path,
},
},
headers={"Content-Type": "application/json"},
)
assert response.status_code == 200, f"API failed: {response.text}"
print(response)
data = response.json()
assert data["statusCode"] == 200, f"API failed: {data['body']}"
print("reponse.json()=", data)
return json.loads(data["body"])


def test_docker_image_execute_endpoint(docker_container):
payload_obj = CodePayload(code="print('Hello from Docker')")
response = invoke_api(body=payload_obj, path="/execute", method="POST")
result = ExecutionResult.model_validate(response)
assert result.success is True, f"Execution did not succeed: {result}"


def test_docker_snippet_endpoints(docker_container):
# ---- Snippet Creation ----
payload_obj = CodePayload(code="print('Hello, snippet!')")
response_body = invoke_api(body=payload_obj, path="/snippets", method="POST")
snippet_creation = SnippetCreationResponse.model_validate(response_body)
assert snippet_creation.snippet_id is not None
snippet_id = snippet_creation.snippet_id

# ---- Snippet Retrieval ----
response_get = invoke_api(
path=f"/snippets/{snippet_id}",
method="GET",
pathParameters={"snippet_id": snippet_id},
)
snippet_obj = Snippet.model_validate(response_get)
assert snippet_obj.snippet_id == snippet_id
assert snippet_obj.code == payload_obj.code
assert snippet_obj.created_at is not None
assert snippet_obj.author_id is not None

0 comments on commit 8c1f71b

Please # to comment.