-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
18257e1
commit 8c1f71b
Showing
2 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import os | ||
import json | ||
import base64 | ||
import pytest | ||
import uuid | ||
import textwrap | ||
from datetime import datetime | ||
|
||
from fastapi.testclient import TestClient | ||
from moto import mock_aws # Use the generic AWS mock | ||
|
||
# Set the environment variables before importing your app. | ||
os.environ["DYNAMODB_CACHE_TABLE"] = "TensorgradCache" | ||
os.environ["DYNAMODB_SNIPPET_TABLE"] = "CodeSnippets" | ||
|
||
# Now import the app and helper functions from your FastAPI application. | ||
from drawTensors import app, safe_execute, create_snippet, get_snippet, ExecutionResult | ||
|
||
# Create a TestClient for FastAPI | ||
client = TestClient(app) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def setup_dynamodb(): | ||
""" | ||
Use Moto to mock AWS and create the necessary DynamoDB tables for testing. | ||
""" | ||
with mock_aws(): | ||
import boto3 # Import boto3 within the Moto context | ||
|
||
dynamodb = boto3.resource("dynamodb", region_name="us-east-1") | ||
|
||
# Create the CodeCache table | ||
try: | ||
dynamodb.create_table( | ||
TableName="CodeCache", | ||
KeySchema=[{"AttributeName": "code_hash", "KeyType": "HASH"}], | ||
AttributeDefinitions=[ | ||
{"AttributeName": "code_hash", "AttributeType": "S"} | ||
], | ||
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, | ||
) | ||
except Exception: | ||
pass | ||
|
||
# Create the CodeSnippets table | ||
try: | ||
dynamodb.create_table( | ||
TableName="CodeSnippets", | ||
KeySchema=[{"AttributeName": "snippet_id", "KeyType": "HASH"}], | ||
AttributeDefinitions=[ | ||
{"AttributeName": "snippet_id", "AttributeType": "S"} | ||
], | ||
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, | ||
) | ||
except Exception: | ||
pass | ||
|
||
# Yield to run tests inside the Moto context. | ||
yield | ||
|
||
|
||
# -------- Tests for safe_execute -------- | ||
|
||
|
||
def test_safe_execute_success(): | ||
code = "print('Hello, world!')" | ||
result: ExecutionResult = safe_execute(code) | ||
assert result.success is True | ||
assert "Hello, world!" in result.output | ||
# For this code, there should be no image, error, or stacktrace. | ||
assert result.image is None | ||
assert result.error == "" | ||
assert result.stacktrace is None | ||
|
||
|
||
def test_safe_execute_syntax_error(): | ||
code = "print('Hello" # unbalanced quote causes SyntaxError | ||
result: ExecutionResult = safe_execute(code) | ||
assert result.success is False | ||
assert "SyntaxError" in result.error or "unterminated" in result.error | ||
|
||
|
||
def test_safe_execute_unsafe_code(): | ||
# Attempt to import os should be blocked. | ||
code = "import os\nprint('Unsafe')" | ||
result: ExecutionResult = safe_execute(code) | ||
assert result.success is False | ||
assert "unsafe operations" in result.error | ||
|
||
|
||
# -------- Tests for /execute endpoint -------- | ||
|
||
|
||
def test_execute_endpoint_success(): | ||
payload = {"code": "print('Test Execute')"} | ||
response = client.post("/execute", json=payload) | ||
assert response.status_code == 200 | ||
data = response.json() | ||
# Validate using the ExecutionResult schema fields. | ||
assert data["success"] is True | ||
assert "Test Execute" in data["output"] | ||
assert data["error"] == "" | ||
|
||
|
||
def test_execute_good_code_with_image(): | ||
code = textwrap.dedent( | ||
""" | ||
i = sp.symbols('i'); | ||
x = tg.Delta(i, 'i', 'j'); | ||
y = x * 2; | ||
save_steps(y); | ||
""" | ||
) | ||
response = client.post("/execute", json={"code": code}) | ||
assert response.status_code == 200 | ||
data = response.json() | ||
assert data["success"] is True | ||
assert "image" in data | ||
# parse image | ||
image_data = data["image"] | ||
assert image_data.startswith("data:image/png;base64,") | ||
|
||
|
||
def test_execute_endpoint_invalid_payload(): | ||
# Send an empty payload. | ||
response = client.post("/execute", json={"code": ""}) | ||
assert response.status_code == 400 | ||
data = response.json() | ||
assert "No code provided" in data["detail"] | ||
|
||
|
||
# -------- Tests for Snippet Endpoints -------- | ||
|
||
|
||
def test_create_and_fetch_snippet(): | ||
code = "print('Snippet Test')" | ||
# Create a snippet. | ||
response_post = client.post("/snippets", json={"code": code}) | ||
assert response_post.status_code == 200 | ||
post_data = response_post.json() | ||
assert "snippet_id" in post_data | ||
snippet_id = post_data["snippet_id"] | ||
|
||
# Retrieve the snippet. | ||
response_get = client.get(f"/snippets/{snippet_id}") | ||
assert response_get.status_code == 200 | ||
get_data = response_get.json() | ||
# Verify that the snippet contains the correct code. | ||
assert get_data["snippet_id"] == snippet_id | ||
assert get_data["code"] == code | ||
# Verify created_at and author_id are present. | ||
assert "created_at" in get_data | ||
assert "author_id" in get_data | ||
|
||
|
||
def test_fetch_nonexistent_snippet(): | ||
# Attempt to fetch a snippet that doesn't exist. | ||
fake_snippet_id = str(uuid.uuid4()) | ||
response_get = client.get(f"/snippets/{fake_snippet_id}") | ||
assert response_get.status_code == 404 | ||
data = response_get.json() | ||
assert "not found" in data["detail"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import os | ||
import sys | ||
import subprocess | ||
import time | ||
import requests | ||
import pytest | ||
import json | ||
import socket | ||
import time | ||
from typing import Optional | ||
from pydantic import BaseModel | ||
|
||
from drawTensors import CodePayload, ExecutionResult, SnippetCreationResponse, Snippet | ||
|
||
PORT = 9000 # Host port | ||
LAMBDA_URL = f"http://localhost:{PORT}/2015-03-31/functions/function/invocations" | ||
|
||
|
||
def wait_for_port(host: str, port: int, timeout: float = 10.0): | ||
"""Wait for a network port to become available.""" | ||
deadline = time.time() + timeout | ||
while time.time() < deadline: | ||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: | ||
result = sock.connect_ex((host, port)) | ||
if result == 0: | ||
return True | ||
time.sleep(0.5) | ||
raise TimeoutError(f"Timeout waiting for {host}:{port} to become available.") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def docker_container(): | ||
""" | ||
Build the Docker image once per session, run the container, yield the container ID, | ||
and then stop and remove the container after the session ends. | ||
""" | ||
|
||
# Run the build command with live output. | ||
build_proc = subprocess.run( | ||
[ | ||
"docker", | ||
"buildx", | ||
"build", | ||
"-t", | ||
"tensorgrad", | ||
"-f", | ||
"docker/Dockerfile", | ||
".", | ||
], | ||
capture_output=True, | ||
text=True, | ||
cwd="..", | ||
) | ||
assert build_proc.returncode == 0, f"Docker build failed:\n{build_proc.stderr}" | ||
|
||
# Run the Docker container in detached mode (mapping host port PORT to container port 8080). | ||
run_proc = subprocess.run( | ||
[ | ||
"docker", | ||
"run", | ||
"-d", | ||
"--name", | ||
"lambda_local_test", | ||
"-p", | ||
f"{PORT}:8080", | ||
"tensorgrad", | ||
], | ||
capture_output=True, | ||
text=True, | ||
cwd="..", | ||
) | ||
assert run_proc.returncode == 0, f"Docker run failed:\n{run_proc.stderr}" | ||
container_id = run_proc.stdout.strip() | ||
|
||
# Allow time for the container to start. | ||
wait_for_port("localhost", PORT, timeout=10) | ||
time.sleep(1) | ||
print("Docker container is running.") | ||
|
||
try: | ||
yield container_id | ||
|
||
finally: | ||
print("Cleaning up Docker container...") | ||
logs = subprocess.run( | ||
["docker", "logs", container_id], capture_output=True, text=True | ||
) | ||
print("\nDocker container logs:\n", logs.stdout, logs.stderr) | ||
|
||
# Teardown: stop and remove the container. | ||
subprocess.run( | ||
["docker", "stop", container_id], stdout=sys.stdout, stderr=sys.stderr | ||
) | ||
subprocess.run( | ||
["docker", "rm", container_id], stdout=sys.stdout, stderr=sys.stderr | ||
) | ||
|
||
|
||
def invoke_api( | ||
body: Optional[BaseModel] = None, | ||
path: str = "/", | ||
method: str = "GET", | ||
pathParameters: dict = None, | ||
) -> requests.Response: | ||
""" | ||
Helper function that creates a simulated API Gateway event and invokes the Lambda endpoint. | ||
""" | ||
response = requests.post( | ||
LAMBDA_URL, | ||
json={ | ||
"resource": path, | ||
"path": path, | ||
"httpMethod": method, | ||
"queryStringParameters": None, | ||
"pathParameters": pathParameters, | ||
"body": None if body is None else body.model_dump_json(), | ||
"isBase64Encoded": False, | ||
"requestContext": { | ||
"resourcePath": path, | ||
"httpMethod": method, | ||
"path": path, | ||
}, | ||
}, | ||
headers={"Content-Type": "application/json"}, | ||
) | ||
assert response.status_code == 200, f"API failed: {response.text}" | ||
print(response) | ||
data = response.json() | ||
assert data["statusCode"] == 200, f"API failed: {data['body']}" | ||
print("reponse.json()=", data) | ||
return json.loads(data["body"]) | ||
|
||
|
||
def test_docker_image_execute_endpoint(docker_container): | ||
payload_obj = CodePayload(code="print('Hello from Docker')") | ||
response = invoke_api(body=payload_obj, path="/execute", method="POST") | ||
result = ExecutionResult.model_validate(response) | ||
assert result.success is True, f"Execution did not succeed: {result}" | ||
|
||
|
||
def test_docker_snippet_endpoints(docker_container): | ||
# ---- Snippet Creation ---- | ||
payload_obj = CodePayload(code="print('Hello, snippet!')") | ||
response_body = invoke_api(body=payload_obj, path="/snippets", method="POST") | ||
snippet_creation = SnippetCreationResponse.model_validate(response_body) | ||
assert snippet_creation.snippet_id is not None | ||
snippet_id = snippet_creation.snippet_id | ||
|
||
# ---- Snippet Retrieval ---- | ||
response_get = invoke_api( | ||
path=f"/snippets/{snippet_id}", | ||
method="GET", | ||
pathParameters={"snippet_id": snippet_id}, | ||
) | ||
snippet_obj = Snippet.model_validate(response_get) | ||
assert snippet_obj.snippet_id == snippet_id | ||
assert snippet_obj.code == payload_obj.code | ||
assert snippet_obj.created_at is not None | ||
assert snippet_obj.author_id is not None |