Save computes to database

This commit is contained in:
grossmj 2021-04-05 14:21:41 +09:30
parent e607793e74
commit 566e326b57
13 changed files with 515 additions and 337 deletions

View File

@ -19,20 +19,23 @@
API routes for computes.
"""
from fastapi import APIRouter, status
from fastapi.encoders import jsonable_encoder
from fastapi import APIRouter, Depends, status
from typing import List, Union
from uuid import UUID
from gns3server.controller import Controller
from gns3server.db.repositories.computes import ComputesRepository
from gns3server.services.computes import ComputesService
from gns3server import schemas
router = APIRouter()
from .dependencies.database import get_repository
responses = {
404: {"model": schemas.ErrorMessage, "description": "Compute not found"}
}
router = APIRouter(responses=responses)
@router.post("",
status_code=status.HTTP_201_CREATED,
@ -40,69 +43,73 @@ responses = {
responses={404: {"model": schemas.ErrorMessage, "description": "Could not connect to compute"},
409: {"model": schemas.ErrorMessage, "description": "Could not create compute"},
401: {"model": schemas.ErrorMessage, "description": "Invalid authentication for compute"}})
async def create_compute(compute_data: schemas.ComputeCreate):
async def create_compute(
compute_create: schemas.ComputeCreate,
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> schemas.Compute:
"""
Create a new compute on the controller.
"""
compute = await Controller.instance().add_compute(**jsonable_encoder(compute_data, exclude_unset=True),
connect=False)
return compute.__json__()
return await ComputesService(computes_repo).create_compute(compute_create)
@router.get("/{compute_id}",
response_model=schemas.Compute,
response_model_exclude_unset=True,
responses=responses)
def get_compute(compute_id: Union[str, UUID]):
response_model_exclude_unset=True)
async def get_compute(
compute_id: Union[str, UUID],
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> schemas.Compute:
"""
Return a compute from the controller.
"""
compute = Controller.instance().get_compute(str(compute_id))
return compute.__json__()
return await ComputesService(computes_repo).get_compute(compute_id)
@router.get("",
response_model=List[schemas.Compute],
response_model_exclude_unset=True)
async def get_computes():
async def get_computes(
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> List[schemas.Compute]:
"""
Return all computes known by the controller.
"""
controller = Controller.instance()
return [c.__json__() for c in controller.computes.values()]
return await ComputesService(computes_repo).get_computes()
@router.put("/{compute_id}",
response_model=schemas.Compute,
response_model_exclude_unset=True,
responses=responses)
async def update_compute(compute_id: Union[str, UUID], compute_data: schemas.ComputeUpdate):
response_model_exclude_unset=True)
async def update_compute(
compute_id: Union[str, UUID],
compute_update: schemas.ComputeUpdate,
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
) -> schemas.Compute:
"""
Update a compute on the controller.
"""
compute = Controller.instance().get_compute(str(compute_id))
# exclude compute_id because we only use it when creating a new compute
await compute.update(**jsonable_encoder(compute_data, exclude_unset=True, exclude={"compute_id"}))
return compute.__json__()
return await ComputesService(computes_repo).update_compute(compute_id, compute_update)
@router.delete("/{compute_id}",
status_code=status.HTTP_204_NO_CONTENT,
responses=responses)
async def delete_compute(compute_id: Union[str, UUID]):
status_code=status.HTTP_204_NO_CONTENT)
async def delete_compute(
compute_id: Union[str, UUID],
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
):
"""
Delete a compute from the controller.
"""
await Controller.instance().delete_compute(str(compute_id))
await ComputesService(computes_repo).delete_compute(compute_id)
@router.get("/{compute_id}/{emulator}/images",
responses=responses)
@router.get("/{compute_id}/{emulator}/images")
async def get_images(compute_id: Union[str, UUID], emulator: str):
"""
Return the list of images available on a compute for a given emulator type.
@ -113,8 +120,7 @@ async def get_images(compute_id: Union[str, UUID], emulator: str):
return await compute.images(emulator)
@router.get("/{compute_id}/{emulator}/{endpoint_path:path}",
responses=responses)
@router.get("/{compute_id}/{emulator}/{endpoint_path:path}")
async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path: str):
"""
Forward a GET request to a compute.
@ -126,8 +132,7 @@ async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path
return result
@router.post("/{compute_id}/{emulator}/{endpoint_path:path}",
responses=responses)
@router.post("/{compute_id}/{emulator}/{endpoint_path:path}")
async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict):
"""
Forward a POST request to a compute.
@ -138,8 +143,7 @@ async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_pat
return await compute.forward("POST", emulator, endpoint_path, data=compute_data)
@router.put("/{compute_id}/{emulator}/{endpoint_path:path}",
responses=responses)
@router.put("/{compute_id}/{emulator}/{endpoint_path:path}")
async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict):
"""
Forward a PUT request to a compute.
@ -150,8 +154,7 @@ async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path
return await compute.forward("PUT", emulator, endpoint_path, data=compute_data)
@router.post("/{compute_id}/auto_idlepc",
responses=responses)
@router.post("/{compute_id}/auto_idlepc")
async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdlePC):
"""
Find a suitable Idle-PC value for a given IOS image. This may take a few minutes.
@ -162,14 +165,3 @@ async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdl
auto_idle_pc.platform,
auto_idle_pc.image,
auto_idle_pc.ram)
@router.get("/{compute_id}/ports",
deprecated=True,
responses=responses)
async def ports(compute_id: Union[str, UUID]):
"""
Return ports information for a given compute.
"""
return await Controller.instance().compute_ports(str(compute_id))

View File

@ -44,43 +44,42 @@ router = APIRouter()
@router.get("", response_model=List[schemas.User])
async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
async def get_users(users_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
"""
Get all users.
"""
users = await user_repo.get_users()
return users
return await users_repo.get_users()
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
async def create_user(
new_user: schemas.UserCreate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
user_create: schemas.UserCreate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Create a new user.
"""
if await user_repo.get_user_by_username(new_user.username):
raise ControllerBadRequestError(f"Username '{new_user.username}' is already registered")
if await users_repo.get_user_by_username(user_create.username):
raise ControllerBadRequestError(f"Username '{user_create.username}' is already registered")
if new_user.email and await user_repo.get_user_by_email(new_user.email):
raise ControllerBadRequestError(f"Email '{new_user.email}' is already registered")
if user_create.email and await users_repo.get_user_by_email(user_create.email):
raise ControllerBadRequestError(f"Email '{user_create.email}' is already registered")
return await user_repo.create_user(new_user)
return await users_repo.create_user(user_create)
@router.get("/{user_id}", response_model=schemas.User)
async def get_user(
user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Get an user.
"""
user = await user_repo.get_user(user_id)
user = await users_repo.get_user(user_id)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")
return user
@ -89,14 +88,14 @@ async def get_user(
@router.put("/{user_id}", response_model=schemas.User)
async def update_user(
user_id: UUID,
update_user: schemas.UserUpdate,
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
user_update: schemas.UserUpdate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Update an user.
"""
user = await user_repo.update_user(user_id, update_user)
user = await users_repo.update_user(user_id, user_update)
if not user:
raise ControllerNotFoundError(f"User '{user_id}' not found")
return user
@ -105,7 +104,7 @@ async def update_user(
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(
user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
) -> None:
"""
@ -115,21 +114,21 @@ async def delete_user(
if current_user.is_superuser:
raise ControllerUnauthorizedError("The super user cannot be deleted")
success = await user_repo.delete_user(user_id)
success = await users_repo.delete_user(user_id)
if not success:
raise ControllerNotFoundError(f"User '{user_id}' not found")
@router.post("/login", response_model=schemas.Token)
async def login(
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
form_data: OAuth2PasswordRequestForm = Depends()
) -> schemas.Token:
"""
User login.
"""
user = await user_repo.authenticate_user(username=form_data.username, password=form_data.password)
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.",

View File

@ -37,6 +37,7 @@ from ..utils.get_resource import get_resource
from .gns3vm.gns3_vm_error import GNS3VMError
from .controller_error import ControllerError, ControllerNotFoundError
import logging
log = logging.getLogger(__name__)
@ -47,6 +48,7 @@ class Controller:
"""
def __init__(self):
self._computes = {}
self._projects = {}
self._notification = Notification(self)
@ -59,7 +61,7 @@ class Controller:
self._config_file = Config.instance().controller_config
log.info("Load controller configuration file {}".format(self._config_file))
async def start(self):
async def start(self, computes):
log.info("Controller is starting")
self.load_base_files()
@ -78,7 +80,7 @@ class Controller:
if name == "gns3vm":
name = "Main server"
computes = self._load_controller_settings()
self._load_controller_settings()
ssl_context = None
if server_config.getboolean("ssl"):
@ -198,22 +200,20 @@ class Controller:
if self._config_loaded is False:
return
controller_settings = {"computes": [],
"templates": [],
"gns3vm": self.gns3vm.__json__(),
controller_settings = {"gns3vm": self.gns3vm.__json__(),
"iou_license": self._iou_license_settings,
"appliances_etag": self._appliance_manager.appliances_etag,
"version": __version__}
for compute in self._computes.values():
if compute.id != "local" and compute.id != "vm":
controller_settings["computes"].append({"host": compute.host,
"name": compute.name,
"port": compute.port,
"protocol": compute.protocol,
"user": compute.user,
"password": compute.password,
"compute_id": compute.id})
# for compute in self._computes.values():
# if compute.id != "local" and compute.id != "vm":
# controller_settings["computes"].append({"host": compute.host,
# "name": compute.name,
# "port": compute.port,
# "protocol": compute.protocol,
# "user": compute.user,
# "password": compute.password,
# "compute_id": compute.id})
try:
os.makedirs(os.path.dirname(self._config_file), exist_ok=True)
@ -584,14 +584,3 @@ class Controller:
await project.delete()
self.remove_project(project)
return res
async def compute_ports(self, compute_id):
"""
Get the ports used by a compute.
:param compute_id: ID of the compute
"""
compute = self.get_compute(compute_id)
response = await compute.get("/network/ports")
return response.json

View File

@ -70,10 +70,10 @@ class Compute:
assert controller is not None
log.info("Create compute %s", compute_id)
if compute_id is None:
self._id = str(uuid.uuid4())
else:
self._id = compute_id
# if compute_id is None:
# self._id = str(uuid.uuid4())
# else:
self._id = compute_id
self.protocol = protocol
self._console_host = console_host
@ -181,17 +181,8 @@ class Compute:
@name.setter
def name(self, name):
if name is not None:
self._name = name
else:
if self._user:
user = self._user
# Due to random user generated by 1.4 it's common to have a very long user
if len(user) > 14:
user = user[:11] + "..."
self._name = "{}://{}@{}:{}".format(self._protocol, user, self._host, self._port)
else:
self._name = "{}://{}:{}".format(self._protocol, self._host, self._port)
self._name = name
@property
def connected(self):

View File

@ -25,8 +25,7 @@ from gns3server.controller import Controller
from gns3server.compute import MODULES
from gns3server.compute.port_manager import PortManager
from gns3server.utils.http_client import HTTPClient
from gns3server.db.tasks import connect_to_db
from gns3server.db.tasks import connect_to_db, get_computes
import logging
log = logging.getLogger(__name__)
@ -57,11 +56,14 @@ def create_startup_handler(app: FastAPI) -> Callable:
# connect to the database
await connect_to_db(app)
await Controller.instance().start()
# retrieve the computes from the database
computes = await get_computes(app)
await Controller.instance().start(computes)
# Because with a large image collection
# without md5sum already computed we start the
# computing with server start
from gns3server.compute.qemu import Qemu
asyncio.ensure_future(Qemu.instance().list_images())

View File

@ -17,6 +17,7 @@
from .base import Base
from .users import User
from .computes import Compute
from .templates import (
Template,
CloudTemplate,

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from sqlalchemy import Column, String
from .base import BaseTable, GUID
class Compute(BaseTable):
__tablename__ = "computes"
compute_id = Column(GUID, primary_key=True)
name = Column(String, index=True)
protocol = Column(String)
host = Column(String)
port = Column(String)
user = Column(String)
password = Column(String)

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from uuid import UUID
from typing import Optional, List
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from .base import BaseRepository
import gns3server.db.models as models
from gns3server.services import auth_service
from gns3server import schemas
class ComputesRepository(BaseRepository):
def __init__(self, db_session: AsyncSession) -> None:
super().__init__(db_session)
self._auth_service = auth_service
async def get_compute(self, compute_id: UUID) -> Optional[models.Compute]:
query = select(models.Compute).where(models.Compute.compute_id == compute_id)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_compute_by_name(self, name: str) -> Optional[models.Compute]:
query = select(models.Compute).where(models.Compute.name == name)
result = await self._db_session.execute(query)
return result.scalars().first()
async def get_computes(self) -> List[models.Compute]:
query = select(models.Compute)
result = await self._db_session.execute(query)
return result.scalars().all()
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
db_compute = models.Compute(
compute_id=compute_create.compute_id,
name=compute_create.name,
protocol=compute_create.protocol.value,
host=compute_create.host,
port=compute_create.port,
user=compute_create.user,
password=compute_create.password
)
self._db_session.add(db_compute)
await self._db_session.commit()
await self._db_session.refresh(db_compute)
return db_compute
async def update_compute(self, compute_id: UUID, compute_update: schemas.ComputeUpdate) -> Optional[models.Compute]:
update_values = compute_update.dict(exclude_unset=True)
query = update(models.Compute) \
.where(models.Compute.compute_id == compute_id) \
.values(update_values)
await self._db_session.execute(query)
await self._db_session.commit()
return await self.get_compute(compute_id)
async def delete_compute(self, compute_id: UUID) -> bool:
query = delete(models.Compute).where(models.Compute.compute_id == compute_id)
result = await self._db_session.execute(query)
await self._db_session.commit()
return result.rowcount > 0

View File

@ -18,8 +18,14 @@
import os
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from pydantic import ValidationError
from typing import List
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from gns3server.db.repositories.computes import ComputesRepository
from gns3server import schemas
from .models import Base
from gns3server.config import Config
@ -40,3 +46,21 @@ async def connect_to_db(app: FastAPI) -> None:
app.state._db_engine = engine
except SQLAlchemyError as e:
log.error(f"Error while connecting to database '{db_url}: {e}")
async def get_computes(app: FastAPI) -> List[dict]:
computes = []
async with AsyncSession(app.state._db_engine) as db_session:
db_computes = await ComputesRepository(db_session).get_computes()
for db_compute in db_computes:
try:
compute = jsonable_encoder(
schemas.Compute.from_orm(db_compute),
exclude_unset=True,
exclude={"created_at", "updated_at"})
except ValidationError as e:
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
continue
computes.append(compute)
return computes

View File

@ -15,12 +15,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Union
from uuid import UUID
from uuid import UUID, uuid4
from enum import Enum
from .nodes import NodeType
from .base import DateTimeModelMixin
class Protocol(str, Enum):
@ -37,12 +38,11 @@ class ComputeBase(BaseModel):
Data to create a compute.
"""
compute_id: Optional[Union[str, UUID]] = None
name: Optional[str] = None
protocol: Protocol
host: str
port: int = Field(..., gt=0, le=65535)
user: Optional[str] = None
name: Optional[str] = None
class ComputeCreate(ComputeBase):
@ -50,6 +50,7 @@ class ComputeCreate(ComputeBase):
Data to create a compute.
"""
compute_id: Union[str, UUID] = Field(default_factory=uuid4)
password: Optional[str] = None
class Config:
@ -63,6 +64,24 @@ class ComputeCreate(ComputeBase):
}
}
@validator("name", always=True)
def generate_name(cls, name, values):
if name is not None:
return name
else:
protocol = values.get("protocol")
host = values.get("host")
port = values.get("port")
user = values.get("user")
if user:
# due to random user generated by 1.4 it's common to have a very long user
if len(user) > 14:
user = user[:11] + "..."
return "{}://{}@{}:{}".format(protocol, user, host, port)
else:
return "{}://{}:{}".format(protocol, host, port)
class ComputeUpdate(ComputeBase):
"""
@ -96,7 +115,7 @@ class Capabilities(BaseModel):
disk_size: int = Field(..., description="Disk size on this compute")
class Compute(ComputeBase):
class Compute(DateTimeModelMixin, ComputeBase):
"""
Data returned for a compute.
"""
@ -110,6 +129,9 @@ class Compute(ComputeBase):
last_error: Optional[str] = Field(None, description="Last error found on the compute")
capabilities: Optional[Capabilities] = None
class Config:
orm_mode = True
class AutoIdlePC(BaseModel):
"""

View File

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2021 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from uuid import UUID
from typing import List, Union
from gns3server import schemas
import gns3server.db.models as models
from gns3server.db.repositories.computes import ComputesRepository
from gns3server.controller import Controller
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerForbiddenError
)
class ComputesService:
def __init__(self, computes_repo: ComputesRepository):
self._computes_repo = computes_repo
self._controller = Controller.instance()
async def get_computes(self) -> List[models.Compute]:
db_computes = await self._computes_repo.get_computes()
return db_computes
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
if await self._computes_repo.get_compute(compute_create.compute_id):
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
db_compute = await self._computes_repo.create_compute(compute_create)
await self._controller.add_compute(compute_id=str(db_compute.compute_id),
connect=False,
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}))
self._controller.notification.controller_emit("compute.created", db_compute.asjson())
return db_compute
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
db_compute = await self._computes_repo.get_compute(compute_id)
if not db_compute:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
return db_compute
async def update_compute(
self,
compute_id: Union[str, UUID],
compute_update: schemas.ComputeUpdate
) -> models.Compute:
compute = self._controller.get_compute(str(compute_id))
await compute.update(**compute_update.dict(exclude_unset=True))
db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
if not db_compute:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
self._controller.notification.controller_emit("compute.updated", db_compute.asjson())
return db_compute
async def delete_compute(self, compute_id: Union[str, UUID]) -> None:
if await self._computes_repo.delete_compute(compute_id):
await self._controller.delete_compute(str(compute_id))
self._controller.notification.controller_emit("compute.deleted", {"compute_id": str(compute_id)})
else:
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")

View File

@ -15,12 +15,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import uuid
import pytest
from fastapi import FastAPI, status
from httpx import AsyncClient
from gns3server.controller import Controller
from gns3server.schemas.computes import Compute
pytestmark = pytest.mark.asyncio
@ -28,234 +29,167 @@ import unittest
from tests.utils import asyncio_patch
async def test_compute_create_without_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
class TestComputeRoutes:
params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"}
async def test_compute_create(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
response_content = response.json()
assert response_content["user"] == "julien"
assert response_content["compute_id"] is not None
assert "password" not in response_content
assert len(controller.computes) == 1
assert controller.computes[response_content["compute_id"]].host == "localhost"
params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["compute_id"] is not None
del params["password"]
for param, value in params.items():
assert response.json()[param] == value
async def test_compute_create_with_id(self, app: FastAPI, client: AsyncClient) -> None:
compute_id = str(uuid.uuid4())
params = {
"compute_id": compute_id,
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["compute_id"] == compute_id
del params["password"]
for param, value in params.items():
assert response.json()[param] == value
async def test_compute_list(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_computes"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) > 0
async def test_compute_get(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
response = await client.get(app.url_path_for("get_compute", compute_id=test_compute.compute_id))
assert response.status_code == status.HTTP_200_OK
assert response.json()["compute_id"] == str(test_compute.compute_id)
async def test_compute_update(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
compute_id = response.json()["compute_id"]
params["protocol"] = "https"
response = await client.put(app.url_path_for("update_compute", compute_id=compute_id), json=params)
assert response.status_code == status.HTTP_200_OK
del params["password"]
for param, value in params.items():
assert response.json()[param] == value
async def test_compute_delete(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
response = await client.delete(app.url_path_for("delete_compute", compute_id=test_compute.compute_id))
assert response.status_code == status.HTTP_204_NO_CONTENT
async def test_compute_create_with_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
class TestComputeFeatures:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"}
async def test_compute_list_images(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["user"] == "julien"
assert "password" not in response.json()
assert len(controller.computes) == 1
assert controller.computes["my_compute_id"].host == "localhost"
params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
compute_id = response.json()["compute_id"]
async def test_compute_get(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock:
response = await client.get(app.url_path_for("delete_compute", compute_id=compute_id) + "/qemu/images")
assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]
mock.assert_called_with("qemu")
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
async def test_compute_list_vms(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("get_computes"), json=params)
assert response.status_code == status.HTTP_201_CREATED
compute_id = response.json()["compute_id"]
response = await client.get(app.url_path_for("update_compute", compute_id="my_compute_id"))
assert response.status_code == status.HTTP_200_OK
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
response = await client.get(app.url_path_for("get_compute", compute_id=compute_id) + "/virtualbox/vms")
mock.assert_called_with("GET", "virtualbox", "vms")
assert response.json() == []
async def test_compute_create_img(self, app: FastAPI, client: AsyncClient) -> None:
async def test_compute_update(app: FastAPI, client: AsyncClient) -> None:
params = {
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("get_computes"), json=params)
assert response.status_code == status.HTTP_201_CREATED
compute_id = response.json()["compute_id"]
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id"))
assert response.status_code == status.HTTP_200_OK
assert response.json()["protocol"] == "http"
params = {"path": "/test"}
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
response = await client.post(app.url_path_for("get_compute", compute_id=compute_id) + "/qemu/img", json=params)
assert response.json() == []
mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY)
params["protocol"] = "https"
response = await client.put(app.url_path_for("update_compute", compute_id="my_compute_id"), json=params)
assert response.status_code == status.HTTP_200_OK
assert response.json()["protocol"] == "https"
async def test_compute_list(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure",
"name": "My super server"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
assert response.json()["user"] == "julien"
assert "password" not in response.json()
response = await client.get(app.url_path_for("get_computes"))
for compute in response.json():
if compute['compute_id'] != 'local':
assert compute == {
'compute_id': 'my_compute_id',
'connected': False,
'host': 'localhost',
'port': 84,
'protocol': 'http',
'user': 'julien',
'name': 'My super server',
'cpu_usage_percent': 0.0,
'memory_usage_percent': 0.0,
'disk_usage_percent': 0.0,
'last_error': None,
'capabilities': {
'version': '',
'platform': '',
'cpus': 0,
'memory': 0,
'disk_size': 0,
'node_types': []
}
}
async def test_compute_delete(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
response = await client.get(app.url_path_for("get_computes"))
assert len(response.json()) == 1
response = await client.delete(app.url_path_for("delete_compute", compute_id="my_compute_id"))
assert response.status_code == status.HTTP_204_NO_CONTENT
response = await client.get(app.url_path_for("get_computes"))
assert len(response.json()) == 0
async def test_compute_list_images(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("create_compute"), json=params)
assert response.status_code == status.HTTP_201_CREATED
with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock:
response = await client.get(app.url_path_for("delete_compute", compute_id="my_compute_id") + "/qemu/images")
assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]
mock.assert_called_with("qemu")
async def test_compute_list_vms(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("get_computes"), json=params)
assert response.status_code == status.HTTP_201_CREATED
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id") + "/virtualbox/vms")
mock.assert_called_with("GET", "virtualbox", "vms")
assert response.json() == []
async def test_compute_create_img(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
response = await client.post(app.url_path_for("get_computes"), json=params)
assert response.status_code == status.HTTP_201_CREATED
params = {"path": "/test"}
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/qemu/img", json=params)
assert response.json() == []
mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY)
async def test_compute_autoidlepc(app: FastAPI, client: AsyncClient) -> None:
params = {
"compute_id": "my_compute_id",
"protocol": "http",
"host": "localhost",
"port": 84,
"user": "julien",
"password": "secure"
}
await client.post(app.url_path_for("get_computes"), json=params)
params = {
"platform": "c7200",
"image": "test.bin",
"ram": 512
}
with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock:
response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/auto_idlepc", json=params)
assert mock.called
assert response.status_code == status.HTTP_200_OK
# async def test_compute_autoidlepc(self, app: FastAPI, client: AsyncClient) -> None:
#
# params = {
# "protocol": "http",
# "host": "localhost",
# "port": 84,
# "user": "julien",
# "password": "secure"
# }
#
# response = await client.post(app.url_path_for("get_computes"), json=params)
# assert response.status_code == status.HTTP_201_CREATED
# compute_id = response.json()["compute_id"]
#
# params = {
# "platform": "c7200",
# "image": "test.bin",
# "ram": 512
# }
#
# with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock:
# response = await client.post(app.url_path_for("autoidlepc", compute_id=compute_id) + "/auto_idlepc", json=params)
# assert mock.called
# assert response.status_code == status.HTTP_200_OK
# FIXME

View File

@ -4,6 +4,7 @@ import tempfile
import shutil
import sys
import os
import uuid
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@ -16,10 +17,12 @@ from gns3server.config import Config
from gns3server.compute import MODULES
from gns3server.compute.port_manager import PortManager
from gns3server.compute.project_manager import ProjectManager
from gns3server.db.models import Base, User
from gns3server.db.models import Base, User, Compute
from gns3server.db.repositories.users import UsersRepository
from gns3server.db.repositories.computes import ComputesRepository
from gns3server.api.routes.controller.dependencies.database import get_db_session
from gns3server.schemas.users import UserCreate
from gns3server import schemas
from gns3server.schemas.computes import Protocol
from gns3server.services import auth_service
sys._called_from_test = True
@ -27,7 +30,7 @@ sys.original_platform = sys.platform
if sys.platform.startswith("win") and sys.version_info < (3, 8):
@pytest.yield_fixture(scope="session")
@pytest.fixture(scope="session")
def event_loop(request):
"""
Overwrite pytest_asyncio event loop on Windows for Python < 3.8
@ -43,7 +46,7 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8):
# https://github.com/pytest-dev/pytest-asyncio/issues/68
# this event_loop is used by pytest-asyncio, and redefining it
# is currently the only way of changing the scope of this fixture
@pytest.yield_fixture(scope="class")
@pytest.fixture(scope="class")
def event_loop(request):
loop = asyncio.get_event_loop_policy().new_event_loop()
@ -54,9 +57,6 @@ def event_loop(request):
@pytest.fixture(scope="class")
async def app() -> FastAPI:
# async with db_engine.begin() as conn:
# await conn.run_sync(Base.metadata.drop_all)
# await conn.run_sync(Base.metadata.create_all)
from gns3server.api.server import app as gns3app
yield gns3app
@ -109,7 +109,7 @@ async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient:
@pytest.fixture
async def test_user(db_session: AsyncSession) -> User:
new_user = UserCreate(
new_user = schemas.UserCreate(
username="user1",
email="user1@email.com",
password="user1_password",
@ -121,6 +121,25 @@ async def test_user(db_session: AsyncSession) -> User:
return await user_repo.create_user(new_user)
@pytest.fixture
async def test_compute(db_session: AsyncSession) -> Compute:
new_compute = schemas.ComputeCreate(
compute_id=uuid.uuid4(),
protocol=Protocol.http,
host="localhost",
port=4242,
user="julien",
password="secure"
)
compute_repo = ComputesRepository(db_session)
existing_compute = await compute_repo.get_compute(new_compute.compute_id)
if existing_compute:
return existing_compute
return await compute_repo.create_compute(new_compute)
@pytest.fixture
def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient: