99 lines
2.6 KiB
Python
99 lines
2.6 KiB
Python
|
|
import os
|
||
|
|
import uuid
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import pytest_asyncio
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
|
||
|
|
load_dotenv()
|
||
|
|
|
||
|
|
from httpx import AsyncClient, ASGITransport
|
||
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||
|
|
|
||
|
|
from src.database.config import get_db
|
||
|
|
from src.database.models import Base, Tenant, User
|
||
|
|
from src.auth.password import hash_password
|
||
|
|
|
||
|
|
TEST_DATABASE_URL = os.getenv(
|
||
|
|
"TEST_DATABASE_URL",
|
||
|
|
"postgresql+asyncpg://factoryops:factoryops@localhost:5432/factoryops_v2_test",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest_asyncio.fixture(scope="function")
|
||
|
|
async def db_session():
|
||
|
|
engine = create_async_engine(TEST_DATABASE_URL, echo=False, pool_size=5)
|
||
|
|
|
||
|
|
async with engine.begin() as conn:
|
||
|
|
await conn.run_sync(Base.metadata.create_all)
|
||
|
|
|
||
|
|
session_factory = async_sessionmaker(
|
||
|
|
engine, class_=AsyncSession, expire_on_commit=False
|
||
|
|
)
|
||
|
|
|
||
|
|
async with session_factory() as session:
|
||
|
|
yield session
|
||
|
|
|
||
|
|
async with engine.begin() as conn:
|
||
|
|
await conn.run_sync(Base.metadata.drop_all)
|
||
|
|
|
||
|
|
await engine.dispose()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest_asyncio.fixture(scope="function")
|
||
|
|
async def client(db_session: AsyncSession):
|
||
|
|
from main import app
|
||
|
|
|
||
|
|
async def override_get_db():
|
||
|
|
yield db_session
|
||
|
|
|
||
|
|
app.dependency_overrides[get_db] = override_get_db
|
||
|
|
|
||
|
|
transport = ASGITransport(app=app)
|
||
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
|
|
yield ac
|
||
|
|
|
||
|
|
app.dependency_overrides.clear()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest_asyncio.fixture(scope="function")
|
||
|
|
async def seeded_db(db_session: AsyncSession):
|
||
|
|
db_session.add(
|
||
|
|
Tenant(id="test-co", name="Test Company", industry_type="manufacturing")
|
||
|
|
)
|
||
|
|
db_session.add(Tenant(id="other-co", name="Other Company", industry_type="general"))
|
||
|
|
await db_session.commit()
|
||
|
|
|
||
|
|
db_session.add(
|
||
|
|
User(
|
||
|
|
id=uuid.uuid4(),
|
||
|
|
email="super@test.com",
|
||
|
|
password_hash=hash_password("pass1234"),
|
||
|
|
name="Super Admin",
|
||
|
|
role="superadmin",
|
||
|
|
tenant_id=None,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
db_session.add(
|
||
|
|
User(
|
||
|
|
id=uuid.uuid4(),
|
||
|
|
email="admin@test-co.com",
|
||
|
|
password_hash=hash_password("pass1234"),
|
||
|
|
name="Tenant Admin",
|
||
|
|
role="tenant_admin",
|
||
|
|
tenant_id="test-co",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
await db_session.commit()
|
||
|
|
return db_session
|
||
|
|
|
||
|
|
|
||
|
|
async def get_auth_headers(
|
||
|
|
client: AsyncClient, email: str = "super@test.com", password: str = "pass1234"
|
||
|
|
) -> dict:
|
||
|
|
resp = await client.post(
|
||
|
|
"/api/auth/login", json={"email": email, "password": password}
|
||
|
|
)
|
||
|
|
token = resp.json()["access_token"]
|
||
|
|
return {"Authorization": f"Bearer {token}"}
|