Generate JWT secret key if none is configured in the config file.

Change location of the database.
This commit is contained in:
grossmj 2020-12-16 18:24:21 +10:30
parent 509e762cda
commit bde706d19a
8 changed files with 130 additions and 40 deletions

View File

@ -4,10 +4,15 @@ host = 0.0.0.0
; HTTP port for controlling the servers
port = 3080
; Option to enable SSL encryption
; Options to enable SSL encryption
ssl = False
certfile=/home/gns3/.config/GNS3/ssl/server.cert
certkey=/home/gns3/.config/GNS3/ssl/server.key
certfile = /home/gns3/.config/GNS3/ssl/server.cert
certkey = /home/gns3/.config/GNS3/ssl/server.key
; Options for JWT tokens (user authentication)
jwt_secret_key = efd08eccec3bd0a1be2e086670e5efa90969c68d07e072d7354a76cea5e33d4e
jwt_algorithm = HS256
jwt_access_token_expire_minutes = 1440
; Path where devices images are stored
images_path = /home/gns3/GNS3/images

View File

@ -25,7 +25,12 @@ from uuid import UUID
from typing import List
from gns3server import schemas
from gns3server.controller.controller_error import ControllerBadRequestError, ControllerNotFoundError
from gns3server.controller.controller_error import (
ControllerBadRequestError,
ControllerNotFoundError,
ControllerUnauthorizedError
)
from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service
@ -98,11 +103,18 @@ async def update_user(
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user(user_id: UUID, user_repo: UsersRepository = Depends(get_repository(UsersRepository))):
async def delete_user(
user_id: UUID,
user_repo: UsersRepository = Depends(get_repository(UsersRepository)),
current_user: schemas.User = Depends(get_current_active_user)
) -> None:
"""
Delete an user.
"""
if current_user.is_superuser:
raise ControllerUnauthorizedError("The super user cannot be deleted")
success = await user_repo.delete_user(user_id)
if not success:
raise ControllerNotFoundError(f"User '{user_id}' not found")

View File

@ -182,9 +182,21 @@ class Config:
controller_config_filename = "gns3_controller.conf"
return os.path.join(self.config_dir, controller_config_filename)
@property
def server_config(self):
if sys.platform.startswith("win"):
server_config_filename = "gns3_server.ini"
else:
server_config_filename = "gns3_server.conf"
return os.path.join(self.config_dir, server_config_filename)
def clear(self):
"""Restart with a clean config"""
self._config = configparser.RawConfigParser()
"""
Restart with a clean config
"""
self._config = configparser.ConfigParser(interpolation=None)
# Override config from command line even if we modify the config file and live reload it.
self._override_config = {}
@ -231,6 +243,18 @@ class Config:
log.info("Load configuration file {}".format(file))
self._watched_files[file] = os.stat(file).st_mtime
def write_config(self):
"""
Write the server configuration file.
"""
try:
os.makedirs(os.path.dirname(self.server_config), exist_ok=True)
with open(self.server_config, 'w+') as fd:
self._config.write(fd)
except OSError as e:
log.error("Cannot write server configuration file '{}': {}".format(self.server_config, e))
def get_default_section(self):
"""
Get the default configuration section.

View File

@ -74,6 +74,7 @@ class BaseTable(Base):
def generate_uuid():
return str(uuid.uuid4())
class User(BaseTable):
__tablename__ = "users"

View File

@ -22,6 +22,7 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import create_async_engine
from .models import Base
from gns3server.config import Config
import logging
log = logging.getLogger(__name__)
@ -29,12 +30,13 @@ log = logging.getLogger(__name__)
async def connect_to_db(app: FastAPI) -> None:
db_url = os.environ.get("GNS3_DATABASE_URI", "sqlite:///./sql_app.db")
db_path = os.path.join(Config.instance().config_dir, "gns3_controller.db")
db_url = os.environ.get("GNS3_DATABASE_URI", f"sqlite:///{db_path}")
engine = create_async_engine(db_url, connect_args={"check_same_thread": False}, future=True)
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
log.info("Successfully connected to the database")
log.info(f"Successfully connected to database '{db_url}'")
app.state._db_engine = engine
except SQLAlchemyError as e:
log.error(f"Error while connecting to the database: {e}")
log.error(f"Error while connecting to database '{db_url}: {e}")

View File

@ -31,6 +31,7 @@ import asyncio
import signal
import functools
import uvicorn
import secrets
from gns3server.controller import Controller
from gns3server.compute.port_manager import PortManager
@ -122,7 +123,7 @@ def parse_arguments(argv):
config = Config.instance().get_section_config("Server")
defaults = {
"host": config.get("host", "0.0.0.0"),
"port": config.get("port", 3080),
"port": config.getint("port", 3080),
"ssl": config.getboolean("ssl", False),
"certfile": config.get("certfile", ""),
"certkey": config.get("certkey", ""),
@ -132,8 +133,8 @@ def parse_arguments(argv):
"quiet": config.getboolean("quiet", False),
"debug": config.getboolean("debug", False),
"logfile": config.getboolean("logfile", ""),
"logmaxsize": config.get("logmaxsize", 10000000), # default is 10MB
"logbackupcount": config.get("logbackupcount", 10),
"logmaxsize": config.getint("logmaxsize", 10000000), # default is 10MB
"logbackupcount": config.getint("logbackupcount", 10),
"logcompression": config.getboolean("logcompression", False)
}
@ -145,6 +146,13 @@ def set_config(args):
config = Config.instance()
server_config = config.get_section_config("Server")
jwt_secret_key = server_config.get("jwt_secret_key", None)
if not jwt_secret_key:
log.info("No JWT secret key configured, generating one...")
if not config._config.has_section("Server"):
config._config.add_section("Server")
config._config.set("Server", "jwt_secret_key", secrets.token_hex(32))
config.write_config()
server_config["local"] = str(args.local)
server_config["allow_remote_console"] = str(args.allow)
server_config["host"] = args.host

View File

@ -16,7 +16,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import bcrypt
from jose import JWTError, jwt
from datetime import datetime, timedelta
from passlib.context import CryptContext
@ -24,19 +23,22 @@ from passlib.context import CryptContext
from typing import Optional
from fastapi import HTTPException, status
from gns3server.schemas.tokens import TokenData
from gns3server.controller.controller_error import ControllerError
from gns3server.config import Config
from pydantic import ValidationError
# FIXME: temporary variables to move to config
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
import logging
log = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class AuthService:
def __init__(self):
self._server_config = Config.instance().get_section_config("Server")
def hash_password(self, password: str) -> str:
return pwd_context.hash(password)
@ -45,19 +47,40 @@ class AuthService:
return pwd_context.verify(password, hashed_password)
def get_secret_key(self):
"""
Should only be used by tests.
"""
return self._server_config.get("jwt_secret_key", None)
def get_algorithm(self):
"""
Should only be used by tests.
"""
return self._server_config.get("jwt_algorithm", None)
def create_access_token(
self,
username,
secret_key: str = SECRET_KEY,
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES
secret_key: str = None,
expires_in: int = 0
) -> str:
if not expires_in:
expires_in = self._server_config.getint("jwt_access_token_expire_minutes", 1440)
expire = datetime.utcnow() + timedelta(minutes=expires_in)
to_encode = {"sub": username, "exp": expire}
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
if secret_key is None:
secret_key = self._server_config.get("jwt_secret_key", None)
if secret_key is None:
raise ControllerError("No JWT secret key has been configured")
algorithm = self._server_config.get("jwt_algorithm", "HS256")
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt
def get_username_from_token(self, token: str, secret_key: str = SECRET_KEY) -> Optional[str]:
def get_username_from_token(self, token: str, secret_key: str = None) -> Optional[str]:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -65,7 +88,12 @@ class AuthService:
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
if secret_key is None:
secret_key = self._server_config.get("jwt_secret_key", None)
if secret_key is None:
raise ControllerError("No JWT secret key has been configured")
algorithm = self._server_config.get("jwt_algorithm", "HS256")
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
username: str = payload.get("sub")
if username is None:
raise credentials_exception

View File

@ -17,16 +17,15 @@
import pytest
from typing import Optional, Union
from typing import Optional
from fastapi import FastAPI, HTTPException, status
from starlette.datastructures import Secret
from httpx import AsyncClient
from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
from gns3server.db.repositories.users import UsersRepository
from gns3server.services import auth_service
from gns3server.services.authentication import SECRET_KEY, ALGORITHM
from gns3server.config import Config
from gns3server.schemas.users import User
pytestmark = pytest.mark.asyncio
@ -36,7 +35,7 @@ class TestUserRoutes:
async def test_route_exist(self, app: FastAPI, client: AsyncClient) -> None:
new_user = {"username": "test_user1", "email": "user1@email.com", "password": "test_password"}
new_user = {"username": "user1", "email": "user1@email.com", "password": "test_password"}
response = await client.post(app.url_path_for("create_user"), json=new_user)
assert response.status_code != status.HTTP_404_NOT_FOUND
@ -48,7 +47,7 @@ class TestUserRoutes:
) -> None:
user_repo = UsersRepository(db_session)
params = {"username": "test_user2", "email": "user2@email.com", "password": "test_password"}
params = {"username": "user2", "email": "user2@email.com", "password": "test_password"}
# make sure the user doesn't exist in the database
user_in_db = await user_repo.get_user_by_username(params["username"])
@ -72,7 +71,7 @@ class TestUserRoutes:
"attr, value, status_code",
(
("email", "user2@email.com", status.HTTP_400_BAD_REQUEST),
("username", "test_user2", status.HTTP_400_BAD_REQUEST),
("username", "user2", status.HTTP_400_BAD_REQUEST),
("email", "invalid_email@one@two.io", status.HTTP_422_UNPROCESSABLE_ENTITY),
("password", "short", status.HTTP_422_UNPROCESSABLE_ENTITY),
("username", "user2@#$%^<>", status.HTTP_422_UNPROCESSABLE_ENTITY),
@ -101,7 +100,7 @@ class TestUserRoutes:
) -> None:
user_repo = UsersRepository(db_session)
new_user = {"username": "test_user3", "email": "user3@email.com", "password": "test_password"}
new_user = {"username": "user3", "email": "user3@email.com", "password": "test_password"}
# send post request to create user and ensure it is successful
res = await client.post(app.url_path_for("create_user"), json=new_user)
@ -114,6 +113,12 @@ class TestUserRoutes:
assert user_in_db.hashed_password != new_user["password"]
assert auth_service.verify_password(new_user["password"], user_in_db.hashed_password)
async def test_get_users(self, app: FastAPI, client: AsyncClient) -> None:
response = await client.get(app.url_path_for("get_users"))
assert response.status_code == status.HTTP_200_OK
assert len(response.json()) == 3 # user1, user2 and user3 should exist
class TestAuthTokens:
@ -124,16 +129,18 @@ class TestAuthTokens:
test_user: User
) -> None:
secret_key = auth_service._server_config.get("jwt_secret_key")
token = auth_service.create_access_token(test_user.username)
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
payload = jwt.decode(token, secret_key, algorithms=["HS256"])
username = payload.get("sub")
assert username == test_user.username
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient) -> None:
async def test_token_missing_user_is_invalid(self, app: FastAPI, client: AsyncClient, config: Config) -> None:
secret_key = auth_service._server_config.get("jwt_secret_key")
token = auth_service.create_access_token(None)
with pytest.raises(jwt.JWTError):
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
jwt.decode(token, secret_key, algorithms=["HS256"])
async def test_can_retrieve_username_from_token(
self,
@ -148,10 +155,10 @@ class TestAuthTokens:
@pytest.mark.parametrize(
"secret, wrong_token",
"wrong_secret, wrong_token",
(
(SECRET_KEY, "asdf"), # use wrong token
(SECRET_KEY, ""), # use wrong token
("use correct secret", "asdf"), # use wrong token
("use correct secret", ""), # use wrong token
("ABC123", "use correct token"), # use wrong secret
),
)
@ -160,15 +167,17 @@ class TestAuthTokens:
app: FastAPI,
client: AsyncClient,
test_user: User,
secret: Union[Secret, str],
wrong_secret: str,
wrong_token: Optional[str],
) -> None:
token = auth_service.create_access_token(test_user.username)
if wrong_secret == "use correct secret":
wrong_secret = auth_service._server_config.get("jwt_secret_key")
if wrong_token == "use correct token":
wrong_token = token
with pytest.raises(HTTPException):
auth_service.get_username_from_token(wrong_token, secret_key=str(secret))
auth_service.get_username_from_token(wrong_token, secret_key=wrong_secret)
class TestUserLogin:
@ -189,8 +198,9 @@ class TestUserLogin:
assert res.status_code == status.HTTP_200_OK
# check that token exists in response and has user encoded within it
secret_key = auth_service._server_config.get("jwt_secret_key")
token = res.json().get("access_token")
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
payload = jwt.decode(token, secret_key, algorithms=["HS256"])
assert "sub" in payload
username = payload.get("sub")
assert username == test_user.username