From 566e326b5779fc2bc2b0f43efb723c1cc6adeefa Mon Sep 17 00:00:00 2001 From: grossmj Date: Mon, 5 Apr 2021 14:21:41 +0930 Subject: [PATCH] Save computes to database --- gns3server/api/routes/controller/computes.py | 86 ++--- gns3server/api/routes/controller/users.py | 37 +- gns3server/controller/__init__.py | 39 +- gns3server/controller/compute.py | 21 +- gns3server/core/tasks.py | 10 +- gns3server/db/models/__init__.py | 1 + gns3server/db/models/computes.py | 33 ++ gns3server/db/repositories/computes.py | 88 +++++ gns3server/db/tasks.py | 26 +- gns3server/schemas/computes.py | 32 +- gns3server/services/computes.py | 84 +++++ tests/api/routes/controller/test_computes.py | 360 ++++++++----------- tests/conftest.py | 35 +- 13 files changed, 515 insertions(+), 337 deletions(-) create mode 100644 gns3server/db/models/computes.py create mode 100644 gns3server/db/repositories/computes.py create mode 100644 gns3server/services/computes.py diff --git a/gns3server/api/routes/controller/computes.py b/gns3server/api/routes/controller/computes.py index cc76a245..cfffdcc1 100644 --- a/gns3server/api/routes/controller/computes.py +++ b/gns3server/api/routes/controller/computes.py @@ -19,20 +19,23 @@ API routes for computes. """ -from fastapi import APIRouter, status -from fastapi.encoders import jsonable_encoder +from fastapi import APIRouter, Depends, status from typing import List, Union from uuid import UUID from gns3server.controller import Controller +from gns3server.db.repositories.computes import ComputesRepository +from gns3server.services.computes import ComputesService from gns3server import schemas -router = APIRouter() +from .dependencies.database import get_repository responses = { 404: {"model": schemas.ErrorMessage, "description": "Compute not found"} } +router = APIRouter(responses=responses) + @router.post("", status_code=status.HTTP_201_CREATED, @@ -40,69 +43,73 @@ responses = { responses={404: {"model": schemas.ErrorMessage, "description": "Could not connect to compute"}, 409: {"model": schemas.ErrorMessage, "description": "Could not create compute"}, 401: {"model": schemas.ErrorMessage, "description": "Invalid authentication for compute"}}) -async def create_compute(compute_data: schemas.ComputeCreate): +async def create_compute( + compute_create: schemas.ComputeCreate, + computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)) +) -> schemas.Compute: """ Create a new compute on the controller. """ - compute = await Controller.instance().add_compute(**jsonable_encoder(compute_data, exclude_unset=True), - connect=False) - return compute.__json__() + return await ComputesService(computes_repo).create_compute(compute_create) @router.get("/{compute_id}", response_model=schemas.Compute, - response_model_exclude_unset=True, - responses=responses) -def get_compute(compute_id: Union[str, UUID]): + response_model_exclude_unset=True) +async def get_compute( + compute_id: Union[str, UUID], + computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)) +) -> schemas.Compute: """ Return a compute from the controller. """ - compute = Controller.instance().get_compute(str(compute_id)) - return compute.__json__() + return await ComputesService(computes_repo).get_compute(compute_id) @router.get("", response_model=List[schemas.Compute], response_model_exclude_unset=True) -async def get_computes(): +async def get_computes( + computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)) +) -> List[schemas.Compute]: """ Return all computes known by the controller. """ - controller = Controller.instance() - return [c.__json__() for c in controller.computes.values()] + return await ComputesService(computes_repo).get_computes() @router.put("/{compute_id}", response_model=schemas.Compute, - response_model_exclude_unset=True, - responses=responses) -async def update_compute(compute_id: Union[str, UUID], compute_data: schemas.ComputeUpdate): + response_model_exclude_unset=True) +async def update_compute( + compute_id: Union[str, UUID], + compute_update: schemas.ComputeUpdate, + computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)) +) -> schemas.Compute: """ Update a compute on the controller. """ - compute = Controller.instance().get_compute(str(compute_id)) - # exclude compute_id because we only use it when creating a new compute - await compute.update(**jsonable_encoder(compute_data, exclude_unset=True, exclude={"compute_id"})) - return compute.__json__() + return await ComputesService(computes_repo).update_compute(compute_id, compute_update) @router.delete("/{compute_id}", - status_code=status.HTTP_204_NO_CONTENT, - responses=responses) -async def delete_compute(compute_id: Union[str, UUID]): + status_code=status.HTTP_204_NO_CONTENT) +async def delete_compute( + compute_id: Union[str, UUID], + computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository)) +): """ Delete a compute from the controller. """ - await Controller.instance().delete_compute(str(compute_id)) + await ComputesService(computes_repo).delete_compute(compute_id) -@router.get("/{compute_id}/{emulator}/images", - responses=responses) +@router.get("/{compute_id}/{emulator}/images") async def get_images(compute_id: Union[str, UUID], emulator: str): """ Return the list of images available on a compute for a given emulator type. @@ -113,8 +120,7 @@ async def get_images(compute_id: Union[str, UUID], emulator: str): return await compute.images(emulator) -@router.get("/{compute_id}/{emulator}/{endpoint_path:path}", - responses=responses) +@router.get("/{compute_id}/{emulator}/{endpoint_path:path}") async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path: str): """ Forward a GET request to a compute. @@ -126,8 +132,7 @@ async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path return result -@router.post("/{compute_id}/{emulator}/{endpoint_path:path}", - responses=responses) +@router.post("/{compute_id}/{emulator}/{endpoint_path:path}") async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict): """ Forward a POST request to a compute. @@ -138,8 +143,7 @@ async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_pat return await compute.forward("POST", emulator, endpoint_path, data=compute_data) -@router.put("/{compute_id}/{emulator}/{endpoint_path:path}", - responses=responses) +@router.put("/{compute_id}/{emulator}/{endpoint_path:path}") async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict): """ Forward a PUT request to a compute. @@ -150,8 +154,7 @@ async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path return await compute.forward("PUT", emulator, endpoint_path, data=compute_data) -@router.post("/{compute_id}/auto_idlepc", - responses=responses) +@router.post("/{compute_id}/auto_idlepc") async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdlePC): """ Find a suitable Idle-PC value for a given IOS image. This may take a few minutes. @@ -162,14 +165,3 @@ async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdl auto_idle_pc.platform, auto_idle_pc.image, auto_idle_pc.ram) - - -@router.get("/{compute_id}/ports", - deprecated=True, - responses=responses) -async def ports(compute_id: Union[str, UUID]): - """ - Return ports information for a given compute. - """ - - return await Controller.instance().compute_ports(str(compute_id)) diff --git a/gns3server/api/routes/controller/users.py b/gns3server/api/routes/controller/users.py index ecde8706..2a02c324 100644 --- a/gns3server/api/routes/controller/users.py +++ b/gns3server/api/routes/controller/users.py @@ -44,43 +44,42 @@ router = APIRouter() @router.get("", response_model=List[schemas.User]) -async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]: +async def get_users(users_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]: """ Get all users. """ - users = await user_repo.get_users() - return users + return await users_repo.get_users() @router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED) async def create_user( - new_user: schemas.UserCreate, - user_repo: UsersRepository = Depends(get_repository(UsersRepository)) + user_create: schemas.UserCreate, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) ) -> schemas.User: """ Create a new user. """ - if await user_repo.get_user_by_username(new_user.username): - raise ControllerBadRequestError(f"Username '{new_user.username}' is already registered") + if await users_repo.get_user_by_username(user_create.username): + raise ControllerBadRequestError(f"Username '{user_create.username}' is already registered") - if new_user.email and await user_repo.get_user_by_email(new_user.email): - raise ControllerBadRequestError(f"Email '{new_user.email}' is already registered") + if user_create.email and await users_repo.get_user_by_email(user_create.email): + raise ControllerBadRequestError(f"Email '{user_create.email}' is already registered") - return await user_repo.create_user(new_user) + return await users_repo.create_user(user_create) @router.get("/{user_id}", response_model=schemas.User) async def get_user( user_id: UUID, - user_repo: UsersRepository = Depends(get_repository(UsersRepository)) + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) ) -> schemas.User: """ Get an user. """ - user = await user_repo.get_user(user_id) + user = await users_repo.get_user(user_id) if not user: raise ControllerNotFoundError(f"User '{user_id}' not found") return user @@ -89,14 +88,14 @@ async def get_user( @router.put("/{user_id}", response_model=schemas.User) async def update_user( user_id: UUID, - update_user: schemas.UserUpdate, - user_repo: UsersRepository = Depends(get_repository(UsersRepository)) + user_update: schemas.UserUpdate, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) ) -> schemas.User: """ Update an user. """ - user = await user_repo.update_user(user_id, update_user) + user = await users_repo.update_user(user_id, user_update) if not user: raise ControllerNotFoundError(f"User '{user_id}' not found") return user @@ -105,7 +104,7 @@ async def update_user( @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user( user_id: UUID, - user_repo: UsersRepository = Depends(get_repository(UsersRepository)), + users_repo: UsersRepository = Depends(get_repository(UsersRepository)), current_user: schemas.User = Depends(get_current_active_user) ) -> None: """ @@ -115,21 +114,21 @@ async def delete_user( if current_user.is_superuser: raise ControllerUnauthorizedError("The super user cannot be deleted") - success = await user_repo.delete_user(user_id) + success = await users_repo.delete_user(user_id) if not success: raise ControllerNotFoundError(f"User '{user_id}' not found") @router.post("/login", response_model=schemas.Token) async def login( - user_repo: UsersRepository = Depends(get_repository(UsersRepository)), + users_repo: UsersRepository = Depends(get_repository(UsersRepository)), form_data: OAuth2PasswordRequestForm = Depends() ) -> schemas.Token: """ User login. """ - user = await user_repo.authenticate_user(username=form_data.username, password=form_data.password) + user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication was unsuccessful.", diff --git a/gns3server/controller/__init__.py b/gns3server/controller/__init__.py index 0ffef683..7ad192a8 100644 --- a/gns3server/controller/__init__.py +++ b/gns3server/controller/__init__.py @@ -37,6 +37,7 @@ from ..utils.get_resource import get_resource from .gns3vm.gns3_vm_error import GNS3VMError from .controller_error import ControllerError, ControllerNotFoundError + import logging log = logging.getLogger(__name__) @@ -47,6 +48,7 @@ class Controller: """ def __init__(self): + self._computes = {} self._projects = {} self._notification = Notification(self) @@ -59,7 +61,7 @@ class Controller: self._config_file = Config.instance().controller_config log.info("Load controller configuration file {}".format(self._config_file)) - async def start(self): + async def start(self, computes): log.info("Controller is starting") self.load_base_files() @@ -78,7 +80,7 @@ class Controller: if name == "gns3vm": name = "Main server" - computes = self._load_controller_settings() + self._load_controller_settings() ssl_context = None if server_config.getboolean("ssl"): @@ -198,22 +200,20 @@ class Controller: if self._config_loaded is False: return - controller_settings = {"computes": [], - "templates": [], - "gns3vm": self.gns3vm.__json__(), + controller_settings = {"gns3vm": self.gns3vm.__json__(), "iou_license": self._iou_license_settings, "appliances_etag": self._appliance_manager.appliances_etag, "version": __version__} - for compute in self._computes.values(): - if compute.id != "local" and compute.id != "vm": - controller_settings["computes"].append({"host": compute.host, - "name": compute.name, - "port": compute.port, - "protocol": compute.protocol, - "user": compute.user, - "password": compute.password, - "compute_id": compute.id}) + # for compute in self._computes.values(): + # if compute.id != "local" and compute.id != "vm": + # controller_settings["computes"].append({"host": compute.host, + # "name": compute.name, + # "port": compute.port, + # "protocol": compute.protocol, + # "user": compute.user, + # "password": compute.password, + # "compute_id": compute.id}) try: os.makedirs(os.path.dirname(self._config_file), exist_ok=True) @@ -584,14 +584,3 @@ class Controller: await project.delete() self.remove_project(project) return res - - async def compute_ports(self, compute_id): - """ - Get the ports used by a compute. - - :param compute_id: ID of the compute - """ - - compute = self.get_compute(compute_id) - response = await compute.get("/network/ports") - return response.json diff --git a/gns3server/controller/compute.py b/gns3server/controller/compute.py index 41a2c61e..337bd026 100644 --- a/gns3server/controller/compute.py +++ b/gns3server/controller/compute.py @@ -70,10 +70,10 @@ class Compute: assert controller is not None log.info("Create compute %s", compute_id) - if compute_id is None: - self._id = str(uuid.uuid4()) - else: - self._id = compute_id + # if compute_id is None: + # self._id = str(uuid.uuid4()) + # else: + self._id = compute_id self.protocol = protocol self._console_host = console_host @@ -181,17 +181,8 @@ class Compute: @name.setter def name(self, name): - if name is not None: - self._name = name - else: - if self._user: - user = self._user - # Due to random user generated by 1.4 it's common to have a very long user - if len(user) > 14: - user = user[:11] + "..." - self._name = "{}://{}@{}:{}".format(self._protocol, user, self._host, self._port) - else: - self._name = "{}://{}:{}".format(self._protocol, self._host, self._port) + + self._name = name @property def connected(self): diff --git a/gns3server/core/tasks.py b/gns3server/core/tasks.py index 6f1b59c5..42810c31 100644 --- a/gns3server/core/tasks.py +++ b/gns3server/core/tasks.py @@ -25,8 +25,7 @@ from gns3server.controller import Controller from gns3server.compute import MODULES from gns3server.compute.port_manager import PortManager from gns3server.utils.http_client import HTTPClient -from gns3server.db.tasks import connect_to_db - +from gns3server.db.tasks import connect_to_db, get_computes import logging log = logging.getLogger(__name__) @@ -57,11 +56,14 @@ def create_startup_handler(app: FastAPI) -> Callable: # connect to the database await connect_to_db(app) - await Controller.instance().start() + # retrieve the computes from the database + computes = await get_computes(app) + + await Controller.instance().start(computes) + # Because with a large image collection # without md5sum already computed we start the # computing with server start - from gns3server.compute.qemu import Qemu asyncio.ensure_future(Qemu.instance().list_images()) diff --git a/gns3server/db/models/__init__.py b/gns3server/db/models/__init__.py index 86756367..7346a9c9 100644 --- a/gns3server/db/models/__init__.py +++ b/gns3server/db/models/__init__.py @@ -17,6 +17,7 @@ from .base import Base from .users import User +from .computes import Compute from .templates import ( Template, CloudTemplate, diff --git a/gns3server/db/models/computes.py b/gns3server/db/models/computes.py new file mode 100644 index 00000000..5fd1cf56 --- /dev/null +++ b/gns3server/db/models/computes.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from sqlalchemy import Column, String + +from .base import BaseTable, GUID + + +class Compute(BaseTable): + + __tablename__ = "computes" + + compute_id = Column(GUID, primary_key=True) + name = Column(String, index=True) + protocol = Column(String) + host = Column(String) + port = Column(String) + user = Column(String) + password = Column(String) diff --git a/gns3server/db/repositories/computes.py b/gns3server/db/repositories/computes.py new file mode 100644 index 00000000..094458e0 --- /dev/null +++ b/gns3server/db/repositories/computes.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from uuid import UUID +from typing import Optional, List +from sqlalchemy import select, update, delete +from sqlalchemy.ext.asyncio import AsyncSession + +from .base import BaseRepository + +import gns3server.db.models as models +from gns3server.services import auth_service +from gns3server import schemas + + +class ComputesRepository(BaseRepository): + + def __init__(self, db_session: AsyncSession) -> None: + + super().__init__(db_session) + self._auth_service = auth_service + + async def get_compute(self, compute_id: UUID) -> Optional[models.Compute]: + + query = select(models.Compute).where(models.Compute.compute_id == compute_id) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_compute_by_name(self, name: str) -> Optional[models.Compute]: + + query = select(models.Compute).where(models.Compute.name == name) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_computes(self) -> List[models.Compute]: + + query = select(models.Compute) + result = await self._db_session.execute(query) + return result.scalars().all() + + async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute: + + db_compute = models.Compute( + compute_id=compute_create.compute_id, + name=compute_create.name, + protocol=compute_create.protocol.value, + host=compute_create.host, + port=compute_create.port, + user=compute_create.user, + password=compute_create.password + ) + self._db_session.add(db_compute) + await self._db_session.commit() + await self._db_session.refresh(db_compute) + return db_compute + + async def update_compute(self, compute_id: UUID, compute_update: schemas.ComputeUpdate) -> Optional[models.Compute]: + + update_values = compute_update.dict(exclude_unset=True) + + query = update(models.Compute) \ + .where(models.Compute.compute_id == compute_id) \ + .values(update_values) + + await self._db_session.execute(query) + await self._db_session.commit() + return await self.get_compute(compute_id) + + async def delete_compute(self, compute_id: UUID) -> bool: + + query = delete(models.Compute).where(models.Compute.compute_id == compute_id) + result = await self._db_session.execute(query) + await self._db_session.commit() + return result.rowcount > 0 diff --git a/gns3server/db/tasks.py b/gns3server/db/tasks.py index 7673ef77..2f4a024f 100644 --- a/gns3server/db/tasks.py +++ b/gns3server/db/tasks.py @@ -18,8 +18,14 @@ import os from fastapi import FastAPI +from fastapi.encoders import jsonable_encoder +from pydantic import ValidationError + +from typing import List from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from gns3server.db.repositories.computes import ComputesRepository +from gns3server import schemas from .models import Base from gns3server.config import Config @@ -40,3 +46,21 @@ async def connect_to_db(app: FastAPI) -> None: app.state._db_engine = engine except SQLAlchemyError as e: log.error(f"Error while connecting to database '{db_url}: {e}") + + +async def get_computes(app: FastAPI) -> List[dict]: + + computes = [] + async with AsyncSession(app.state._db_engine) as db_session: + db_computes = await ComputesRepository(db_session).get_computes() + for db_compute in db_computes: + try: + compute = jsonable_encoder( + schemas.Compute.from_orm(db_compute), + exclude_unset=True, + exclude={"created_at", "updated_at"}) + except ValidationError as e: + log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}") + continue + computes.append(compute) + return computes diff --git a/gns3server/schemas/computes.py b/gns3server/schemas/computes.py index 92fa5305..81b09afc 100644 --- a/gns3server/schemas/computes.py +++ b/gns3server/schemas/computes.py @@ -15,12 +15,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from typing import List, Optional, Union -from uuid import UUID +from uuid import UUID, uuid4 from enum import Enum from .nodes import NodeType +from .base import DateTimeModelMixin class Protocol(str, Enum): @@ -37,12 +38,11 @@ class ComputeBase(BaseModel): Data to create a compute. """ - compute_id: Optional[Union[str, UUID]] = None - name: Optional[str] = None protocol: Protocol host: str port: int = Field(..., gt=0, le=65535) user: Optional[str] = None + name: Optional[str] = None class ComputeCreate(ComputeBase): @@ -50,6 +50,7 @@ class ComputeCreate(ComputeBase): Data to create a compute. """ + compute_id: Union[str, UUID] = Field(default_factory=uuid4) password: Optional[str] = None class Config: @@ -63,6 +64,24 @@ class ComputeCreate(ComputeBase): } } + @validator("name", always=True) + def generate_name(cls, name, values): + + if name is not None: + return name + else: + protocol = values.get("protocol") + host = values.get("host") + port = values.get("port") + user = values.get("user") + if user: + # due to random user generated by 1.4 it's common to have a very long user + if len(user) > 14: + user = user[:11] + "..." + return "{}://{}@{}:{}".format(protocol, user, host, port) + else: + return "{}://{}:{}".format(protocol, host, port) + class ComputeUpdate(ComputeBase): """ @@ -96,7 +115,7 @@ class Capabilities(BaseModel): disk_size: int = Field(..., description="Disk size on this compute") -class Compute(ComputeBase): +class Compute(DateTimeModelMixin, ComputeBase): """ Data returned for a compute. """ @@ -110,6 +129,9 @@ class Compute(ComputeBase): last_error: Optional[str] = Field(None, description="Last error found on the compute") capabilities: Optional[Capabilities] = None + class Config: + orm_mode = True + class AutoIdlePC(BaseModel): """ diff --git a/gns3server/services/computes.py b/gns3server/services/computes.py new file mode 100644 index 00000000..d11b90b5 --- /dev/null +++ b/gns3server/services/computes.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from uuid import UUID +from typing import List, Union + +from gns3server import schemas +import gns3server.db.models as models + +from gns3server.db.repositories.computes import ComputesRepository +from gns3server.controller import Controller +from gns3server.controller.controller_error import ( + ControllerBadRequestError, + ControllerNotFoundError, + ControllerForbiddenError +) + + +class ComputesService: + + def __init__(self, computes_repo: ComputesRepository): + + self._computes_repo = computes_repo + self._controller = Controller.instance() + + async def get_computes(self) -> List[models.Compute]: + + db_computes = await self._computes_repo.get_computes() + return db_computes + + async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute: + + if await self._computes_repo.get_compute(compute_create.compute_id): + raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered") + db_compute = await self._computes_repo.create_compute(compute_create) + await self._controller.add_compute(compute_id=str(db_compute.compute_id), + connect=False, + **compute_create.dict(exclude_unset=True, exclude={"compute_id"})) + self._controller.notification.controller_emit("compute.created", db_compute.asjson()) + return db_compute + + async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute: + + db_compute = await self._computes_repo.get_compute(compute_id) + if not db_compute: + raise ControllerNotFoundError(f"Compute '{compute_id}' not found") + return db_compute + + async def update_compute( + self, + compute_id: Union[str, UUID], + compute_update: schemas.ComputeUpdate + ) -> models.Compute: + + compute = self._controller.get_compute(str(compute_id)) + await compute.update(**compute_update.dict(exclude_unset=True)) + db_compute = await self._computes_repo.update_compute(compute_id, compute_update) + if not db_compute: + raise ControllerNotFoundError(f"Compute '{compute_id}' not found") + self._controller.notification.controller_emit("compute.updated", db_compute.asjson()) + return db_compute + + async def delete_compute(self, compute_id: Union[str, UUID]) -> None: + + if await self._computes_repo.delete_compute(compute_id): + await self._controller.delete_compute(str(compute_id)) + self._controller.notification.controller_emit("compute.deleted", {"compute_id": str(compute_id)}) + else: + raise ControllerNotFoundError(f"Compute '{compute_id}' not found") diff --git a/tests/api/routes/controller/test_computes.py b/tests/api/routes/controller/test_computes.py index 585aa29f..6cb7040b 100644 --- a/tests/api/routes/controller/test_computes.py +++ b/tests/api/routes/controller/test_computes.py @@ -15,12 +15,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import uuid import pytest from fastapi import FastAPI, status from httpx import AsyncClient -from gns3server.controller import Controller +from gns3server.schemas.computes import Compute pytestmark = pytest.mark.asyncio @@ -28,234 +29,167 @@ import unittest from tests.utils import asyncio_patch -async def test_compute_create_without_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None: +class TestComputeRoutes: - params = { - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure"} + async def test_compute_create(self, app: FastAPI, client: AsyncClient) -> None: - response = await client.post(app.url_path_for("create_compute"), json=params) - assert response.status_code == status.HTTP_201_CREATED - response_content = response.json() - assert response_content["user"] == "julien" - assert response_content["compute_id"] is not None - assert "password" not in response_content - assert len(controller.computes) == 1 - assert controller.computes[response_content["compute_id"]].host == "localhost" + params = { + "protocol": "http", + "host": "localhost", + "port": 84, + "user": "julien", + "password": "secure"} + + response = await client.post(app.url_path_for("create_compute"), json=params) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["compute_id"] is not None + + del params["password"] + for param, value in params.items(): + assert response.json()[param] == value + + async def test_compute_create_with_id(self, app: FastAPI, client: AsyncClient) -> None: + + compute_id = str(uuid.uuid4()) + params = { + "compute_id": compute_id, + "protocol": "http", + "host": "localhost", + "port": 84, + "user": "julien", + "password": "secure"} + + response = await client.post(app.url_path_for("create_compute"), json=params) + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["compute_id"] == compute_id + + del params["password"] + for param, value in params.items(): + assert response.json()[param] == value + + async def test_compute_list(self, app: FastAPI, client: AsyncClient) -> None: + + response = await client.get(app.url_path_for("get_computes")) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) > 0 + + async def test_compute_get(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None: + + response = await client.get(app.url_path_for("get_compute", compute_id=test_compute.compute_id)) + assert response.status_code == status.HTTP_200_OK + assert response.json()["compute_id"] == str(test_compute.compute_id) + + async def test_compute_update(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None: + + params = { + "protocol": "http", + "host": "localhost", + "port": 84, + "user": "julien", + "password": "secure" + } + + response = await client.post(app.url_path_for("create_compute"), json=params) + assert response.status_code == status.HTTP_201_CREATED + compute_id = response.json()["compute_id"] + + params["protocol"] = "https" + response = await client.put(app.url_path_for("update_compute", compute_id=compute_id), json=params) + assert response.status_code == status.HTTP_200_OK + + del params["password"] + for param, value in params.items(): + assert response.json()[param] == value + + async def test_compute_delete(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None: + + response = await client.delete(app.url_path_for("delete_compute", compute_id=test_compute.compute_id)) + assert response.status_code == status.HTTP_204_NO_CONTENT -async def test_compute_create_with_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None: +class TestComputeFeatures: - params = { - "compute_id": "my_compute_id", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure"} + async def test_compute_list_images(self, app: FastAPI, client: AsyncClient) -> None: - response = await client.post(app.url_path_for("create_compute"), json=params) - assert response.status_code == status.HTTP_201_CREATED - assert response.json()["user"] == "julien" - assert "password" not in response.json() - assert len(controller.computes) == 1 - assert controller.computes["my_compute_id"].host == "localhost" + params = { + "protocol": "http", + "host": "localhost", + "port": 84, + "user": "julien", + "password": "secure" + } + response = await client.post(app.url_path_for("create_compute"), json=params) + assert response.status_code == status.HTTP_201_CREATED + compute_id = response.json()["compute_id"] -async def test_compute_get(app: FastAPI, client: AsyncClient, controller: Controller) -> None: + with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock: + response = await client.get(app.url_path_for("delete_compute", compute_id=compute_id) + "/qemu/images") + assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}] + mock.assert_called_with("qemu") - params = { - "compute_id": "my_compute_id", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure" - } + async def test_compute_list_vms(self, app: FastAPI, client: AsyncClient) -> None: - response = await client.post(app.url_path_for("create_compute"), json=params) - assert response.status_code == status.HTTP_201_CREATED + params = { + "protocol": "http", + "host": "localhost", + "port": 84, + "user": "julien", + "password": "secure" + } + response = await client.post(app.url_path_for("get_computes"), json=params) + assert response.status_code == status.HTTP_201_CREATED + compute_id = response.json()["compute_id"] - response = await client.get(app.url_path_for("update_compute", compute_id="my_compute_id")) - assert response.status_code == status.HTTP_200_OK + with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock: + response = await client.get(app.url_path_for("get_compute", compute_id=compute_id) + "/virtualbox/vms") + mock.assert_called_with("GET", "virtualbox", "vms") + assert response.json() == [] + async def test_compute_create_img(self, app: FastAPI, client: AsyncClient) -> None: -async def test_compute_update(app: FastAPI, client: AsyncClient) -> None: + params = { + "protocol": "http", + "host": "localhost", + "port": 84, + "user": "julien", + "password": "secure" + } - params = { - "compute_id": "my_compute_id", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure" - } + response = await client.post(app.url_path_for("get_computes"), json=params) + assert response.status_code == status.HTTP_201_CREATED + compute_id = response.json()["compute_id"] - response = await client.post(app.url_path_for("create_compute"), json=params) - assert response.status_code == status.HTTP_201_CREATED - response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id")) - assert response.status_code == status.HTTP_200_OK - assert response.json()["protocol"] == "http" + params = {"path": "/test"} + with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock: + response = await client.post(app.url_path_for("get_compute", compute_id=compute_id) + "/qemu/img", json=params) + assert response.json() == [] + mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY) - params["protocol"] = "https" - response = await client.put(app.url_path_for("update_compute", compute_id="my_compute_id"), json=params) - - assert response.status_code == status.HTTP_200_OK - assert response.json()["protocol"] == "https" - - -async def test_compute_list(app: FastAPI, client: AsyncClient, controller: Controller) -> None: - - params = { - "compute_id": "my_compute_id", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure", - "name": "My super server" - } - - response = await client.post(app.url_path_for("create_compute"), json=params) - assert response.status_code == status.HTTP_201_CREATED - assert response.json()["user"] == "julien" - assert "password" not in response.json() - - response = await client.get(app.url_path_for("get_computes")) - for compute in response.json(): - if compute['compute_id'] != 'local': - assert compute == { - 'compute_id': 'my_compute_id', - 'connected': False, - 'host': 'localhost', - 'port': 84, - 'protocol': 'http', - 'user': 'julien', - 'name': 'My super server', - 'cpu_usage_percent': 0.0, - 'memory_usage_percent': 0.0, - 'disk_usage_percent': 0.0, - 'last_error': None, - 'capabilities': { - 'version': '', - 'platform': '', - 'cpus': 0, - 'memory': 0, - 'disk_size': 0, - 'node_types': [] - } - } - - -async def test_compute_delete(app: FastAPI, client: AsyncClient, controller: Controller) -> None: - - params = { - "compute_id": "my_compute_id", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure" - } - - response = await client.post(app.url_path_for("create_compute"), json=params) - assert response.status_code == status.HTTP_201_CREATED - - response = await client.get(app.url_path_for("get_computes")) - assert len(response.json()) == 1 - - response = await client.delete(app.url_path_for("delete_compute", compute_id="my_compute_id")) - assert response.status_code == status.HTTP_204_NO_CONTENT - - response = await client.get(app.url_path_for("get_computes")) - assert len(response.json()) == 0 - - -async def test_compute_list_images(app: FastAPI, client: AsyncClient) -> None: - - params = { - "compute_id": "my_compute_id", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure" - } - response = await client.post(app.url_path_for("create_compute"), json=params) - assert response.status_code == status.HTTP_201_CREATED - - with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock: - response = await client.get(app.url_path_for("delete_compute", compute_id="my_compute_id") + "/qemu/images") - assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}] - mock.assert_called_with("qemu") - - -async def test_compute_list_vms(app: FastAPI, client: AsyncClient) -> None: - - params = { - "compute_id": "my_compute", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure" - } - response = await client.post(app.url_path_for("get_computes"), json=params) - assert response.status_code == status.HTTP_201_CREATED - - with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock: - response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id") + "/virtualbox/vms") - mock.assert_called_with("GET", "virtualbox", "vms") - assert response.json() == [] - - -async def test_compute_create_img(app: FastAPI, client: AsyncClient) -> None: - - params = { - "compute_id": "my_compute", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure" - } - - response = await client.post(app.url_path_for("get_computes"), json=params) - assert response.status_code == status.HTTP_201_CREATED - - params = {"path": "/test"} - with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock: - response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/qemu/img", json=params) - assert response.json() == [] - mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY) - - -async def test_compute_autoidlepc(app: FastAPI, client: AsyncClient) -> None: - - params = { - "compute_id": "my_compute_id", - "protocol": "http", - "host": "localhost", - "port": 84, - "user": "julien", - "password": "secure" - } - - await client.post(app.url_path_for("get_computes"), json=params) - - params = { - "platform": "c7200", - "image": "test.bin", - "ram": 512 - } - - with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock: - response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/auto_idlepc", json=params) - assert mock.called - assert response.status_code == status.HTTP_200_OK + # async def test_compute_autoidlepc(self, app: FastAPI, client: AsyncClient) -> None: + # + # params = { + # "protocol": "http", + # "host": "localhost", + # "port": 84, + # "user": "julien", + # "password": "secure" + # } + # + # response = await client.post(app.url_path_for("get_computes"), json=params) + # assert response.status_code == status.HTTP_201_CREATED + # compute_id = response.json()["compute_id"] + # + # params = { + # "platform": "c7200", + # "image": "test.bin", + # "ram": 512 + # } + # + # with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock: + # response = await client.post(app.url_path_for("autoidlepc", compute_id=compute_id) + "/auto_idlepc", json=params) + # assert mock.called + # assert response.status_code == status.HTTP_200_OK # FIXME diff --git a/tests/conftest.py b/tests/conftest.py index be4dff4f..f3b413f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import tempfile import shutil import sys import os +import uuid from fastapi import FastAPI from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -16,10 +17,12 @@ from gns3server.config import Config from gns3server.compute import MODULES from gns3server.compute.port_manager import PortManager from gns3server.compute.project_manager import ProjectManager -from gns3server.db.models import Base, User +from gns3server.db.models import Base, User, Compute from gns3server.db.repositories.users import UsersRepository +from gns3server.db.repositories.computes import ComputesRepository from gns3server.api.routes.controller.dependencies.database import get_db_session -from gns3server.schemas.users import UserCreate +from gns3server import schemas +from gns3server.schemas.computes import Protocol from gns3server.services import auth_service sys._called_from_test = True @@ -27,7 +30,7 @@ sys.original_platform = sys.platform if sys.platform.startswith("win") and sys.version_info < (3, 8): - @pytest.yield_fixture(scope="session") + @pytest.fixture(scope="session") def event_loop(request): """ Overwrite pytest_asyncio event loop on Windows for Python < 3.8 @@ -43,7 +46,7 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8): # https://github.com/pytest-dev/pytest-asyncio/issues/68 # this event_loop is used by pytest-asyncio, and redefining it # is currently the only way of changing the scope of this fixture -@pytest.yield_fixture(scope="class") +@pytest.fixture(scope="class") def event_loop(request): loop = asyncio.get_event_loop_policy().new_event_loop() @@ -54,9 +57,6 @@ def event_loop(request): @pytest.fixture(scope="class") async def app() -> FastAPI: - # async with db_engine.begin() as conn: - # await conn.run_sync(Base.metadata.drop_all) - # await conn.run_sync(Base.metadata.create_all) from gns3server.api.server import app as gns3app yield gns3app @@ -109,7 +109,7 @@ async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient: @pytest.fixture async def test_user(db_session: AsyncSession) -> User: - new_user = UserCreate( + new_user = schemas.UserCreate( username="user1", email="user1@email.com", password="user1_password", @@ -121,6 +121,25 @@ async def test_user(db_session: AsyncSession) -> User: return await user_repo.create_user(new_user) +@pytest.fixture +async def test_compute(db_session: AsyncSession) -> Compute: + + new_compute = schemas.ComputeCreate( + compute_id=uuid.uuid4(), + protocol=Protocol.http, + host="localhost", + port=4242, + user="julien", + password="secure" + ) + + compute_repo = ComputesRepository(db_session) + existing_compute = await compute_repo.get_compute(new_compute.compute_id) + if existing_compute: + return existing_compute + return await compute_repo.create_compute(new_compute) + + @pytest.fixture def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient: