diff --git a/gns3server/compute/project.py b/gns3server/compute/project.py index b9567c94..b8da58e6 100644 --- a/gns3server/compute/project.py +++ b/gns3server/compute/project.py @@ -29,7 +29,7 @@ from .port_manager import PortManager from .notification_manager import NotificationManager from ..config import Config from ..utils.asyncio import wait_run_in_executor -from ..utils.path import check_path_allowed +from ..utils.path import check_path_allowed, get_default_project_directory import logging @@ -63,7 +63,7 @@ class Project: self._used_udp_ports = set() if path is None: - location = self._config().get("project_directory", self._get_default_project_directory()) + location = get_default_project_directory() path = os.path.join(location, self._id) try: os.makedirs(path, exist_ok=True) @@ -95,22 +95,6 @@ class Project: return self._config().getboolean("local", False) - @classmethod - def _get_default_project_directory(cls): - """ - Return the default location for the project directory - depending of the operating system - """ - - server_config = Config.instance().get_section_config("Server") - path = os.path.expanduser(server_config.get("projects_path", "~/GNS3/projects")) - path = os.path.normpath(path) - try: - os.makedirs(path, exist_ok=True) - except OSError as e: - raise aiohttp.web.HTTPInternalServerError(text="Could not create project directory: {}".format(e)) - return path - @property def id(self): @@ -418,8 +402,7 @@ class Project: At startup drop old temporary project. After a crash for example """ - config = Config.instance().get_section_config("Server") - directory = config.get("project_directory", cls._get_default_project_directory()) + directory = get_default_project_directory() if os.path.exists(directory): for project in os.listdir(directory): path = os.path.join(directory, project) diff --git a/gns3server/controller/project.py b/gns3server/controller/project.py index 26fe6cb3..92ed2189 100644 --- a/gns3server/controller/project.py +++ b/gns3server/controller/project.py @@ -25,7 +25,7 @@ from .vm import VM from .udp_link import UDPLink from ..notification_queue import NotificationQueue from ..config import Config -from ..utils.path import check_path_allowed +from ..utils.path import check_path_allowed, get_default_project_directory class Project: @@ -50,8 +50,7 @@ class Project: self._id = project_id if path is None: - location = self._config().get("project_directory", self._get_default_project_directory()) - path = os.path.join(location, self._id) + path = os.path.join(get_default_project_directory(), self._id) self.path = path self._temporary = temporary @@ -205,22 +204,6 @@ class Project: for listener in self._listeners: listener.put_nowait((action, event, kwargs)) - @classmethod - def _get_default_project_directory(cls): - """ - Return the default location for the project directory - depending of the operating system - """ - - server_config = Config.instance().get_section_config("Server") - path = os.path.expanduser(server_config.get("projects_path", "~/GNS3/projects")) - path = os.path.normpath(path) - try: - os.makedirs(path, exist_ok=True) - except OSError as e: - raise aiohttp.web.HTTPInternalServerError(text="Could not create project directory: {}".format(e)) - return path - def __json__(self): return { diff --git a/gns3server/utils/path.py b/gns3server/utils/path.py index ebafe9fe..94a1ee64 100644 --- a/gns3server/utils/path.py +++ b/gns3server/utils/path.py @@ -21,6 +21,22 @@ import aiohttp from ..config import Config +def get_default_project_directory(): + """ + Return the default location for the project directory + depending of the operating system + """ + + server_config = Config.instance().get_section_config("Server") + path = os.path.expanduser(server_config.get("projects_path", "~/GNS3/projects")) + path = os.path.normpath(path) + try: + os.makedirs(path, exist_ok=True) + except OSError as e: + raise aiohttp.web.HTTPInternalServerError(text="Could not create project directory: {}".format(e)) + return path + + def check_path_allowed(path): """ If the server is non local raise an error if @@ -30,10 +46,10 @@ def check_path_allowed(path): """ config = Config.instance().get_section_config("Server") - project_directory = config.get("project_directory") + project_directory = get_default_project_directory() if len(os.path.commonprefix([project_directory, path])) == len(project_directory): return - if config.getboolean("local") is False: + if "local" in config and config.getboolean("local") is False: raise aiohttp.web.HTTPForbidden(text="The path is not allowed") diff --git a/tests/compute/test_project.py b/tests/compute/test_project.py index 8f171308..638c5196 100644 --- a/tests/compute/test_project.py +++ b/tests/compute/test_project.py @@ -69,10 +69,10 @@ def test_clean_tmp_directory(async_run): def test_path(tmpdir): - directory = Config.instance().get_section_config("Server").get("project_directory") + directory = Config.instance().get_section_config("Server").get("projects_path") with patch("gns3server.compute.project.Project.is_local", return_value=True): - with patch("gns3server.compute.project.Project._get_default_project_directory", return_value=directory): + with patch("gns3server.utils.path.get_default_project_directory", return_value=directory): p = Project(project_id=str(uuid4())) assert p.path == os.path.join(directory, p.id) assert os.path.exists(os.path.join(directory, p.id)) @@ -124,7 +124,7 @@ def test_json(tmpdir): def test_vm_working_directory(tmpdir, vm): - directory = Config.instance().get_section_config("Server").get("project_directory") + directory = Config.instance().get_section_config("Server").get("projects_path") with patch("gns3server.compute.project.Project.is_local", return_value=True): p = Project(project_id=str(uuid4())) @@ -211,15 +211,6 @@ def test_project_close_temporary_project(loop, manager): assert os.path.exists(directory) is False -def test_get_default_project_directory(monkeypatch): - - monkeypatch.undo() - project = Project(project_id=str(uuid4())) - path = os.path.normpath(os.path.expanduser("~/GNS3/projects")) - assert project._get_default_project_directory() == path - assert os.path.exists(path) - - def test_clean_project_directory(tmpdir): # A non anonymous project with uuid. @@ -237,7 +228,7 @@ def test_clean_project_directory(tmpdir): with open(str(tmp), 'w+') as f: f.write("1") - with patch("gns3server.config.Config.get_section_config", return_value={"project_directory": str(tmpdir)}): + with patch("gns3server.config.Config.get_section_config", return_value={"projects_path": str(tmpdir)}): Project.clean_project_directory() assert os.path.exists(str(project1)) @@ -247,7 +238,7 @@ def test_clean_project_directory(tmpdir): def test_list_files(tmpdir, loop): - with patch("gns3server.config.Config.get_section_config", return_value={"project_directory": str(tmpdir)}): + with patch("gns3server.config.Config.get_section_config", return_value={"projects_path": str(tmpdir)}): project = Project(project_id=str(uuid4())) path = project.path os.makedirs(os.path.join(path, "vm-1", "dynamips")) diff --git a/tests/conftest.py b/tests/conftest.py index 85f2c194..7c0de0bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,7 +203,7 @@ def run_around_tests(monkeypatch, port_manager, controller, config): port_manager._instance = port_manager os.makedirs(os.path.join(tmppath, 'projects')) - config.set("Server", "project_directory", os.path.join(tmppath, 'projects')) + config.set("Server", "projects_path", os.path.join(tmppath, 'projects')) config.set("Server", "images_path", os.path.join(tmppath, 'images')) config.set("Server", "auth", False) config.set("Server", "controller", True) @@ -216,7 +216,7 @@ def run_around_tests(monkeypatch, port_manager, controller, config): # Force turn off KVM because it's not available on CI config.set("Qemu", "enable_kvm", False) - monkeypatch.setattr("gns3server.compute.project.Project._get_default_project_directory", lambda *args: os.path.join(tmppath, 'projects')) + monkeypatch.setattr("gns3server.utils.path.get_default_project_directory", lambda *args: os.path.join(tmppath, 'projects')) # Force sys.platform to the original value. Because it seem not be restore correctly at each tests sys.platform = sys.original_platform diff --git a/tests/controller/test_project.py b/tests/controller/test_project.py index cb74d959..c2d524ec 100644 --- a/tests/controller/test_project.py +++ b/tests/controller/test_project.py @@ -43,9 +43,9 @@ def test_json(tmpdir): def test_path(tmpdir): - directory = Config.instance().get_section_config("Server").get("project_directory") + directory = Config.instance().get_section_config("Server").get("projects_path") - with patch("gns3server.compute.project.Project._get_default_project_directory", return_value=directory): + with patch("gns3server.utils.path.get_default_project_directory", return_value=directory): p = Project(project_id=str(uuid4())) assert p.path == os.path.join(directory, p.id) assert os.path.exists(os.path.join(directory, p.id)) diff --git a/tests/handlers/api/compute/test_project.py b/tests/handlers/api/compute/test_project.py index 4e7a8e14..1b8ea231 100644 --- a/tests/handlers/api/compute/test_project.py +++ b/tests/handlers/api/compute/test_project.py @@ -201,7 +201,7 @@ def test_close_project_invalid_uuid(http_compute): def test_get_file(http_compute, tmpdir): - with patch("gns3server.config.Config.get_section_config", return_value={"project_directory": str(tmpdir)}): + with patch("gns3server.config.Config.get_section_config", return_value={"projects_path": str(tmpdir)}): project = ProjectManager.instance().create_project(project_id="01010203-0405-0607-0809-0a0b0c0d0e0b") with open(os.path.join(project.path, "hello"), "w+") as f: @@ -220,7 +220,7 @@ def test_get_file(http_compute, tmpdir): def test_stream_file(http_compute, tmpdir): - with patch("gns3server.config.Config.get_section_config", return_value={"project_directory": str(tmpdir)}): + with patch("gns3server.config.Config.get_section_config", return_value={"projects_path": str(tmpdir)}): project = ProjectManager.instance().create_project(project_id="01010203-0405-0607-0809-0a0b0c0d0e0b") with open(os.path.join(project.path, "hello"), "w+") as f: diff --git a/tests/utils/test_path.py b/tests/utils/test_path.py index 20bed40c..08c7042f 100644 --- a/tests/utils/test_path.py +++ b/tests/utils/test_path.py @@ -15,18 +15,28 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import os import pytest import aiohttp -from gns3server.utils.path import check_path_allowed +from gns3server.utils.path import check_path_allowed, get_default_project_directory def test_check_path_allowed(config, tmpdir): config.set("Server", "local", False) - config.set("Server", "project_directory", str(tmpdir)) + config.set("Server", "projects_path", str(tmpdir)) with pytest.raises(aiohttp.web.HTTPForbidden): check_path_allowed("/private") config.set("Server", "local", True) check_path_allowed(str(tmpdir / "hello" / "world")) check_path_allowed("/private") + + +def test_get_default_project_directory(config): + + config.clear() + + path = os.path.normpath(os.path.expanduser("~/GNS3/projects")) + assert get_default_project_directory() == path + assert os.path.exists(path)