import os import logging from datetime import datetime, timezone from typing import Optional from uuid import UUID import httpx from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from src.database.models import Machine, MachineChangeHistory, Tenant logger = logging.getLogger(__name__) UPSTREAM_FIELDS = { "name", "equipment_code", "model", "manufacturer", "installation_date", "location", "rated_capacity", "power_rating", } LOCAL_FIELDS = {"criticality", "area", "description"} FIELD_MAPPING = { "name": "name", "equipmentId": "equipment_code", "model": "model", "manufacturer": "manufacturer", "installationDate": "installation_date", "location": "location", "capacity": "rated_capacity", "powerConsumption": "power_rating", "description": "description", } class ImportResult(BaseModel): imported_count: int = 0 skipped_count: int = 0 errors: list[str] = [] class PullResult(BaseModel): synced_count: int = 0 fields_updated: int = 0 errors: list[str] = [] class PushResult(BaseModel): success: bool = True fields_pushed: list[str] = [] error: Optional[str] = None class SyncResult(BaseModel): pull: PullResult push_count: int = 0 push_errors: list[str] = [] class EquipmentSyncService: def __init__(self, db: AsyncSession, tenant_id: str): self.db = db self.tenant_id = tenant_id self.api_url = os.getenv("DIGITAL_TWIN_API_URL", "") self.api_key = os.getenv("DIGITAL_TWIN_API_KEY", "") def _get_headers(self) -> dict: headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" return headers def _http_client(self) -> httpx.AsyncClient: return httpx.AsyncClient( base_url=self.api_url, headers=self._get_headers(), timeout=30.0, ) def _extract_equipment(self, data) -> list[dict]: if isinstance(data, list): return data if isinstance(data, dict): inner = data.get("data", data) if isinstance(inner, dict): return inner.get("equipment", inner.get("data", [])) if isinstance(inner, list): return inner return [] def _extract_pagination(self, data: dict) -> Optional[dict]: if not isinstance(data, dict): return None inner = data.get("data", data) if isinstance(inner, dict): return inner.get("pagination") return None async def _get_company_id(self) -> Optional[str]: result = await self.db.execute( select(Tenant.digital_twin_company_id).where(Tenant.id == self.tenant_id) ) return result.scalar_one_or_none() async def fetch_remote_equipment(self) -> list[dict]: if not self.api_url: return [] try: company_id = await self._get_company_id() if not company_id: logger.warning( f"Tenant {self.tenant_id} has no digital_twin_company_id mapped" ) return [] all_equipment: list[dict] = [] page = 1 max_limit = 500 async with self._http_client() as client: while True: resp = await client.get( "/api/v1/aas/equipment", params={ "page": page, "limit": max_limit, "companyId": company_id, }, ) resp.raise_for_status() data = resp.json() batch = self._extract_equipment(data) all_equipment.extend(batch) pagination = self._extract_pagination(data) if not pagination or not pagination.get("hasNextPage"): break page += 1 return all_equipment except Exception as e: logger.error(f"Failed to fetch remote equipment: {e}") return [] def _map_remote_to_local(self, remote: dict) -> dict: mapped = {} for remote_key, local_key in FIELD_MAPPING.items(): val = remote.get(remote_key) if val is not None: mapped[local_key] = str(val) if val else None location = remote.get("location", "") if location and " " in str(location): parts = str(location).split(" ") mapped["area"] = parts[0] if parts else None return mapped async def import_equipment( self, external_ids: Optional[list[str]] = None ) -> ImportResult: result = ImportResult() remote_list = await self.fetch_remote_equipment() if not remote_list: result.errors.append("디지털 트윈에서 설비 데이터를 가져올 수 없습니다.") return result for eq in remote_list: remote_id = str(eq.get("id", "")) if not remote_id: continue if external_ids and remote_id not in external_ids: continue existing = await self.db.execute( select(Machine).where( Machine.tenant_id == self.tenant_id, Machine.external_id == remote_id, ) ) if existing.scalar_one_or_none(): result.skipped_count += 1 continue try: mapped = self._map_remote_to_local(eq) install_dt = None raw_date = mapped.get("installation_date") if raw_date: try: install_dt = datetime.fromisoformat( raw_date.replace("Z", "+00:00") ) except (ValueError, AttributeError): pass machine = Machine( tenant_id=self.tenant_id, name=mapped.get("name", f"Equipment-{remote_id[:8]}"), equipment_code=mapped.get("equipment_code", ""), model=mapped.get("model"), manufacturer=mapped.get("manufacturer"), installation_date=install_dt, location=mapped.get("location"), area=mapped.get("area"), criticality="major", rated_capacity=mapped.get("rated_capacity"), power_rating=mapped.get("power_rating"), source="digital-twin", external_id=remote_id, sync_version=1, last_synced_at=datetime.now(timezone.utc), ) self.db.add(machine) result.imported_count += 1 except Exception as e: result.errors.append(f"설비 {remote_id}: {str(e)}") if result.imported_count > 0: await self.db.commit() return result async def pull_from_remote(self) -> PullResult: result = PullResult() remote_list = await self.fetch_remote_equipment() if not remote_list: return result remote_by_id = {str(eq.get("id", "")): eq for eq in remote_list} stmt = select(Machine).where( Machine.tenant_id == self.tenant_id, Machine.source == "digital-twin", Machine.external_id.isnot(None), ) local_machines = (await self.db.execute(stmt)).scalars().all() for machine in local_machines: ext_id = str(machine.external_id) remote = remote_by_id.get(ext_id) if not remote: continue remote_updated = remote.get("updatedAt") or remote.get("updated_at") remote_dt = None if remote_updated: try: remote_dt = datetime.fromisoformat( str(remote_updated).replace("Z", "+00:00") ) except (ValueError, AttributeError): pass mapped = self._map_remote_to_local(remote) fields_changed = 0 for field in UPSTREAM_FIELDS: remote_val = mapped.get(field) if remote_val is None: continue if field == "installation_date": local_val = ( machine.installation_date.isoformat() if machine.installation_date else None ) compare_remote = remote_val else: local_val = str(getattr(machine, field, "") or "") compare_remote = str(remote_val or "") if local_val == compare_remote: continue local_updated = machine.updated_at if remote_dt and local_updated and remote_dt <= local_updated: continue history = MachineChangeHistory( tenant_id=self.tenant_id, machine_id=machine.id, field_name=field, old_value=str(local_val) if local_val else None, new_value=str(remote_val) if remote_val else None, change_source="sync", changed_at=datetime.now(timezone.utc), ) self.db.add(history) if field == "installation_date": try: setattr( machine, field, datetime.fromisoformat(remote_val.replace("Z", "+00:00")), ) except (ValueError, AttributeError): pass else: setattr(machine, field, remote_val) fields_changed += 1 if fields_changed > 0: machine.sync_version = (machine.sync_version or 0) + 1 machine.last_synced_at = datetime.now(timezone.utc) result.synced_count += 1 result.fields_updated += fields_changed if result.synced_count > 0: await self.db.commit() return result async def push_to_remote(self, machine_id: UUID) -> PushResult: if not self.api_url: return PushResult( success=False, error="DIGITAL_TWIN_API_URL이 설정되지 않았습니다." ) stmt = select(Machine).where( Machine.id == machine_id, Machine.tenant_id == self.tenant_id, Machine.source == "digital-twin", ) machine = (await self.db.execute(stmt)).scalar_one_or_none() if not machine or not machine.external_id: return PushResult(success=False, error="동기화 대상 설비가 아닙니다.") payload = {} reverse_mapping = {v: k for k, v in FIELD_MAPPING.items()} pushed_fields = [] for local_field in UPSTREAM_FIELDS | LOCAL_FIELDS: remote_key = reverse_mapping.get(local_field) if not remote_key: continue val = getattr(machine, local_field, None) if val is not None: if hasattr(val, "isoformat"): payload[remote_key] = val.isoformat() else: payload[remote_key] = str(val) pushed_fields.append(local_field) try: async with self._http_client() as client: resp = await client.put( f"/api/v1/aas/equipment/{machine.external_id}", json=payload, ) resp.raise_for_status() return PushResult(success=True, fields_pushed=pushed_fields) except Exception as e: logger.error(f"Failed to push to digital-twin: {e}") return PushResult(success=False, error=str(e)) async def sync(self) -> SyncResult: if not self.api_url: return SyncResult( pull=PullResult(errors=["DIGITAL_TWIN_API_URL이 설정되지 않았습니다."]), ) pull_result = await self.pull_from_remote() push_count = 0 push_errors: list[str] = [] stmt = select(Machine).where( Machine.tenant_id == self.tenant_id, Machine.source == "digital-twin", Machine.external_id.isnot(None), ) synced_machines = (await self.db.execute(stmt)).scalars().all() for machine in synced_machines: push_result = await self.push_to_remote(machine.id) if push_result.success: push_count += 1 elif push_result.error: push_errors.append(f"{machine.name}: {push_result.error}") return SyncResult( pull=pull_result, push_count=push_count, push_errors=push_errors, ) async def record_change( self, machine_id: UUID, field_name: str, old_value: Optional[str], new_value: Optional[str], change_source: str, changed_by: Optional[UUID] = None, ): history = MachineChangeHistory( tenant_id=self.tenant_id, machine_id=machine_id, field_name=field_name, old_value=old_value, new_value=new_value, change_source=change_source, changed_by=changed_by, changed_at=datetime.now(timezone.utc), ) self.db.add(history)