from http.client import HTTPResponse
import json
import uuid
from app.utils import (
ANALYSIS_STATUS,
STATUS,
read_data_file,
create_response_error,
create_response_message,
deregister_container_from_ensemble,
parse_response_for_triggered_analysis,
)
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship, selectinload
from app.models.ensemble_ids import EnsembleIds, get_ensemble_ids_by_ids
from app.database import Base
from app.models.ids_system import IdsSystem, update_ids_status
from app.validation.models import EnsembleUpdate
import httpx
from sqlalchemy.future import select
from app.logger import LOGGER
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.ids_system import IdsSystem, get_ids_system_by_id
from app.models.dataset import get_dataset_by_id
[docs]
class Ensemble(Base):
__tablename__ = "ensemble"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(64), nullable=False)
technique_id = Column(Integer, ForeignKey("ensemble_technique.id"))
status = Column(String(32), nullable=False)
description = Column(String(2048))
current_analysis_id = Column(String(64))
ensemble_ids = relationship("EnsembleIds", cascade="all, delete", lazy="selectin")
ensemble_technique = relationship(
"EnsembleTechnique", back_populates="ensemble", lazy="selectin"
)
[docs]
async def add_container(self, db: AsyncSession, container_id: int):
ensemble_ids = EnsembleIds(
ensemble_id=self.id,
ids_system_id=container_id,
status=ANALYSIS_STATUS.IDLE.value,
)
container: IdsSystem = await get_ids_system_by_id(db, container_id)
container_url = container.get_container_http_url()
endpoint = f"/configure/ensemble/add/{self.id}"
async with httpx.AsyncClient() as client:
response: HTTPResponse = await client.post(container_url + endpoint)
if response.status_code == 200:
db.add(ensemble_ids)
await db.commit()
return response
[docs]
async def remove_container(self, db: AsyncSession, container_id: int):
ensemble_ids = await get_ensemble_ids_by_ids(db, self.id, container_id)
container: IdsSystem = await get_ids_system_by_id(db, container_id)
response = await deregister_container_from_ensemble(container)
if response.status_code == 200:
await db.delete(ensemble_ids)
await db.commit()
return response
[docs]
async def get_ensemble_ids(self, db: AsyncSession):
stmt = select(EnsembleIds).where(EnsembleIds.ensemble_id == self.id)
result = await db.execute(stmt)
return result.scalars().all()
[docs]
async def get_assigned_containers(self, db: AsyncSession):
ensemble_ids = await self.get_ensemble_ids(db)
id_list = [e_ids.ids_system_id for e_ids in ensemble_ids]
stmt = select(IdsSystem).where(IdsSystem.id.in_(id_list))
container_result = await db.execute(stmt)
return container_result.scalars().all()
[docs]
async def start_static_analysis(self, db: AsyncSession, dataset_id: int):
LOGGER.debug(f"{dataset_id}")
dataset = await get_dataset_by_id(db, dataset_id)
containers: list[IdsSystem] = await self.get_assigned_containers(db)
responses = []
data_file = await read_data_file(dataset.data_file_path)
for container in containers:
form_data = {
"container_id": (None, str(container.id), "application/json"),
"ensemble_id": (None, str(self.id), "application/json"),
"dataset": (dataset.name, data_file, "application/octet-stream"),
"dataset_id": (None, str(dataset.id), "application/json"),
}
# TODO 0: try with asyncio in background
response: HTTPResponse = await container.start_static_analysis(
form_data, dataset
)
response = await parse_response_for_triggered_analysis(
response, container, "static", self.id
)
if response.status_code != 200:
await update_ids_status(db, STATUS.IDLE.value, container)
else:
await update_ids_status(db, STATUS.ACTIVE.value, container)
responses.append(response)
return responses
[docs]
async def container_is_last_one_running(self, db: AsyncSession, container):
all_containers = await self.get_assigned_containers(db)
other_containers_in_ensemble = list(
filter(lambda c: c.id != container.id, all_containers)
)
other_containers_running = [
await c.is_busy() for c in other_containers_in_ensemble
]
# if there is only one container in the ensemble, then that is always the last one running
if len(all_containers) == 1:
return True
elif True not in other_containers_running:
return True
else:
return False
[docs]
async def start_network_analysis(self, db: AsyncSession, network_analysis_data):
containers: list[IdsSystem] = await self.get_assigned_containers(db)
responses = []
for container in containers:
data = json.dumps(network_analysis_data.__dict__)
response: HTTPResponse = await container.start_network_analysis(data)
response = await parse_response_for_triggered_analysis(
response, container, "network", self.id
)
if response.status_code != 200:
await update_ids_status(db, STATUS.IDLE.value, container)
else:
await update_ids_status(db, STATUS.ACTIVE.value, container)
responses.append(response)
return responses
[docs]
async def stop_analysis(self, db: AsyncSession):
containers: list[IdsSystem] = await self.get_assigned_containers(db)
responses = []
for container in containers:
response: HTTPResponse = await container.stop_analysis()
if response.status_code == 200:
message = f"Analysis for container {container.id} successfully stopped"
responses.append(create_response_message(message, 200))
else:
message = f"Analysis for container {container.id} could not be stopped"
responses.append(create_response_error(message, 500))
return responses
[docs]
async def is_container_running(self):
if self.status == STATUS.ACTIVE:
return True
else:
return False
[docs]
async def generate_new_analysis_id(self, db: AsyncSession):
self.current_analysis_id = str(uuid.uuid4())
await db.commit()
await db.refresh(self)
[docs]
async def unset_analysis_id(self, db: AsyncSession):
self.current_analysis_id = None
await db.commit()
await db.refresh(self)
[docs]
async def get_all_ensembles(db: AsyncSession):
stmt = select(Ensemble).options(
selectinload(Ensemble.ensemble_ids),
)
result = await db.execute(stmt)
return result.scalars().all()
[docs]
async def get_ensemble_by_id(db: AsyncSession, id: int):
stmt = (
select(Ensemble)
.options(
selectinload(Ensemble.ensemble_ids),
)
.where(Ensemble.id == id)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
[docs]
async def remove_ensemble(db: AsyncSession, ensemble: Ensemble):
await db.delete(ensemble)
await db.commit()
[docs]
async def add_ensemble(db: AsyncSession, ensemble: Ensemble):
db.add(ensemble)
await db.commit()
await db.refresh(ensemble)
[docs]
async def update_ensemble(db: AsyncSession, ensemble: EnsembleUpdate):
stmt = select(Ensemble).where(Ensemble.id == ensemble.id)
result = await db.execute(stmt)
ensemble_db = result.scalar_one_or_none()
if not ensemble_db:
return None
former_containers = [
ensemble_container.ids_system_id
for ensemble_container in await ensemble_db.get_ensemble_ids(db)
]
# Update ensemble attributes
for key, value in ensemble.model_dump().items():
setattr(ensemble_db, key, value)
await db.commit()
await db.refresh(ensemble_db)
new_containers = ensemble.container_ids
added_containers = list(
filter(lambda x: x not in former_containers, new_containers)
)
removed_containers = list(
filter(lambda x: x not in new_containers, former_containers)
)
responses = []
for container_id in removed_containers:
res = await ensemble_db.remove_container(db, container_id)
responses.append(res)
for container_id in added_containers:
res = await ensemble_db.add_container(db, container_id)
responses.append(res)
return responses
[docs]
async def update_ensemble_status(db: AsyncSession, status: STATUS, ensemble: Ensemble):
ensemble.status = status
await db.commit()
await db.refresh(ensemble)