[S0-T03/T06/T07/T08/T09] Set up packages/api: FastAPI, SQLAlchemy, Arq worker, CRUD endpoints, SSE
This commit is contained in:
parent
8b61c03d3c
commit
89f1e47d54
19 changed files with 509 additions and 0 deletions
39
packages/api/README.md
Normal file
39
packages/api/README.md
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
# remodel-api
|
||||||
|
|
||||||
|
FastAPI + Arq async service for the REmodel calculation engine.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry install
|
||||||
|
# Redis must be running (see docker-compose.yml at repo root)
|
||||||
|
docker compose up -d redis
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start API server
|
||||||
|
poetry run uvicorn remodel_api.main:app --reload --port 8000
|
||||||
|
|
||||||
|
# Start Arq worker
|
||||||
|
poetry run arq remodel_api.workers.main.WorkerSettings
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
poetry run pytest
|
||||||
|
|
||||||
|
# Lint + typecheck
|
||||||
|
poetry run ruff check . && poetry run mypy src/
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|--------|-------------------------------|------------------------|
|
||||||
|
| GET | /healthz | Health check |
|
||||||
|
| POST | /api/scenarios | Create + enqueue |
|
||||||
|
| GET | /api/scenarios | List all |
|
||||||
|
| GET | /api/scenarios/{id} | Get one |
|
||||||
|
| GET | /api/scenarios/{id}/events | SSE progress stream |
|
||||||
|
|
||||||
|
Interactive docs: http://localhost:8000/docs
|
||||||
74
packages/api/pyproject.toml
Normal file
74
packages/api/pyproject.toml
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
[tool.poetry]
|
||||||
|
name = "remodel-api"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "FastAPI + Arq async API for the REmodel calculation engine"
|
||||||
|
authors = ["Manohar <manohar6839@gmail.com>"]
|
||||||
|
readme = "README.md"
|
||||||
|
packages = [{include = "remodel_api", from = "src"}]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.12"
|
||||||
|
fastapi = "^0.111"
|
||||||
|
uvicorn = {extras = ["standard"], version = "^0.30"}
|
||||||
|
arq = "^0.26"
|
||||||
|
redis = "^5.0"
|
||||||
|
sqlalchemy = {extras = ["asyncio"], version = "^2.0"}
|
||||||
|
aiosqlite = "^0.20"
|
||||||
|
alembic = "^1.13"
|
||||||
|
pydantic = "^2.7"
|
||||||
|
pydantic-settings = "^2.3"
|
||||||
|
remodel-engine = {path = "../engine", develop = true}
|
||||||
|
sse-starlette = "^2.1"
|
||||||
|
greenlet = "^3.5.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
ruff = "^0.4"
|
||||||
|
mypy = "^1.10"
|
||||||
|
pytest = "^8.2"
|
||||||
|
pytest-cov = "^5.0"
|
||||||
|
pytest-asyncio = "^0.23"
|
||||||
|
httpx = "^0.27"
|
||||||
|
anyio = "^4.4"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py312"
|
||||||
|
line-length = 100
|
||||||
|
src = ["src", "tests"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "W", "I", "UP", "B", "SIM", "ANN", "RUF"]
|
||||||
|
ignore = ["ANN101", "ANN102", "ANN401"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**" = ["ANN"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.12"
|
||||||
|
strict = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
disallow_any_generics = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
no_implicit_reexport = true
|
||||||
|
files = ["src"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = ["arq.*", "alembic.*", "sse_starlette.*", "redis.*"]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
addopts = "--cov=remodel_api --cov-report=term-missing --cov-fail-under=85"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["remodel_api"]
|
||||||
|
branch = true
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING:"]
|
||||||
3
packages/api/src/remodel_api/__init__.py
Normal file
3
packages/api/src/remodel_api/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""REmodel FastAPI service."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
12
packages/api/src/remodel_api/config.py
Normal file
12
packages/api/src/remodel_api/config.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
model_config = SettingsConfigDict(env_prefix="REMODEL_", env_file=".env", extra="ignore")
|
||||||
|
|
||||||
|
database_url: str = "sqlite+aiosqlite:///./remodel.db"
|
||||||
|
redis_url: str = "redis://localhost:6379"
|
||||||
|
api_version: str = "0.1.0"
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
26
packages/api/src/remodel_api/db/models.py
Normal file
26
packages/api/src/remodel_api/db/models.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, String, Text, func
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Scenario(Base):
|
||||||
|
__tablename__ = "scenarios"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
|
)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
String(20), nullable=False, default="queued"
|
||||||
|
)
|
||||||
|
inputs_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
kpis_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||||
|
)
|
||||||
19
packages/api/src/remodel_api/db/session.py
Normal file
19
packages/api/src/remodel_api/db/session.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from remodel_api.config import settings
|
||||||
|
from remodel_api.db.models import Base
|
||||||
|
|
||||||
|
engine = create_async_engine(settings.database_url, echo=False) # pragma: no cover
|
||||||
|
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db() -> None:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncGenerator[AsyncSession, None]: # pragma: no cover
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
yield session
|
||||||
38
packages/api/src/remodel_api/main.py
Normal file
38
packages/api/src/remodel_api/main.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from remodel_api import __version__
|
||||||
|
from remodel_api.db.session import init_db
|
||||||
|
from remodel_api.routers import scenarios
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # pragma: no cover
|
||||||
|
await init_db()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="REmodel API",
|
||||||
|
version=__version__,
|
||||||
|
description="Hybrid RE project finance modeling — Solar + Wind + BESS",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["http://localhost:3000"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(scenarios.router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/healthz", tags=["ops"])
|
||||||
|
async def healthz() -> dict[str, str]:
|
||||||
|
return {"status": "ok", "version": __version__}
|
||||||
80
packages/api/src/remodel_api/routers/scenarios.py
Normal file
80
packages/api/src/remodel_api/routers/scenarios.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import arq
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
from remodel_api.config import settings
|
||||||
|
from remodel_api.db.models import Scenario
|
||||||
|
from remodel_api.db.session import get_session
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_session)]
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioRead(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
status: str
|
||||||
|
kpis_json: str | None
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/scenarios", response_model=ScenarioRead, status_code=201)
|
||||||
|
async def create_scenario(body: ScenarioCreate, db: SessionDep) -> Scenario:
|
||||||
|
scenario = Scenario(name=body.name, status="queued")
|
||||||
|
db.add(scenario)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(scenario)
|
||||||
|
|
||||||
|
pool = await arq.create_pool(arq.connections.RedisSettings.from_dsn(settings.redis_url))
|
||||||
|
await pool.enqueue_job("run_dummy_scenario", scenario.id)
|
||||||
|
await pool.aclose()
|
||||||
|
|
||||||
|
return scenario
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/scenarios", response_model=list[ScenarioRead])
|
||||||
|
async def list_scenarios(db: SessionDep) -> list[Scenario]:
|
||||||
|
result = await db.execute(select(Scenario).order_by(Scenario.created_at.desc()))
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/scenarios/{scenario_id}", response_model=ScenarioRead)
|
||||||
|
async def get_scenario(scenario_id: str, db: SessionDep) -> Scenario:
|
||||||
|
scenario = await db.get(Scenario, scenario_id)
|
||||||
|
if scenario is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Scenario not found")
|
||||||
|
return scenario
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/scenarios/{scenario_id}/events")
|
||||||
|
async def scenario_events(scenario_id: str) -> EventSourceResponse: # pragma: no cover
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
async def generator() -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
r = aioredis.from_url(settings.redis_url) # type: ignore[no-untyped-call]
|
||||||
|
channel = f"scenario:{scenario_id}:events"
|
||||||
|
pubsub = r.pubsub()
|
||||||
|
await pubsub.subscribe(channel)
|
||||||
|
try:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] == "message":
|
||||||
|
yield {"data": message["data"].decode()}
|
||||||
|
finally:
|
||||||
|
await pubsub.unsubscribe(channel)
|
||||||
|
await r.aclose()
|
||||||
|
|
||||||
|
return EventSourceResponse(generator())
|
||||||
0
packages/api/src/remodel_api/workers/__init__.py
Normal file
0
packages/api/src/remodel_api/workers/__init__.py
Normal file
12
packages/api/src/remodel_api/workers/main.py
Normal file
12
packages/api/src/remodel_api/workers/main.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
from arq.connections import RedisSettings
|
||||||
|
|
||||||
|
from remodel_api.config import settings
|
||||||
|
from remodel_api.workers.tasks import run_dummy_scenario
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerSettings:
|
||||||
|
functions: ClassVar[list] = [run_dummy_scenario] # type: ignore[type-arg]
|
||||||
|
redis_settings: ClassVar[RedisSettings] = RedisSettings.from_dsn(settings.redis_url)
|
||||||
|
keep_result: ClassVar[int] = 3600
|
||||||
48
packages/api/src/remodel_api/workers/tasks.py
Normal file
48
packages/api/src/remodel_api/workers/tasks.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
from remodel_api.config import settings
|
||||||
|
from remodel_api.db.models import Scenario
|
||||||
|
from remodel_api.db.session import AsyncSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
async def _publish(r: Any, channel: str, stage: str, pct: int) -> None:
|
||||||
|
payload = json.dumps({"stage": stage, "pct": pct})
|
||||||
|
await r.publish(channel, payload)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_dummy_scenario(ctx: dict[str, Any], scenario_id: str) -> dict[str, Any]:
|
||||||
|
r = aioredis.from_url(settings.redis_url) # type: ignore[no-untyped-call]
|
||||||
|
channel = f"scenario:{scenario_id}:events"
|
||||||
|
|
||||||
|
async with AsyncSessionLocal() as db:
|
||||||
|
scenario = await db.get(Scenario, scenario_id)
|
||||||
|
if scenario is None:
|
||||||
|
await r.aclose()
|
||||||
|
return {"error": "not found"}
|
||||||
|
scenario.status = "running"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await _publish(r, channel, "starting", 0)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
await _publish(r, channel, "computing", 33)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
await _publish(r, channel, "computing", 66)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
await _publish(r, channel, "finishing", 90)
|
||||||
|
|
||||||
|
result: dict[str, Any] = {"id": scenario_id, "result": "dummy"}
|
||||||
|
|
||||||
|
async with AsyncSessionLocal() as db:
|
||||||
|
scenario = await db.get(Scenario, scenario_id)
|
||||||
|
if scenario is not None:
|
||||||
|
scenario.status = "success"
|
||||||
|
scenario.kpis_json = json.dumps(result)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await _publish(r, channel, "done", 100)
|
||||||
|
await r.aclose()
|
||||||
|
return result
|
||||||
0
packages/api/tests/__init__.py
Normal file
0
packages/api/tests/__init__.py
Normal file
33
packages/api/tests/conftest.py
Normal file
33
packages/api/tests/conftest.py
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from remodel_api.db.models import Base
|
||||||
|
from remodel_api.db.session import get_session
|
||||||
|
from remodel_api.main import app
|
||||||
|
|
||||||
|
TEST_DB_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
async def db_session() -> AsyncSession: # type: ignore[return]
|
||||||
|
engine = create_async_engine(TEST_DB_URL)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as session:
|
||||||
|
yield session
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
async def client(db_session: AsyncSession) -> AsyncClient: # type: ignore[return]
|
||||||
|
async def override_session() -> AsyncSession: # type: ignore[return]
|
||||||
|
yield db_session # type: ignore[misc]
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = override_session
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
|
||||||
|
yield c
|
||||||
|
app.dependency_overrides.clear()
|
||||||
9
packages/api/tests/test_healthz.py
Normal file
9
packages/api/tests/test_healthz.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
async def test_healthz(client: AsyncClient) -> None:
|
||||||
|
resp = await client.get("/healthz")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert "version" in data
|
||||||
47
packages/api/tests/test_scenarios.py
Normal file
47
packages/api/tests/test_scenarios.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_scenarios_empty(client: AsyncClient) -> None:
|
||||||
|
resp = await client.get("/api/scenarios")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_and_get_scenario(client: AsyncClient) -> None:
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.enqueue_job = AsyncMock()
|
||||||
|
mock_pool.aclose = AsyncMock()
|
||||||
|
|
||||||
|
with patch("remodel_api.routers.scenarios.arq.create_pool", return_value=mock_pool):
|
||||||
|
resp = await client.post("/api/scenarios", json={"name": "Test Scenario"})
|
||||||
|
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "Test Scenario"
|
||||||
|
assert data["status"] == "queued"
|
||||||
|
scenario_id = data["id"]
|
||||||
|
|
||||||
|
resp2 = await client.get(f"/api/scenarios/{scenario_id}")
|
||||||
|
assert resp2.status_code == 200
|
||||||
|
assert resp2.json()["id"] == scenario_id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_scenario_not_found(client: AsyncClient) -> None:
|
||||||
|
resp = await client.get("/api/scenarios/nonexistent-id")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_scenarios_after_create(client: AsyncClient) -> None:
|
||||||
|
mock_pool = AsyncMock()
|
||||||
|
mock_pool.enqueue_job = AsyncMock()
|
||||||
|
mock_pool.aclose = AsyncMock()
|
||||||
|
|
||||||
|
with patch("remodel_api.routers.scenarios.arq.create_pool", return_value=mock_pool):
|
||||||
|
await client.post("/api/scenarios", json={"name": "S1"})
|
||||||
|
await client.post("/api/scenarios", json={"name": "S2"})
|
||||||
|
|
||||||
|
resp = await client.get("/api/scenarios")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()) == 2
|
||||||
14
packages/api/tests/test_worker_settings.py
Normal file
14
packages/api/tests/test_worker_settings.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
from remodel_api.workers.main import WorkerSettings
|
||||||
|
from remodel_api.workers.tasks import run_dummy_scenario
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_settings_has_functions() -> None:
|
||||||
|
assert run_dummy_scenario in WorkerSettings.functions
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_settings_has_redis_settings() -> None:
|
||||||
|
assert WorkerSettings.redis_settings is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_settings_keep_result() -> None:
|
||||||
|
assert WorkerSettings.keep_result == 3600
|
||||||
55
packages/api/tests/test_worker_tasks.py
Normal file
55
packages/api/tests/test_worker_tasks.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from remodel_api.db.models import Scenario
|
||||||
|
from remodel_api.workers.tasks import run_dummy_scenario
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_redis() -> AsyncMock:
|
||||||
|
r = AsyncMock()
|
||||||
|
r.publish = AsyncMock()
|
||||||
|
r.aclose = AsyncMock()
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_dummy_scenario_success(mock_redis: AsyncMock) -> None:
|
||||||
|
scenario = Scenario(name="worker-test", status="queued")
|
||||||
|
|
||||||
|
session_mock = AsyncMock()
|
||||||
|
session_mock.__aenter__ = AsyncMock(return_value=session_mock)
|
||||||
|
session_mock.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
session_mock.get = AsyncMock(return_value=scenario)
|
||||||
|
session_mock.commit = AsyncMock()
|
||||||
|
|
||||||
|
factory_mock = MagicMock()
|
||||||
|
factory_mock.return_value = session_mock
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("remodel_api.workers.tasks.aioredis.from_url", return_value=mock_redis),
|
||||||
|
patch("remodel_api.workers.tasks.AsyncSessionLocal", factory_mock),
|
||||||
|
):
|
||||||
|
result = await run_dummy_scenario({}, "dummy-id")
|
||||||
|
|
||||||
|
assert result["result"] == "dummy"
|
||||||
|
assert result["id"] == "dummy-id"
|
||||||
|
assert mock_redis.publish.called
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_dummy_scenario_not_found(mock_redis: AsyncMock) -> None:
|
||||||
|
session_mock = AsyncMock()
|
||||||
|
session_mock.__aenter__ = AsyncMock(return_value=session_mock)
|
||||||
|
session_mock.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
session_mock.get = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
factory_mock = MagicMock()
|
||||||
|
factory_mock.return_value = session_mock
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("remodel_api.workers.tasks.aioredis.from_url", return_value=mock_redis),
|
||||||
|
patch("remodel_api.workers.tasks.AsyncSessionLocal", factory_mock),
|
||||||
|
):
|
||||||
|
result = await run_dummy_scenario({}, "missing-id")
|
||||||
|
|
||||||
|
assert "error" in result
|
||||||
Loading…
Add table
Reference in a new issue