diff --git a/gns3server/api/routes/controller/__init__.py b/gns3server/api/routes/controller/__init__.py index cb1e6ddc..8d9009be 100644 --- a/gns3server/api/routes/controller/__init__.py +++ b/gns3server/api/routes/controller/__init__.py @@ -29,6 +29,7 @@ from . import snapshots from . import symbols from . import templates from . import users +from . import groups from .dependencies.authentication import get_current_active_user @@ -37,6 +38,13 @@ router = APIRouter() router.include_router(controller.router, tags=["Controller"]) router.include_router(users.router, prefix="/users", tags=["Users"]) +router.include_router( + groups.router, + dependencies=[Depends(get_current_active_user)], + prefix="/groups", + tags=["Users groups"] +) + router.include_router( appliances.router, dependencies=[Depends(get_current_active_user)], diff --git a/gns3server/api/routes/controller/groups.py b/gns3server/api/routes/controller/groups.py new file mode 100644 index 00000000..b3a96f66 --- /dev/null +++ b/gns3server/api/routes/controller/groups.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" +API routes for user groups. +""" + +from fastapi import APIRouter, Depends, status +from uuid import UUID +from typing import List + +from gns3server import schemas +from gns3server.controller.controller_error import ( + ControllerBadRequestError, + ControllerNotFoundError, + ControllerForbiddenError, +) + +from gns3server.db.repositories.users import UsersRepository +from .dependencies.database import get_repository + +import logging + +log = logging.getLogger(__name__) + +router = APIRouter() + + +@router.get("", response_model=List[schemas.UserGroup]) +async def get_user_groups( + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) +) -> List[schemas.UserGroup]: + """ + Get all user groups. + """ + + return await users_repo.get_user_groups() + + +@router.post( + "", + response_model=schemas.UserGroup, + status_code=status.HTTP_201_CREATED +) +async def create_user_group( + user_group_create: schemas.UserGroupCreate, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) +) -> schemas.UserGroup: + """ + Create a new user group. + """ + + if await users_repo.get_user_group_by_name(user_group_create.name): + raise ControllerBadRequestError(f"User group '{user_group_create.name}' already exists") + + return await users_repo.create_user_group(user_group_create) + + +@router.get("/{user_group_id}", response_model=schemas.UserGroup) +async def get_user_group( + user_group_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)), +) -> schemas.UserGroup: + """ + Get an user group. + """ + + user_group = await users_repo.get_user_group(user_group_id) + if not user_group: + raise ControllerNotFoundError(f"User group '{user_group_id}' not found") + return user_group + + +@router.put("/{user_group_id}", response_model=schemas.UserGroup) +async def update_user_group( + user_group_id: UUID, + user_group_update: schemas.UserGroupUpdate, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) +) -> schemas.UserGroup: + """ + Update an user group. + """ + user_group = await users_repo.get_user_group(user_group_id) + if not user_group: + raise ControllerNotFoundError(f"User group '{user_group_id}' not found") + + if not user_group.is_updatable: + raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be updated") + + return await users_repo.update_user_group(user_group_id, user_group_update) + + +@router.delete( + "/{user_group_id}", + status_code=status.HTTP_204_NO_CONTENT +) +async def delete_user_group( + user_group_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)), +) -> None: + """ + Delete an user group + """ + + user_group = await users_repo.get_user_group(user_group_id) + if not user_group: + raise ControllerNotFoundError(f"User group '{user_group_id}' not found") + + if not user_group.is_updatable: + raise ControllerForbiddenError(f"User group '{user_group_id}' cannot be deleted") + + success = await users_repo.delete_user_group(user_group_id) + if not success: + raise ControllerNotFoundError(f"User group '{user_group_id}' could not be deleted") + + +@router.get("/{user_group_id}/members", response_model=List[schemas.User]) +async def get_user_group_members( + user_group_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) +) -> List[schemas.User]: + """ + Get all user group members. + """ + + return await users_repo.get_user_group_members(user_group_id) + + +@router.put( + "/{user_group_id}/members/{user_id}", + status_code=status.HTTP_204_NO_CONTENT +) +async def add_member_to_group( + user_group_id: UUID, + user_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) +) -> None: + """ + Add member to an user group. + """ + + user = await users_repo.get_user(user_id) + if not user: + raise ControllerNotFoundError(f"User '{user_id}' not found") + + user_group = await users_repo.add_member_to_user_group(user_group_id, user) + if not user_group: + raise ControllerNotFoundError(f"User group '{user_group_id}' not found") + + +@router.delete( + "/{user_group_id}/members/{user_id}", + status_code=status.HTTP_204_NO_CONTENT +) +async def remove_member_from_group( + user_group_id: UUID, + user_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)), +) -> None: + """ + Remove member from an user group. + """ + + user = await users_repo.get_user(user_id) + if not user: + raise ControllerNotFoundError(f"User '{user_id}' not found") + + user_group = await users_repo.remove_member_from_user_group(user_group_id, user) + if not user_group: + raise ControllerNotFoundError(f"User group '{user_group_id}' not found") diff --git a/gns3server/api/routes/controller/users.py b/gns3server/api/routes/controller/users.py index 4e9161b0..fc0af3d5 100644 --- a/gns3server/api/routes/controller/users.py +++ b/gns3server/api/routes/controller/users.py @@ -185,3 +185,15 @@ async def get_current_active_user(current_user: schemas.User = Depends(get_curre """ return current_user + + +@router.get("/{user_id}/groups", response_model=List[schemas.UserGroup]) +async def get_user_memberships( + user_id: UUID, + users_repo: UsersRepository = Depends(get_repository(UsersRepository)) +) -> List[schemas.UserGroup]: + """ + Get user memberships. + """ + + return await users_repo.get_user_memberships(user_id) diff --git a/gns3server/controller/compute.py b/gns3server/controller/compute.py index ebe969b7..e3c1cca5 100644 --- a/gns3server/controller/compute.py +++ b/gns3server/controller/compute.py @@ -487,7 +487,7 @@ class Compute: # Try to reconnect after 1 second if server unavailable only if not during tests (otherwise we create a ressources usage bomb) from gns3server.api.server import app - if not app.state.exiting and not hasattr(sys, "_called_from_test") or not sys._called_from_test: + if not app.state.exiting and not hasattr(sys, "_called_from_test"): log.info(f"Reconnecting to to compute '{self._id}' WebSocket '{ws_url}'") asyncio.get_event_loop().call_later(1, lambda: asyncio.ensure_future(self.connect())) diff --git a/gns3server/db/models/__init__.py b/gns3server/db/models/__init__.py index 71ead9d8..1c644f72 100644 --- a/gns3server/db/models/__init__.py +++ b/gns3server/db/models/__init__.py @@ -16,7 +16,7 @@ # along with this program. If not, see . from .base import Base -from .users import User +from .users import User, UserGroup from .computes import Compute from .templates import ( Template, diff --git a/gns3server/db/models/base.py b/gns3server/db/models/base.py index 86017f06..1dbae1af 100644 --- a/gns3server/db/models/base.py +++ b/gns3server/db/models/base.py @@ -76,8 +76,8 @@ class BaseTable(Base): __abstract__ = True - created_at = Column(DateTime, default=func.current_timestamp()) - updated_at = Column(DateTime, default=func.current_timestamp(), onupdate=func.current_timestamp()) + created_at = Column(DateTime, server_default=func.current_timestamp()) + updated_at = Column(DateTime, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) def generate_uuid(): diff --git a/gns3server/db/models/users.py b/gns3server/db/models/users.py index 6f9400ba..2ee63a4c 100644 --- a/gns3server/db/models/users.py +++ b/gns3server/db/models/users.py @@ -15,15 +15,23 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from sqlalchemy import Boolean, Column, String, event +from sqlalchemy import Table, Boolean, Column, String, ForeignKey, event +from sqlalchemy.orm import relationship -from .base import BaseTable, generate_uuid, GUID +from .base import Base, BaseTable, generate_uuid, GUID from gns3server.services import auth_service import logging log = logging.getLogger(__name__) +users_group_members = Table( + "users_group_members", + Base.metadata, + Column("user_id", GUID, ForeignKey("users.user_id", ondelete="CASCADE")), + Column("user_group_id", GUID, ForeignKey("users_group.user_group_id", ondelete="CASCADE")) +) + class User(BaseTable): @@ -36,6 +44,7 @@ class User(BaseTable): hashed_password = Column(String) is_active = Column(Boolean, default=True) is_superadmin = Column(Boolean, default=False) + groups = relationship("UserGroup", secondary=users_group_members, back_populates="users") @event.listens_for(User.__table__, 'after_create') @@ -51,3 +60,46 @@ def create_default_super_admin(target, connection, **kw): connection.execute(stmt) connection.commit() log.info("The default super admin account has been created in the database") + + +class UserGroup(BaseTable): + + __tablename__ = "users_group" + + user_group_id = Column(GUID, primary_key=True, default=generate_uuid) + name = Column(String, unique=True, index=True) + is_updatable = Column(Boolean, default=True) + users = relationship("User", secondary=users_group_members, back_populates="groups") + + +@event.listens_for(UserGroup.__table__, 'after_create') +def create_default_user_groups(target, connection, **kw): + + default_groups = [ + {"name": "Administrators", "is_updatable": False}, + {"name": "Editors", "is_updatable": False}, + {"name": "Users", "is_updatable": False} + ] + + stmt = target.insert().values(default_groups) + connection.execute(stmt) + connection.commit() + log.info("The default user groups have been created in the database") + + +@event.listens_for(users_group_members, 'after_create') +def add_admin_to_group(target, connection, **kw): + + users_group_table = UserGroup.__table__ + stmt = users_group_table.select().where(users_group_table.c.name == "Administrators") + result = connection.execute(stmt) + user_group_id = result.first().user_group_id + + users_table = User.__table__ + stmt = users_table.select().where(users_table.c.username == "admin") + result = connection.execute(stmt) + user_id = result.first().user_id + + stmt = target.insert().values(user_id=user_id, user_group_id=user_group_id) + connection.execute(stmt) + connection.commit() diff --git a/gns3server/db/repositories/users.py b/gns3server/db/repositories/users.py index 96576cfc..c93c741e 100644 --- a/gns3server/db/repositories/users.py +++ b/gns3server/db/repositories/users.py @@ -16,9 +16,10 @@ # along with this program. If not, see . from uuid import UUID -from typing import Optional, List +from typing import Optional, List, Union from sqlalchemy import select, update, delete from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from .base import BaseRepository @@ -107,3 +108,105 @@ class UsersRepository(BaseRepository): if not self._auth_service.verify_password(password, user.hashed_password): return None return user + + async def get_user_memberships(self, user_id: UUID) -> List[models.UserGroup]: + + query = select(models.UserGroup).\ + join(models.UserGroup.users).\ + filter(models.User.user_id == user_id) + + result = await self._db_session.execute(query) + return result.scalars().all() + + async def get_user_group(self, user_group_id: UUID) -> Optional[models.UserGroup]: + + query = select(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_user_group_by_name(self, name: str) -> Optional[models.UserGroup]: + + query = select(models.UserGroup).where(models.UserGroup.name == name) + result = await self._db_session.execute(query) + return result.scalars().first() + + async def get_user_groups(self) -> List[models.UserGroup]: + + query = select(models.UserGroup) + result = await self._db_session.execute(query) + return result.scalars().all() + + async def create_user_group(self, user_group: schemas.UserGroupCreate) -> models.UserGroup: + + db_user_group = models.UserGroup(name=user_group.name) + self._db_session.add(db_user_group) + await self._db_session.commit() + await self._db_session.refresh(db_user_group) + return db_user_group + + async def update_user_group( + self, + user_group_id: UUID, + user_group_update: schemas.UserGroupUpdate + ) -> Optional[models.UserGroup]: + + update_values = user_group_update.dict(exclude_unset=True) + query = update(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id).values(update_values) + + await self._db_session.execute(query) + await self._db_session.commit() + return await self.get_user_group(user_group_id) + + async def delete_user_group(self, user_group_id: UUID) -> bool: + + query = delete(models.UserGroup).where(models.UserGroup.user_group_id == user_group_id) + result = await self._db_session.execute(query) + await self._db_session.commit() + return result.rowcount > 0 + + async def add_member_to_user_group( + self, + user_group_id: UUID, + user: models.User + ) -> Union[None, models.UserGroup]: + + query = select(models.UserGroup).\ + options(selectinload(models.UserGroup.users)).\ + where(models.UserGroup.user_group_id == user_group_id) + result = await self._db_session.execute(query) + user_group_db = result.scalars().first() + if not user_group_db: + return None + + user_group_db.users.append(user) + await self._db_session.commit() + await self._db_session.refresh(user_group_db) + return user_group_db + + async def remove_member_from_user_group( + self, + user_group_id: UUID, + user: models.User + ) -> Union[None, models.UserGroup]: + + query = select(models.UserGroup).\ + options(selectinload(models.UserGroup.users)).\ + where(models.UserGroup.user_group_id == user_group_id) + result = await self._db_session.execute(query) + user_group_db = result.scalars().first() + if not user_group_db: + return None + + user_group_db.users.remove(user) + await self._db_session.commit() + await self._db_session.refresh(user_group_db) + return user_group_db + + async def get_user_group_members(self, user_group_id: UUID) -> List[models.User]: + + query = select(models.User).\ + join(models.User.groups).\ + filter(models.UserGroup.user_group_id == user_group_id) + + result = await self._db_session.execute(query) + return result.scalars().all() diff --git a/gns3server/schemas/__init__.py b/gns3server/schemas/__init__.py index 838141f8..f32985a5 100644 --- a/gns3server/schemas/__init__.py +++ b/gns3server/schemas/__init__.py @@ -27,7 +27,7 @@ from .controller.drawings import Drawing from .controller.gns3vm import GNS3VM from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile -from .controller.users import UserCreate, UserUpdate, User, Credentials +from .controller.users import UserCreate, UserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup from .controller.tokens import Token from .controller.snapshots import SnapshotCreate, Snapshot from .controller.iou_license import IOULicense diff --git a/gns3server/schemas/controller/users.py b/gns3server/schemas/controller/users.py index 3fa0a551..28d40688 100644 --- a/gns3server/schemas/controller/users.py +++ b/gns3server/schemas/controller/users.py @@ -58,6 +58,39 @@ class User(DateTimeModelMixin, UserBase): orm_mode = True +class UserGroupBase(BaseModel): + """ + Common user group properties. + """ + + name: Optional[str] = Field(None, min_length=3, regex="[a-zA-Z0-9_-]+$") + + +class UserGroupCreate(UserGroupBase): + """ + Properties to create an user group. + """ + + name: Optional[str] = Field(..., min_length=3, regex="[a-zA-Z0-9_-]+$") + + +class UserGroupUpdate(UserGroupBase): + """ + Properties to update an user group. + """ + + pass + + +class UserGroup(DateTimeModelMixin, UserGroupBase): + + user_group_id: UUID + is_updatable: bool + + class Config: + orm_mode = True + + class Credentials(BaseModel): username: str diff --git a/gns3server/server.py b/gns3server/server.py index 3c7dedb1..e52f0834 100644 --- a/gns3server/server.py +++ b/gns3server/server.py @@ -319,6 +319,7 @@ class Server: access_log=access_log, ssl_certfile=config.Server.certfile, ssl_keyfile=config.Server.certkey, + lifespan="on" ) # overwrite uvicorn loggers with our own logger diff --git a/tests/api/routes/controller/test_groups.py b/tests/api/routes/controller/test_groups.py new file mode 100644 index 00000000..7551ab9b --- /dev/null +++ b/tests/api/routes/controller/test_groups.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python +# +# Copyright (C) 2021 GNS3 Technologies Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import pytest + +from fastapi import FastAPI, status +from httpx import AsyncClient + +from sqlalchemy.ext.asyncio import AsyncSession +from gns3server.db.repositories.users import UsersRepository +from gns3server.schemas.controller.users import User + +pytestmark = pytest.mark.asyncio + + +class TestGroupRoutes: + + async def test_create_group(self, app: FastAPI, client: AsyncClient) -> None: + + new_group = {"name": "group1"} + response = await client.post(app.url_path_for("create_user_group"), json=new_group) + assert response.status_code == status.HTTP_201_CREATED + + async def test_get_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("group1") + response = await client.get(app.url_path_for("get_user_group", user_group_id=group_in_db.user_group_id)) + assert response.status_code == status.HTTP_200_OK + assert response.json()["user_group_id"] == str(group_in_db.user_group_id) + + async def test_list_groups(self, app: FastAPI, client: AsyncClient) -> None: + + response = await client.get(app.url_path_for("get_user_groups")) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 4 # 3 default groups + group1 + + async def test_update_group(self, app: FastAPI, client: AsyncClient, db_session: AsyncSession) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("group1") + + update_group = {"name": "group42"} + response = await client.put( + app.url_path_for("update_user_group", user_group_id=group_in_db.user_group_id), + json=update_group + ) + assert response.status_code == status.HTTP_200_OK + updated_group_in_db = await user_repo.get_user_group(group_in_db.user_group_id) + assert updated_group_in_db.name == "group42" + + async def test_cannot_update_admin_group( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Administrators") + update_group = {"name": "Hackers"} + response = await client.put( + app.url_path_for("update_user_group", user_group_id=group_in_db.user_group_id), + json=update_group + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + async def test_delete_group( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("group42") + response = await client.delete(app.url_path_for("delete_user_group", user_group_id=group_in_db.user_group_id)) + assert response.status_code == status.HTTP_204_NO_CONTENT + + async def test_cannot_delete_admin_group( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Administrators") + response = await client.delete(app.url_path_for("delete_user_group", user_group_id=group_in_db.user_group_id)) + assert response.status_code == status.HTTP_403_FORBIDDEN + + async def test_add_member_to_group( + self, + app: FastAPI, + client: AsyncClient, + test_user: User, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Users") + response = await client.put( + app.url_path_for( + "add_member_to_group", + user_group_id=group_in_db.user_group_id, + user_id=str(test_user.user_id) + ) + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + members = await user_repo.get_user_group_members(group_in_db.user_group_id) + assert len(members) == 1 + assert members[0].username == test_user.username + + async def test_get_user_group_members( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Users") + response = await client.get( + app.url_path_for( + "get_user_group_members", + user_group_id=group_in_db.user_group_id) + ) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 1 + + async def test_remove_member_from_group( + self, + app: FastAPI, + client: AsyncClient, + test_user: User, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + group_in_db = await user_repo.get_user_group_by_name("Users") + + response = await client.delete( + app.url_path_for( + "remove_member_from_group", + user_group_id=group_in_db.user_group_id, + user_id=str(test_user.user_id) + ), + ) + assert response.status_code == status.HTTP_204_NO_CONTENT + members = await user_repo.get_user_group_members(group_in_db.user_group_id) + assert len(members) == 0 diff --git a/tests/api/routes/controller/test_users.py b/tests/api/routes/controller/test_users.py index 9fa94e74..f1a5dc41 100644 --- a/tests/api/routes/controller/test_users.py +++ b/tests/api/routes/controller/test_users.py @@ -56,8 +56,8 @@ class TestUserRoutes: assert user_in_db is None # register the user - res = await client.post(app.url_path_for("create_user"), json=params) - assert res.status_code == status.HTTP_201_CREATED + response = await client.post(app.url_path_for("create_user"), json=params) + assert response.status_code == status.HTTP_201_CREATED # make sure the user does exists in the database now user_in_db = await user_repo.get_user_by_username(params["username"]) @@ -66,7 +66,7 @@ class TestUserRoutes: assert user_in_db.username == params["username"] # check that the user returned in the response is equal to the user in the database - created_user = User(**res.json()).json() + created_user = User(**response.json()).json() assert created_user == User.from_orm(user_in_db).json() @pytest.mark.parametrize( @@ -91,8 +91,8 @@ class TestUserRoutes: new_user = {"email": "not_taken@email.com", "username": "not_taken_username", "password": "test_password"} new_user[attr] = value - res = await client.post(app.url_path_for("create_user"), json=new_user) - assert res.status_code == status_code + response = await client.post(app.url_path_for("create_user"), json=new_user) + assert response.status_code == status_code async def test_users_saved_password_is_hashed( self, @@ -105,8 +105,8 @@ class TestUserRoutes: new_user = {"username": "user3", "email": "user3@email.com", "password": "test_password"} # send post request to create user and ensure it is successful - res = await client.post(app.url_path_for("create_user"), json=new_user) - assert res.status_code == status.HTTP_201_CREATED + response = await client.post(app.url_path_for("create_user"), json=new_user) + assert response.status_code == status.HTTP_201_CREATED # ensure that the users password is hashed in the db # and that we can verify it using our auth service @@ -156,7 +156,6 @@ class TestAuthTokens: username = auth_service.get_username_from_token(token) assert username == test_user.username - @pytest.mark.parametrize( "wrong_secret, wrong_token", ( @@ -200,19 +199,19 @@ class TestUserLogin: "username": test_user.username, "password": "user1_password", } - res = await unauthorized_client.post(app.url_path_for("login"), data=login_data) - assert res.status_code == status.HTTP_200_OK + response = await unauthorized_client.post(app.url_path_for("login"), data=login_data) + assert response.status_code == status.HTTP_200_OK # check that token exists in response and has user encoded within it - token = res.json().get("access_token") + token = response.json().get("access_token") payload = jwt.decode(token, jwt_secret, algorithms=["HS256"]) assert "sub" in payload username = payload.get("sub") assert username == test_user.username # check that token is proper type - assert "token_type" in res.json() - assert res.json().get("token_type") == "bearer" + assert "token_type" in response.json() + assert response.json().get("token_type") == "bearer" async def test_user_can_authenticate_using_json( self, @@ -226,9 +225,9 @@ class TestUserLogin: "username": test_user.username, "password": "user1_password", } - res = await unauthorized_client.post(app.url_path_for("authenticate"), json=credentials) - assert res.status_code == status.HTTP_200_OK - assert res.json().get("access_token") + response = await unauthorized_client.post(app.url_path_for("authenticate"), json=credentials) + assert response.status_code == status.HTTP_200_OK + assert response.json().get("access_token") @pytest.mark.parametrize( "username, password, status_code", @@ -253,9 +252,9 @@ class TestUserLogin: "username": username, "password": password, } - res = await unauthorized_client.post(app.url_path_for("login"), data=login_data) - assert res.status_code == status_code - assert "access_token" not in res.json() + response = await unauthorized_client.post(app.url_path_for("login"), data=login_data) + assert response.status_code == status_code + assert "access_token" not in response.json() class TestUserMe: @@ -267,9 +266,9 @@ class TestUserMe: test_user: User, ) -> None: - res = await authorized_client.get(app.url_path_for("get_current_active_user")) - assert res.status_code == status.HTTP_200_OK - user = User(**res.json()) + response = await authorized_client.get(app.url_path_for("get_current_active_user")) + assert response.status_code == status.HTTP_200_OK + user = User(**response.json()) assert user.username == test_user.username assert user.email == test_user.email assert user.user_id == test_user.user_id @@ -280,8 +279,8 @@ class TestUserMe: test_user: User, ) -> None: - res = await unauthorized_client.get(app.url_path_for("get_current_active_user")) - assert res.status_code == status.HTTP_401_UNAUTHORIZED + response = await unauthorized_client.get(app.url_path_for("get_current_active_user")) + assert response.status_code == status.HTTP_401_UNAUTHORIZED class TestSuperAdmin: @@ -307,8 +306,8 @@ class TestSuperAdmin: user_repo = UsersRepository(db_session) admin_in_db = await user_repo.get_user_by_username("admin") - res = await client.delete(app.url_path_for("delete_user", user_id=admin_in_db.user_id)) - assert res.status_code == status.HTTP_403_FORBIDDEN + response = await client.delete(app.url_path_for("delete_user", user_id=admin_in_db.user_id)) + assert response.status_code == status.HTTP_403_FORBIDDEN async def test_admin_can_login_after_password_recovery( self, @@ -327,5 +326,18 @@ class TestSuperAdmin: "username": "admin", "password": "whatever", } - res = await unauthorized_client.post(app.url_path_for("login"), data=login_data) - assert res.status_code == status.HTTP_200_OK + response = await unauthorized_client.post(app.url_path_for("login"), data=login_data) + assert response.status_code == status.HTTP_200_OK + + async def test_super_admin_belongs_to_admin_group( + self, + app: FastAPI, + client: AsyncClient, + db_session: AsyncSession + ) -> None: + + user_repo = UsersRepository(db_session) + admin_in_db = await user_repo.get_user_by_username("admin") + response = await client.get(app.url_path_for("get_user_memberships", user_id=admin_in_db.user_id)) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()) == 1