mirror of
https://github.com/GNS3/gns3-server.git
synced 2025-02-01 05:43:49 +02:00
Save computes to database
This commit is contained in:
parent
e607793e74
commit
566e326b57
@ -19,20 +19,23 @@
|
||||
API routes for computes.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from typing import List, Union
|
||||
from uuid import UUID
|
||||
|
||||
from gns3server.controller import Controller
|
||||
from gns3server.db.repositories.computes import ComputesRepository
|
||||
from gns3server.services.computes import ComputesService
|
||||
from gns3server import schemas
|
||||
|
||||
router = APIRouter()
|
||||
from .dependencies.database import get_repository
|
||||
|
||||
responses = {
|
||||
404: {"model": schemas.ErrorMessage, "description": "Compute not found"}
|
||||
}
|
||||
|
||||
router = APIRouter(responses=responses)
|
||||
|
||||
|
||||
@router.post("",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
@ -40,69 +43,73 @@ responses = {
|
||||
responses={404: {"model": schemas.ErrorMessage, "description": "Could not connect to compute"},
|
||||
409: {"model": schemas.ErrorMessage, "description": "Could not create compute"},
|
||||
401: {"model": schemas.ErrorMessage, "description": "Invalid authentication for compute"}})
|
||||
async def create_compute(compute_data: schemas.ComputeCreate):
|
||||
async def create_compute(
|
||||
compute_create: schemas.ComputeCreate,
|
||||
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
|
||||
) -> schemas.Compute:
|
||||
"""
|
||||
Create a new compute on the controller.
|
||||
"""
|
||||
|
||||
compute = await Controller.instance().add_compute(**jsonable_encoder(compute_data, exclude_unset=True),
|
||||
connect=False)
|
||||
return compute.__json__()
|
||||
return await ComputesService(computes_repo).create_compute(compute_create)
|
||||
|
||||
|
||||
@router.get("/{compute_id}",
|
||||
response_model=schemas.Compute,
|
||||
response_model_exclude_unset=True,
|
||||
responses=responses)
|
||||
def get_compute(compute_id: Union[str, UUID]):
|
||||
response_model_exclude_unset=True)
|
||||
async def get_compute(
|
||||
compute_id: Union[str, UUID],
|
||||
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
|
||||
) -> schemas.Compute:
|
||||
"""
|
||||
Return a compute from the controller.
|
||||
"""
|
||||
|
||||
compute = Controller.instance().get_compute(str(compute_id))
|
||||
return compute.__json__()
|
||||
return await ComputesService(computes_repo).get_compute(compute_id)
|
||||
|
||||
|
||||
@router.get("",
|
||||
response_model=List[schemas.Compute],
|
||||
response_model_exclude_unset=True)
|
||||
async def get_computes():
|
||||
async def get_computes(
|
||||
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
|
||||
) -> List[schemas.Compute]:
|
||||
"""
|
||||
Return all computes known by the controller.
|
||||
"""
|
||||
|
||||
controller = Controller.instance()
|
||||
return [c.__json__() for c in controller.computes.values()]
|
||||
return await ComputesService(computes_repo).get_computes()
|
||||
|
||||
|
||||
@router.put("/{compute_id}",
|
||||
response_model=schemas.Compute,
|
||||
response_model_exclude_unset=True,
|
||||
responses=responses)
|
||||
async def update_compute(compute_id: Union[str, UUID], compute_data: schemas.ComputeUpdate):
|
||||
response_model_exclude_unset=True)
|
||||
async def update_compute(
|
||||
compute_id: Union[str, UUID],
|
||||
compute_update: schemas.ComputeUpdate,
|
||||
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
|
||||
) -> schemas.Compute:
|
||||
"""
|
||||
Update a compute on the controller.
|
||||
"""
|
||||
|
||||
compute = Controller.instance().get_compute(str(compute_id))
|
||||
# exclude compute_id because we only use it when creating a new compute
|
||||
await compute.update(**jsonable_encoder(compute_data, exclude_unset=True, exclude={"compute_id"}))
|
||||
return compute.__json__()
|
||||
return await ComputesService(computes_repo).update_compute(compute_id, compute_update)
|
||||
|
||||
|
||||
@router.delete("/{compute_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses=responses)
|
||||
async def delete_compute(compute_id: Union[str, UUID]):
|
||||
status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_compute(
|
||||
compute_id: Union[str, UUID],
|
||||
computes_repo: ComputesRepository = Depends(get_repository(ComputesRepository))
|
||||
):
|
||||
"""
|
||||
Delete a compute from the controller.
|
||||
"""
|
||||
|
||||
await Controller.instance().delete_compute(str(compute_id))
|
||||
await ComputesService(computes_repo).delete_compute(compute_id)
|
||||
|
||||
|
||||
@router.get("/{compute_id}/{emulator}/images",
|
||||
responses=responses)
|
||||
@router.get("/{compute_id}/{emulator}/images")
|
||||
async def get_images(compute_id: Union[str, UUID], emulator: str):
|
||||
"""
|
||||
Return the list of images available on a compute for a given emulator type.
|
||||
@ -113,8 +120,7 @@ async def get_images(compute_id: Union[str, UUID], emulator: str):
|
||||
return await compute.images(emulator)
|
||||
|
||||
|
||||
@router.get("/{compute_id}/{emulator}/{endpoint_path:path}",
|
||||
responses=responses)
|
||||
@router.get("/{compute_id}/{emulator}/{endpoint_path:path}")
|
||||
async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path: str):
|
||||
"""
|
||||
Forward a GET request to a compute.
|
||||
@ -126,8 +132,7 @@ async def forward_get(compute_id: Union[str, UUID], emulator: str, endpoint_path
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/{compute_id}/{emulator}/{endpoint_path:path}",
|
||||
responses=responses)
|
||||
@router.post("/{compute_id}/{emulator}/{endpoint_path:path}")
|
||||
async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict):
|
||||
"""
|
||||
Forward a POST request to a compute.
|
||||
@ -138,8 +143,7 @@ async def forward_post(compute_id: Union[str, UUID], emulator: str, endpoint_pat
|
||||
return await compute.forward("POST", emulator, endpoint_path, data=compute_data)
|
||||
|
||||
|
||||
@router.put("/{compute_id}/{emulator}/{endpoint_path:path}",
|
||||
responses=responses)
|
||||
@router.put("/{compute_id}/{emulator}/{endpoint_path:path}")
|
||||
async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path: str, compute_data: dict):
|
||||
"""
|
||||
Forward a PUT request to a compute.
|
||||
@ -150,8 +154,7 @@ async def forward_put(compute_id: Union[str, UUID], emulator: str, endpoint_path
|
||||
return await compute.forward("PUT", emulator, endpoint_path, data=compute_data)
|
||||
|
||||
|
||||
@router.post("/{compute_id}/auto_idlepc",
|
||||
responses=responses)
|
||||
@router.post("/{compute_id}/auto_idlepc")
|
||||
async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdlePC):
|
||||
"""
|
||||
Find a suitable Idle-PC value for a given IOS image. This may take a few minutes.
|
||||
@ -162,14 +165,3 @@ async def autoidlepc(compute_id: Union[str, UUID], auto_idle_pc: schemas.AutoIdl
|
||||
auto_idle_pc.platform,
|
||||
auto_idle_pc.image,
|
||||
auto_idle_pc.ram)
|
||||
|
||||
|
||||
@router.get("/{compute_id}/ports",
|
||||
deprecated=True,
|
||||
responses=responses)
|
||||
async def ports(compute_id: Union[str, UUID]):
|
||||
"""
|
||||
Return ports information for a given compute.
|
||||
"""
|
||||
|
||||
return await Controller.instance().compute_ports(str(compute_id))
|
||||
|
@ -44,43 +44,42 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[schemas.User])
|
||||
async def get_users(user_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
|
||||
async def get_users(users_repo: UsersRepository = Depends(get_repository(UsersRepository))) -> List[schemas.User]:
|
||||
"""
|
||||
Get all users.
|
||||
"""
|
||||
|
||||
users = await user_repo.get_users()
|
||||
return users
|
||||
return await users_repo.get_users()
|
||||
|
||||
|
||||
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
new_user: schemas.UserCreate,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
user_create: schemas.UserCreate,
|
||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
) -> schemas.User:
|
||||
"""
|
||||
Create a new user.
|
||||
"""
|
||||
|
||||
if await user_repo.get_user_by_username(new_user.username):
|
||||
raise ControllerBadRequestError(f"Username '{new_user.username}' is already registered")
|
||||
if await users_repo.get_user_by_username(user_create.username):
|
||||
raise ControllerBadRequestError(f"Username '{user_create.username}' is already registered")
|
||||
|
||||
if new_user.email and await user_repo.get_user_by_email(new_user.email):
|
||||
raise ControllerBadRequestError(f"Email '{new_user.email}' is already registered")
|
||||
if user_create.email and await users_repo.get_user_by_email(user_create.email):
|
||||
raise ControllerBadRequestError(f"Email '{user_create.email}' is already registered")
|
||||
|
||||
return await user_repo.create_user(new_user)
|
||||
return await users_repo.create_user(user_create)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=schemas.User)
|
||||
async def get_user(
|
||||
user_id: UUID,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
) -> schemas.User:
|
||||
"""
|
||||
Get an user.
|
||||
"""
|
||||
|
||||
user = await user_repo.get_user(user_id)
|
||||
user = await users_repo.get_user(user_id)
|
||||
if not user:
|
||||
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
||||
return user
|
||||
@ -89,14 +88,14 @@ async def get_user(
|
||||
@router.put("/{user_id}", response_model=schemas.User)
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
update_user: schemas.UserUpdate,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
user_update: schemas.UserUpdate,
|
||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
|
||||
) -> schemas.User:
|
||||
"""
|
||||
Update an user.
|
||||
"""
|
||||
|
||||
user = await user_repo.update_user(user_id, update_user)
|
||||
user = await users_repo.update_user(user_id, user_update)
|
||||
if not user:
|
||||
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
||||
return user
|
||||
@ -105,7 +104,7 @@ async def update_user(
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||
current_user: schemas.User = Depends(get_current_active_user)
|
||||
) -> None:
|
||||
"""
|
||||
@ -115,21 +114,21 @@ async def delete_user(
|
||||
if current_user.is_superuser:
|
||||
raise ControllerUnauthorizedError("The super user cannot be deleted")
|
||||
|
||||
success = await user_repo.delete_user(user_id)
|
||||
success = await users_repo.delete_user(user_id)
|
||||
if not success:
|
||||
raise ControllerNotFoundError(f"User '{user_id}' not found")
|
||||
|
||||
|
||||
@router.post("/login", response_model=schemas.Token)
|
||||
async def login(
|
||||
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||||
form_data: OAuth2PasswordRequestForm = Depends()
|
||||
) -> schemas.Token:
|
||||
"""
|
||||
User login.
|
||||
"""
|
||||
|
||||
user = await user_repo.authenticate_user(username=form_data.username, password=form_data.password)
|
||||
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication was unsuccessful.",
|
||||
|
@ -37,6 +37,7 @@ from ..utils.get_resource import get_resource
|
||||
from .gns3vm.gns3_vm_error import GNS3VMError
|
||||
from .controller_error import ControllerError, ControllerNotFoundError
|
||||
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -47,6 +48,7 @@ class Controller:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self._computes = {}
|
||||
self._projects = {}
|
||||
self._notification = Notification(self)
|
||||
@ -59,7 +61,7 @@ class Controller:
|
||||
self._config_file = Config.instance().controller_config
|
||||
log.info("Load controller configuration file {}".format(self._config_file))
|
||||
|
||||
async def start(self):
|
||||
async def start(self, computes):
|
||||
|
||||
log.info("Controller is starting")
|
||||
self.load_base_files()
|
||||
@ -78,7 +80,7 @@ class Controller:
|
||||
if name == "gns3vm":
|
||||
name = "Main server"
|
||||
|
||||
computes = self._load_controller_settings()
|
||||
self._load_controller_settings()
|
||||
|
||||
ssl_context = None
|
||||
if server_config.getboolean("ssl"):
|
||||
@ -198,22 +200,20 @@ class Controller:
|
||||
if self._config_loaded is False:
|
||||
return
|
||||
|
||||
controller_settings = {"computes": [],
|
||||
"templates": [],
|
||||
"gns3vm": self.gns3vm.__json__(),
|
||||
controller_settings = {"gns3vm": self.gns3vm.__json__(),
|
||||
"iou_license": self._iou_license_settings,
|
||||
"appliances_etag": self._appliance_manager.appliances_etag,
|
||||
"version": __version__}
|
||||
|
||||
for compute in self._computes.values():
|
||||
if compute.id != "local" and compute.id != "vm":
|
||||
controller_settings["computes"].append({"host": compute.host,
|
||||
"name": compute.name,
|
||||
"port": compute.port,
|
||||
"protocol": compute.protocol,
|
||||
"user": compute.user,
|
||||
"password": compute.password,
|
||||
"compute_id": compute.id})
|
||||
# for compute in self._computes.values():
|
||||
# if compute.id != "local" and compute.id != "vm":
|
||||
# controller_settings["computes"].append({"host": compute.host,
|
||||
# "name": compute.name,
|
||||
# "port": compute.port,
|
||||
# "protocol": compute.protocol,
|
||||
# "user": compute.user,
|
||||
# "password": compute.password,
|
||||
# "compute_id": compute.id})
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self._config_file), exist_ok=True)
|
||||
@ -584,14 +584,3 @@ class Controller:
|
||||
await project.delete()
|
||||
self.remove_project(project)
|
||||
return res
|
||||
|
||||
async def compute_ports(self, compute_id):
|
||||
"""
|
||||
Get the ports used by a compute.
|
||||
|
||||
:param compute_id: ID of the compute
|
||||
"""
|
||||
|
||||
compute = self.get_compute(compute_id)
|
||||
response = await compute.get("/network/ports")
|
||||
return response.json
|
||||
|
@ -70,10 +70,10 @@ class Compute:
|
||||
assert controller is not None
|
||||
log.info("Create compute %s", compute_id)
|
||||
|
||||
if compute_id is None:
|
||||
self._id = str(uuid.uuid4())
|
||||
else:
|
||||
self._id = compute_id
|
||||
# if compute_id is None:
|
||||
# self._id = str(uuid.uuid4())
|
||||
# else:
|
||||
self._id = compute_id
|
||||
|
||||
self.protocol = protocol
|
||||
self._console_host = console_host
|
||||
@ -181,17 +181,8 @@ class Compute:
|
||||
|
||||
@name.setter
|
||||
def name(self, name):
|
||||
if name is not None:
|
||||
self._name = name
|
||||
else:
|
||||
if self._user:
|
||||
user = self._user
|
||||
# Due to random user generated by 1.4 it's common to have a very long user
|
||||
if len(user) > 14:
|
||||
user = user[:11] + "..."
|
||||
self._name = "{}://{}@{}:{}".format(self._protocol, user, self._host, self._port)
|
||||
else:
|
||||
self._name = "{}://{}:{}".format(self._protocol, self._host, self._port)
|
||||
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
|
@ -25,8 +25,7 @@ from gns3server.controller import Controller
|
||||
from gns3server.compute import MODULES
|
||||
from gns3server.compute.port_manager import PortManager
|
||||
from gns3server.utils.http_client import HTTPClient
|
||||
from gns3server.db.tasks import connect_to_db
|
||||
|
||||
from gns3server.db.tasks import connect_to_db, get_computes
|
||||
|
||||
import logging
|
||||
log = logging.getLogger(__name__)
|
||||
@ -57,11 +56,14 @@ def create_startup_handler(app: FastAPI) -> Callable:
|
||||
# connect to the database
|
||||
await connect_to_db(app)
|
||||
|
||||
await Controller.instance().start()
|
||||
# retrieve the computes from the database
|
||||
computes = await get_computes(app)
|
||||
|
||||
await Controller.instance().start(computes)
|
||||
|
||||
# Because with a large image collection
|
||||
# without md5sum already computed we start the
|
||||
# computing with server start
|
||||
|
||||
from gns3server.compute.qemu import Qemu
|
||||
asyncio.ensure_future(Qemu.instance().list_images())
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
from .base import Base
|
||||
from .users import User
|
||||
from .computes import Compute
|
||||
from .templates import (
|
||||
Template,
|
||||
CloudTemplate,
|
||||
|
33
gns3server/db/models/computes.py
Normal file
33
gns3server/db/models/computes.py
Normal file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (C) 2021 GNS3 Technologies Inc.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from sqlalchemy import Column, String
|
||||
|
||||
from .base import BaseTable, GUID
|
||||
|
||||
|
||||
class Compute(BaseTable):
|
||||
|
||||
__tablename__ = "computes"
|
||||
|
||||
compute_id = Column(GUID, primary_key=True)
|
||||
name = Column(String, index=True)
|
||||
protocol = Column(String)
|
||||
host = Column(String)
|
||||
port = Column(String)
|
||||
user = Column(String)
|
||||
password = Column(String)
|
88
gns3server/db/repositories/computes.py
Normal file
88
gns3server/db/repositories/computes.py
Normal file
@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (C) 2021 GNS3 Technologies Inc.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from uuid import UUID
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .base import BaseRepository
|
||||
|
||||
import gns3server.db.models as models
|
||||
from gns3server.services import auth_service
|
||||
from gns3server import schemas
|
||||
|
||||
|
||||
class ComputesRepository(BaseRepository):
|
||||
|
||||
def __init__(self, db_session: AsyncSession) -> None:
|
||||
|
||||
super().__init__(db_session)
|
||||
self._auth_service = auth_service
|
||||
|
||||
async def get_compute(self, compute_id: UUID) -> Optional[models.Compute]:
|
||||
|
||||
query = select(models.Compute).where(models.Compute.compute_id == compute_id)
|
||||
result = await self._db_session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_compute_by_name(self, name: str) -> Optional[models.Compute]:
|
||||
|
||||
query = select(models.Compute).where(models.Compute.name == name)
|
||||
result = await self._db_session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_computes(self) -> List[models.Compute]:
|
||||
|
||||
query = select(models.Compute)
|
||||
result = await self._db_session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
|
||||
|
||||
db_compute = models.Compute(
|
||||
compute_id=compute_create.compute_id,
|
||||
name=compute_create.name,
|
||||
protocol=compute_create.protocol.value,
|
||||
host=compute_create.host,
|
||||
port=compute_create.port,
|
||||
user=compute_create.user,
|
||||
password=compute_create.password
|
||||
)
|
||||
self._db_session.add(db_compute)
|
||||
await self._db_session.commit()
|
||||
await self._db_session.refresh(db_compute)
|
||||
return db_compute
|
||||
|
||||
async def update_compute(self, compute_id: UUID, compute_update: schemas.ComputeUpdate) -> Optional[models.Compute]:
|
||||
|
||||
update_values = compute_update.dict(exclude_unset=True)
|
||||
|
||||
query = update(models.Compute) \
|
||||
.where(models.Compute.compute_id == compute_id) \
|
||||
.values(update_values)
|
||||
|
||||
await self._db_session.execute(query)
|
||||
await self._db_session.commit()
|
||||
return await self.get_compute(compute_id)
|
||||
|
||||
async def delete_compute(self, compute_id: UUID) -> bool:
|
||||
|
||||
query = delete(models.Compute).where(models.Compute.compute_id == compute_id)
|
||||
result = await self._db_session.execute(query)
|
||||
await self._db_session.commit()
|
||||
return result.rowcount > 0
|
@ -18,8 +18,14 @@
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import ValidationError
|
||||
|
||||
from typing import List
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from gns3server.db.repositories.computes import ComputesRepository
|
||||
from gns3server import schemas
|
||||
|
||||
from .models import Base
|
||||
from gns3server.config import Config
|
||||
@ -40,3 +46,21 @@ async def connect_to_db(app: FastAPI) -> None:
|
||||
app.state._db_engine = engine
|
||||
except SQLAlchemyError as e:
|
||||
log.error(f"Error while connecting to database '{db_url}: {e}")
|
||||
|
||||
|
||||
async def get_computes(app: FastAPI) -> List[dict]:
|
||||
|
||||
computes = []
|
||||
async with AsyncSession(app.state._db_engine) as db_session:
|
||||
db_computes = await ComputesRepository(db_session).get_computes()
|
||||
for db_compute in db_computes:
|
||||
try:
|
||||
compute = jsonable_encoder(
|
||||
schemas.Compute.from_orm(db_compute),
|
||||
exclude_unset=True,
|
||||
exclude={"created_at", "updated_at"})
|
||||
except ValidationError as e:
|
||||
log.error(f"Could not load compute '{db_compute.compute_id}' from database: {e}")
|
||||
continue
|
||||
computes.append(compute)
|
||||
return computes
|
||||
|
@ -15,12 +15,13 @@
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional, Union
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
from enum import Enum
|
||||
|
||||
from .nodes import NodeType
|
||||
from .base import DateTimeModelMixin
|
||||
|
||||
|
||||
class Protocol(str, Enum):
|
||||
@ -37,12 +38,11 @@ class ComputeBase(BaseModel):
|
||||
Data to create a compute.
|
||||
"""
|
||||
|
||||
compute_id: Optional[Union[str, UUID]] = None
|
||||
name: Optional[str] = None
|
||||
protocol: Protocol
|
||||
host: str
|
||||
port: int = Field(..., gt=0, le=65535)
|
||||
user: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ComputeCreate(ComputeBase):
|
||||
@ -50,6 +50,7 @@ class ComputeCreate(ComputeBase):
|
||||
Data to create a compute.
|
||||
"""
|
||||
|
||||
compute_id: Union[str, UUID] = Field(default_factory=uuid4)
|
||||
password: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
@ -63,6 +64,24 @@ class ComputeCreate(ComputeBase):
|
||||
}
|
||||
}
|
||||
|
||||
@validator("name", always=True)
|
||||
def generate_name(cls, name, values):
|
||||
|
||||
if name is not None:
|
||||
return name
|
||||
else:
|
||||
protocol = values.get("protocol")
|
||||
host = values.get("host")
|
||||
port = values.get("port")
|
||||
user = values.get("user")
|
||||
if user:
|
||||
# due to random user generated by 1.4 it's common to have a very long user
|
||||
if len(user) > 14:
|
||||
user = user[:11] + "..."
|
||||
return "{}://{}@{}:{}".format(protocol, user, host, port)
|
||||
else:
|
||||
return "{}://{}:{}".format(protocol, host, port)
|
||||
|
||||
|
||||
class ComputeUpdate(ComputeBase):
|
||||
"""
|
||||
@ -96,7 +115,7 @@ class Capabilities(BaseModel):
|
||||
disk_size: int = Field(..., description="Disk size on this compute")
|
||||
|
||||
|
||||
class Compute(ComputeBase):
|
||||
class Compute(DateTimeModelMixin, ComputeBase):
|
||||
"""
|
||||
Data returned for a compute.
|
||||
"""
|
||||
@ -110,6 +129,9 @@ class Compute(ComputeBase):
|
||||
last_error: Optional[str] = Field(None, description="Last error found on the compute")
|
||||
capabilities: Optional[Capabilities] = None
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class AutoIdlePC(BaseModel):
|
||||
"""
|
||||
|
84
gns3server/services/computes.py
Normal file
84
gns3server/services/computes.py
Normal file
@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright (C) 2021 GNS3 Technologies Inc.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
from uuid import UUID
|
||||
from typing import List, Union
|
||||
|
||||
from gns3server import schemas
|
||||
import gns3server.db.models as models
|
||||
|
||||
from gns3server.db.repositories.computes import ComputesRepository
|
||||
from gns3server.controller import Controller
|
||||
from gns3server.controller.controller_error import (
|
||||
ControllerBadRequestError,
|
||||
ControllerNotFoundError,
|
||||
ControllerForbiddenError
|
||||
)
|
||||
|
||||
|
||||
class ComputesService:
|
||||
|
||||
def __init__(self, computes_repo: ComputesRepository):
|
||||
|
||||
self._computes_repo = computes_repo
|
||||
self._controller = Controller.instance()
|
||||
|
||||
async def get_computes(self) -> List[models.Compute]:
|
||||
|
||||
db_computes = await self._computes_repo.get_computes()
|
||||
return db_computes
|
||||
|
||||
async def create_compute(self, compute_create: schemas.ComputeCreate) -> models.Compute:
|
||||
|
||||
if await self._computes_repo.get_compute(compute_create.compute_id):
|
||||
raise ControllerBadRequestError(f"Compute '{compute_create.compute_id}' is already registered")
|
||||
db_compute = await self._computes_repo.create_compute(compute_create)
|
||||
await self._controller.add_compute(compute_id=str(db_compute.compute_id),
|
||||
connect=False,
|
||||
**compute_create.dict(exclude_unset=True, exclude={"compute_id"}))
|
||||
self._controller.notification.controller_emit("compute.created", db_compute.asjson())
|
||||
return db_compute
|
||||
|
||||
async def get_compute(self, compute_id: Union[str, UUID]) -> models.Compute:
|
||||
|
||||
db_compute = await self._computes_repo.get_compute(compute_id)
|
||||
if not db_compute:
|
||||
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
|
||||
return db_compute
|
||||
|
||||
async def update_compute(
|
||||
self,
|
||||
compute_id: Union[str, UUID],
|
||||
compute_update: schemas.ComputeUpdate
|
||||
) -> models.Compute:
|
||||
|
||||
compute = self._controller.get_compute(str(compute_id))
|
||||
await compute.update(**compute_update.dict(exclude_unset=True))
|
||||
db_compute = await self._computes_repo.update_compute(compute_id, compute_update)
|
||||
if not db_compute:
|
||||
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
|
||||
self._controller.notification.controller_emit("compute.updated", db_compute.asjson())
|
||||
return db_compute
|
||||
|
||||
async def delete_compute(self, compute_id: Union[str, UUID]) -> None:
|
||||
|
||||
if await self._computes_repo.delete_compute(compute_id):
|
||||
await self._controller.delete_compute(str(compute_id))
|
||||
self._controller.notification.controller_emit("compute.deleted", {"compute_id": str(compute_id)})
|
||||
else:
|
||||
raise ControllerNotFoundError(f"Compute '{compute_id}' not found")
|
@ -15,12 +15,13 @@
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import uuid
|
||||
import pytest
|
||||
|
||||
from fastapi import FastAPI, status
|
||||
from httpx import AsyncClient
|
||||
|
||||
from gns3server.controller import Controller
|
||||
from gns3server.schemas.computes import Compute
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@ -28,234 +29,167 @@ import unittest
|
||||
from tests.utils import asyncio_patch
|
||||
|
||||
|
||||
async def test_compute_create_without_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
|
||||
class TestComputeRoutes:
|
||||
|
||||
params = {
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"}
|
||||
async def test_compute_create(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
response_content = response.json()
|
||||
assert response_content["user"] == "julien"
|
||||
assert response_content["compute_id"] is not None
|
||||
assert "password" not in response_content
|
||||
assert len(controller.computes) == 1
|
||||
assert controller.computes[response_content["compute_id"]].host == "localhost"
|
||||
params = {
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"}
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["compute_id"] is not None
|
||||
|
||||
del params["password"]
|
||||
for param, value in params.items():
|
||||
assert response.json()[param] == value
|
||||
|
||||
async def test_compute_create_with_id(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
compute_id = str(uuid.uuid4())
|
||||
params = {
|
||||
"compute_id": compute_id,
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"}
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["compute_id"] == compute_id
|
||||
|
||||
del params["password"]
|
||||
for param, value in params.items():
|
||||
assert response.json()[param] == value
|
||||
|
||||
async def test_compute_list(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
response = await client.get(app.url_path_for("get_computes"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert len(response.json()) > 0
|
||||
|
||||
async def test_compute_get(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
|
||||
|
||||
response = await client.get(app.url_path_for("get_compute", compute_id=test_compute.compute_id))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["compute_id"] == str(test_compute.compute_id)
|
||||
|
||||
async def test_compute_update(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
|
||||
|
||||
params = {
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
compute_id = response.json()["compute_id"]
|
||||
|
||||
params["protocol"] = "https"
|
||||
response = await client.put(app.url_path_for("update_compute", compute_id=compute_id), json=params)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
del params["password"]
|
||||
for param, value in params.items():
|
||||
assert response.json()[param] == value
|
||||
|
||||
async def test_compute_delete(self, app: FastAPI, client: AsyncClient, test_compute: Compute) -> None:
|
||||
|
||||
response = await client.delete(app.url_path_for("delete_compute", compute_id=test_compute.compute_id))
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
async def test_compute_create_with_id(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
|
||||
class TestComputeFeatures:
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute_id",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"}
|
||||
async def test_compute_list_images(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["user"] == "julien"
|
||||
assert "password" not in response.json()
|
||||
assert len(controller.computes) == 1
|
||||
assert controller.computes["my_compute_id"].host == "localhost"
|
||||
params = {
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
compute_id = response.json()["compute_id"]
|
||||
|
||||
async def test_compute_get(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
|
||||
with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock:
|
||||
response = await client.get(app.url_path_for("delete_compute", compute_id=compute_id) + "/qemu/images")
|
||||
assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]
|
||||
mock.assert_called_with("qemu")
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute_id",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
async def test_compute_list_vms(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
params = {
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
response = await client.post(app.url_path_for("get_computes"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
compute_id = response.json()["compute_id"]
|
||||
|
||||
response = await client.get(app.url_path_for("update_compute", compute_id="my_compute_id"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
|
||||
response = await client.get(app.url_path_for("get_compute", compute_id=compute_id) + "/virtualbox/vms")
|
||||
mock.assert_called_with("GET", "virtualbox", "vms")
|
||||
assert response.json() == []
|
||||
|
||||
async def test_compute_create_img(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
async def test_compute_update(app: FastAPI, client: AsyncClient) -> None:
|
||||
params = {
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute_id",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
response = await client.post(app.url_path_for("get_computes"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
compute_id = response.json()["compute_id"]
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id"))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["protocol"] == "http"
|
||||
params = {"path": "/test"}
|
||||
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
|
||||
response = await client.post(app.url_path_for("get_compute", compute_id=compute_id) + "/qemu/img", json=params)
|
||||
assert response.json() == []
|
||||
mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY)
|
||||
|
||||
params["protocol"] = "https"
|
||||
response = await client.put(app.url_path_for("update_compute", compute_id="my_compute_id"), json=params)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json()["protocol"] == "https"
|
||||
|
||||
|
||||
async def test_compute_list(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute_id",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure",
|
||||
"name": "My super server"
|
||||
}
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
assert response.json()["user"] == "julien"
|
||||
assert "password" not in response.json()
|
||||
|
||||
response = await client.get(app.url_path_for("get_computes"))
|
||||
for compute in response.json():
|
||||
if compute['compute_id'] != 'local':
|
||||
assert compute == {
|
||||
'compute_id': 'my_compute_id',
|
||||
'connected': False,
|
||||
'host': 'localhost',
|
||||
'port': 84,
|
||||
'protocol': 'http',
|
||||
'user': 'julien',
|
||||
'name': 'My super server',
|
||||
'cpu_usage_percent': 0.0,
|
||||
'memory_usage_percent': 0.0,
|
||||
'disk_usage_percent': 0.0,
|
||||
'last_error': None,
|
||||
'capabilities': {
|
||||
'version': '',
|
||||
'platform': '',
|
||||
'cpus': 0,
|
||||
'memory': 0,
|
||||
'disk_size': 0,
|
||||
'node_types': []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def test_compute_delete(app: FastAPI, client: AsyncClient, controller: Controller) -> None:
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute_id",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
|
||||
response = await client.get(app.url_path_for("get_computes"))
|
||||
assert len(response.json()) == 1
|
||||
|
||||
response = await client.delete(app.url_path_for("delete_compute", compute_id="my_compute_id"))
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
response = await client.get(app.url_path_for("get_computes"))
|
||||
assert len(response.json()) == 0
|
||||
|
||||
|
||||
async def test_compute_list_images(app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute_id",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
response = await client.post(app.url_path_for("create_compute"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
|
||||
with asyncio_patch("gns3server.controller.compute.Compute.images", return_value=[{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]) as mock:
|
||||
response = await client.get(app.url_path_for("delete_compute", compute_id="my_compute_id") + "/qemu/images")
|
||||
assert response.json() == [{"filename": "linux.qcow2"}, {"filename": "asav.qcow2"}]
|
||||
mock.assert_called_with("qemu")
|
||||
|
||||
|
||||
async def test_compute_list_vms(app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
response = await client.post(app.url_path_for("get_computes"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
|
||||
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
|
||||
response = await client.get(app.url_path_for("get_compute", compute_id="my_compute_id") + "/virtualbox/vms")
|
||||
mock.assert_called_with("GET", "virtualbox", "vms")
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
async def test_compute_create_img(app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
|
||||
response = await client.post(app.url_path_for("get_computes"), json=params)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
|
||||
params = {"path": "/test"}
|
||||
with asyncio_patch("gns3server.controller.compute.Compute.forward", return_value=[]) as mock:
|
||||
response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/qemu/img", json=params)
|
||||
assert response.json() == []
|
||||
mock.assert_called_with("POST", "qemu", "img", data=unittest.mock.ANY)
|
||||
|
||||
|
||||
async def test_compute_autoidlepc(app: FastAPI, client: AsyncClient) -> None:
|
||||
|
||||
params = {
|
||||
"compute_id": "my_compute_id",
|
||||
"protocol": "http",
|
||||
"host": "localhost",
|
||||
"port": 84,
|
||||
"user": "julien",
|
||||
"password": "secure"
|
||||
}
|
||||
|
||||
await client.post(app.url_path_for("get_computes"), json=params)
|
||||
|
||||
params = {
|
||||
"platform": "c7200",
|
||||
"image": "test.bin",
|
||||
"ram": 512
|
||||
}
|
||||
|
||||
with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock:
|
||||
response = await client.post(app.url_path_for("get_compute", compute_id="my_compute_id") + "/auto_idlepc", json=params)
|
||||
assert mock.called
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
# async def test_compute_autoidlepc(self, app: FastAPI, client: AsyncClient) -> None:
|
||||
#
|
||||
# params = {
|
||||
# "protocol": "http",
|
||||
# "host": "localhost",
|
||||
# "port": 84,
|
||||
# "user": "julien",
|
||||
# "password": "secure"
|
||||
# }
|
||||
#
|
||||
# response = await client.post(app.url_path_for("get_computes"), json=params)
|
||||
# assert response.status_code == status.HTTP_201_CREATED
|
||||
# compute_id = response.json()["compute_id"]
|
||||
#
|
||||
# params = {
|
||||
# "platform": "c7200",
|
||||
# "image": "test.bin",
|
||||
# "ram": 512
|
||||
# }
|
||||
#
|
||||
# with asyncio_patch("gns3server.controller.Controller.autoidlepc", return_value={"idlepc": "0x606de20c"}) as mock:
|
||||
# response = await client.post(app.url_path_for("autoidlepc", compute_id=compute_id) + "/auto_idlepc", json=params)
|
||||
# assert mock.called
|
||||
# assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
# FIXME
|
||||
|
@ -4,6 +4,7 @@ import tempfile
|
||||
import shutil
|
||||
import sys
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
@ -16,10 +17,12 @@ from gns3server.config import Config
|
||||
from gns3server.compute import MODULES
|
||||
from gns3server.compute.port_manager import PortManager
|
||||
from gns3server.compute.project_manager import ProjectManager
|
||||
from gns3server.db.models import Base, User
|
||||
from gns3server.db.models import Base, User, Compute
|
||||
from gns3server.db.repositories.users import UsersRepository
|
||||
from gns3server.db.repositories.computes import ComputesRepository
|
||||
from gns3server.api.routes.controller.dependencies.database import get_db_session
|
||||
from gns3server.schemas.users import UserCreate
|
||||
from gns3server import schemas
|
||||
from gns3server.schemas.computes import Protocol
|
||||
from gns3server.services import auth_service
|
||||
|
||||
sys._called_from_test = True
|
||||
@ -27,7 +30,7 @@ sys.original_platform = sys.platform
|
||||
|
||||
|
||||
if sys.platform.startswith("win") and sys.version_info < (3, 8):
|
||||
@pytest.yield_fixture(scope="session")
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request):
|
||||
"""
|
||||
Overwrite pytest_asyncio event loop on Windows for Python < 3.8
|
||||
@ -43,7 +46,7 @@ if sys.platform.startswith("win") and sys.version_info < (3, 8):
|
||||
# https://github.com/pytest-dev/pytest-asyncio/issues/68
|
||||
# this event_loop is used by pytest-asyncio, and redefining it
|
||||
# is currently the only way of changing the scope of this fixture
|
||||
@pytest.yield_fixture(scope="class")
|
||||
@pytest.fixture(scope="class")
|
||||
def event_loop(request):
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
@ -54,9 +57,6 @@ def event_loop(request):
|
||||
@pytest.fixture(scope="class")
|
||||
async def app() -> FastAPI:
|
||||
|
||||
# async with db_engine.begin() as conn:
|
||||
# await conn.run_sync(Base.metadata.drop_all)
|
||||
# await conn.run_sync(Base.metadata.create_all)
|
||||
from gns3server.api.server import app as gns3app
|
||||
yield gns3app
|
||||
|
||||
@ -109,7 +109,7 @@ async def client(app: FastAPI, db_session: AsyncSession) -> AsyncClient:
|
||||
@pytest.fixture
|
||||
async def test_user(db_session: AsyncSession) -> User:
|
||||
|
||||
new_user = UserCreate(
|
||||
new_user = schemas.UserCreate(
|
||||
username="user1",
|
||||
email="user1@email.com",
|
||||
password="user1_password",
|
||||
@ -121,6 +121,25 @@ async def test_user(db_session: AsyncSession) -> User:
|
||||
return await user_repo.create_user(new_user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_compute(db_session: AsyncSession) -> Compute:
|
||||
|
||||
new_compute = schemas.ComputeCreate(
|
||||
compute_id=uuid.uuid4(),
|
||||
protocol=Protocol.http,
|
||||
host="localhost",
|
||||
port=4242,
|
||||
user="julien",
|
||||
password="secure"
|
||||
)
|
||||
|
||||
compute_repo = ComputesRepository(db_session)
|
||||
existing_compute = await compute_repo.get_compute(new_compute.compute_id)
|
||||
if existing_compute:
|
||||
return existing_compute
|
||||
return await compute_repo.create_compute(new_compute)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authorized_client(client: AsyncClient, test_user: User) -> AsyncClient:
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user