Protect the API and add alternative authentication endpoint.

This commit is contained in:
grossmj 2021-04-20 11:59:02 +09:30
parent e28452f09a
commit 0465cb87f6
7 changed files with 199 additions and 80 deletions

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from . import controller
from . import appliances
@ -30,17 +30,80 @@ from . import symbols
from . import templates
from . import users
from .dependencies.authentication import get_current_active_user
router = APIRouter()
router.include_router(controller.router, tags=["Controller"])
router.include_router(users.router, prefix="/users", tags=["Users"])
router.include_router(appliances.router, prefix="/appliances", tags=["Appliances"])
router.include_router(computes.router, prefix="/computes", tags=["Computes"])
router.include_router(drawings.router, prefix="/projects/{project_id}/drawings", tags=["Drawings"])
router.include_router(gns3vm.router, prefix="/gns3vm", tags=["GNS3 VM"])
router.include_router(links.router, prefix="/projects/{project_id}/links", tags=["Links"])
router.include_router(nodes.router, prefix="/projects/{project_id}/nodes", tags=["Nodes"])
router.include_router(notifications.router, prefix="/notifications", tags=["Notifications"])
router.include_router(projects.router, prefix="/projects", tags=["Projects"])
router.include_router(snapshots.router, prefix="/projects/{project_id}/snapshots", tags=["Snapshots"])
router.include_router(symbols.router, prefix="/symbols", tags=["Symbols"])
router.include_router(templates.router, tags=["Templates"])
router.include_router(
appliances.router,
dependencies=[Depends(get_current_active_user)],
prefix="/appliances",
tags=["Appliances"]
)
router.include_router(
computes.router,
dependencies=[Depends(get_current_active_user)],
prefix="/computes",
tags=["Computes"]
)
router.include_router(
drawings.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/drawings",
tags=["Drawings"])
router.include_router(
gns3vm.router,
dependencies=[Depends(get_current_active_user)],
prefix="/gns3vm",
tags=["GNS3 VM"]
)
router.include_router(
links.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/links",
tags=["Links"]
)
router.include_router(
nodes.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/nodes",
tags=["Nodes"]
)
router.include_router(
notifications.router,
dependencies=[Depends(get_current_active_user)],
prefix="/notifications",
tags=["Notifications"])
router.include_router(
projects.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects",
tags=["Projects"])
router.include_router(
snapshots.router,
dependencies=[Depends(get_current_active_user)],
prefix="/projects/{project_id}/snapshots",
tags=["Snapshots"])
router.include_router(
symbols.router,
dependencies=[Depends(get_current_active_user)],
prefix="/symbols", tags=["Symbols"]
)
router.include_router(
templates.router,
dependencies=[Depends(get_current_active_user)],
tags=["Templates"]
)

View File

@ -18,7 +18,7 @@ import asyncio
import signal
import os
from fastapi import APIRouter, status
from fastapi import APIRouter, Depends, status
from fastapi.encoders import jsonable_encoder
from typing import List
@ -28,6 +28,7 @@ from gns3server.version import __version__
from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError
from gns3server import schemas
from .dependencies.authentication import get_current_active_user
import logging
@ -36,8 +37,39 @@ log = logging.getLogger(__name__)
router = APIRouter()
@router.get(
"/version",
response_model=schemas.Version,
)
def get_version() -> dict:
"""
Return the server version number.
"""
local_server = Config.instance().settings.Server.local
return {"version": __version__, "local": local_server}
@router.post(
"/version",
response_model=schemas.Version,
response_model_exclude_defaults=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Invalid version"}},
)
def check_version(version: schemas.Version) -> dict:
"""
Check if version is the same as the server.
"""
print(version.version)
if version.version != __version__:
raise ControllerError(f"Client version {version.version} is not the same as server version {__version__}")
return {"version": __version__}
@router.post(
"/shutdown",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_204_NO_CONTENT,
responses={403: {"model": schemas.ErrorMessage, "description": "Server shutdown not allowed"}},
)
@ -71,38 +103,11 @@ async def shutdown() -> None:
os.kill(os.getpid(), signal.SIGTERM)
@router.get("/version", response_model=schemas.Version)
def get_version() -> dict:
"""
Return the server version number.
"""
local_server = Config.instance().settings.Server.local
return {"version": __version__, "local": local_server}
@router.post(
"/version",
response_model=schemas.Version,
response_model_exclude_defaults=True,
responses={409: {"model": schemas.ErrorMessage, "description": "Invalid version"}},
@router.get(
"/iou_license",
dependencies=[Depends(get_current_active_user)],
response_model=schemas.IOULicense
)
def check_version(version: schemas.Version) -> dict:
"""
Check if version is the same as the server.
:param request:
:param response:
:return:
"""
print(version.version)
if version.version != __version__:
raise ControllerError(f"Client version {version.version} is not the same as server version {__version__}")
return {"version": __version__}
@router.get("/iou_license", response_model=schemas.IOULicense)
def get_iou_license() -> schemas.IOULicense:
"""
Return the IOU license settings
@ -111,7 +116,12 @@ def get_iou_license() -> schemas.IOULicense:
return Controller.instance().iou_license
@router.put("/iou_license", status_code=status.HTTP_201_CREATED, response_model=schemas.IOULicense)
@router.put(
"/iou_license",
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_201_CREATED,
response_model=schemas.IOULicense
)
async def update_iou_license(iou_license: schemas.IOULicense) -> schemas.IOULicense:
"""
Update the IOU license settings.
@ -124,7 +134,7 @@ async def update_iou_license(iou_license: schemas.IOULicense) -> schemas.IOULice
return current_iou_license
@router.get("/statistics")
@router.get("/statistics", dependencies=[Depends(get_current_active_user)])
async def statistics() -> List[dict]:
"""
Return server statistics.

View File

@ -24,7 +24,7 @@ from gns3server.services import auth_service
from .database import get_repository
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login") # FIXME: URL prefix
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/v3/users/login")
async def get_user_from_token(

View File

@ -44,10 +44,53 @@ log = logging.getLogger(__name__)
router = APIRouter()
@router.get("", response_model=List[schemas.User])
async def get_users(
@router.post("/login", response_model=schemas.Token)
async def login(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
form_data: OAuth2PasswordRequestForm = Depends(),
) -> schemas.Token:
"""
Default user login method using forms (x-www-form-urlencoded).
Example: curl http://host:port/v3/users/login -H "Content-Type: application/x-www-form-urlencoded" -d "username=admin&password=admin"
"""
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.",
headers={"WWW-Authenticate": "Bearer"},
)
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
return token
@router.post("/authenticate", response_model=schemas.Token)
async def authenticate(
user_credentials: schemas.Credentials,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
) -> schemas.Token:
"""
Alternative authentication method using json.
Example: curl http://host:port/v3/users/authenticate -d '{"username": "admin", "password": "admin"}'
"""
user = await users_repo.authenticate_user(username=user_credentials.username, password=user_credentials.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.",
headers={"WWW-Authenticate": "Bearer"},
)
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
return token
@router.get("", response_model=List[schemas.User], dependencies=[Depends(get_current_active_user)])
async def get_users(
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> List[schemas.User]:
"""
Get all users.
@ -56,11 +99,15 @@ async def get_users(
return await users_repo.get_users()
@router.post("", response_model=schemas.User, status_code=status.HTTP_201_CREATED)
@router.post(
"",
response_model=schemas.User,
dependencies=[Depends(get_current_active_user)],
status_code=status.HTTP_201_CREATED
)
async def create_user(
user_create: schemas.UserCreate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Create a new user.
@ -75,11 +122,10 @@ async def create_user(
return await users_repo.create_user(user_create)
@router.get("/{user_id}",response_model=schemas.User)
@router.get("/{user_id}", dependencies=[Depends(get_current_active_user)], response_model=schemas.User)
async def get_user(
user_id: UUID,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
) -> schemas.User:
"""
Get an user.
@ -91,12 +137,11 @@ async def get_user(
return user
@router.put("/{user_id}", response_model=schemas.User)
@router.put("/{user_id}", dependencies=[Depends(get_current_active_user)], response_model=schemas.User)
async def update_user(
user_id: UUID,
user_update: schemas.UserUpdate,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
users_repo: UsersRepository = Depends(get_repository(UsersRepository))
) -> schemas.User:
"""
Update an user.
@ -126,27 +171,6 @@ async def delete_user(
raise ControllerNotFoundError(f"User '{user_id}' not found")
@router.post("/login", response_model=schemas.Token)
async def login(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
form_data: OAuth2PasswordRequestForm = Depends(),
) -> schemas.Token:
"""
User login.
"""
user = await users_repo.authenticate_user(username=form_data.username, password=form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication was unsuccessful.",
headers={"WWW-Authenticate": "Bearer"},
)
token = schemas.Token(access_token=auth_service.create_access_token(user.username), token_type="bearer")
return token
@router.get("/users/me/", response_model=schemas.User)
async def get_current_active_user(current_user: schemas.User = Depends(get_current_active_user)) -> schemas.User:
"""

View File

@ -27,7 +27,7 @@ from .controller.drawings import Drawing
from .controller.gns3vm import GNS3VM
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
from .controller.users import UserCreate, UserUpdate, User
from .controller.users import UserCreate, UserUpdate, User, Credentials
from .controller.tokens import Token
from .controller.snapshots import SnapshotCreate, Snapshot
from .controller.iou_license import IOULicense

View File

@ -56,3 +56,9 @@ class User(DateTimeModelMixin, UserBase):
class Config:
orm_mode = True
class Credentials(BaseModel):
username: str
password: str

View File

@ -214,6 +214,22 @@ class TestUserLogin:
assert "token_type" in res.json()
assert res.json().get("token_type") == "bearer"
async def test_user_can_authenticate_using_json(
self,
app: FastAPI,
unauthorized_client: AsyncClient,
test_user: User,
config: Config
) -> None:
credentials = {
"username": test_user.username,
"password": "user1_password",
}
res = await unauthorized_client.post(app.url_path_for("authenticate"), json=credentials)
assert res.status_code == status.HTTP_200_OK
assert res.json().get("access_token")
@pytest.mark.parametrize(
"username, password, status_code",
(