Move more import code to the server

https://github.com/GNS3/gns3-gui/issues/1156
This commit is contained in:
Julien Duponchelle 2016-04-05 18:32:48 +02:00
parent f8ffd078a8
commit 9ed15e55af
No known key found for this signature in database
GPG Key ID: F1E2485547D4595D
4 changed files with 63 additions and 9 deletions

View File

@ -21,7 +21,6 @@ import json
import os import os
import psutil import psutil
import tempfile import tempfile
import zipfile
from ...web.route import Route from ...web.route import Route
from ...schemas.project import PROJECT_OBJECT_SCHEMA, PROJECT_CREATE_SCHEMA, PROJECT_UPDATE_SCHEMA, PROJECT_FILE_LIST_SCHEMA, PROJECT_LIST_SCHEMA from ...schemas.project import PROJECT_OBJECT_SCHEMA, PROJECT_CREATE_SCHEMA, PROJECT_UPDATE_SCHEMA, PROJECT_FILE_LIST_SCHEMA, PROJECT_LIST_SCHEMA
@ -58,6 +57,7 @@ class ProjectHandler:
description="Create a new project on the server", description="Create a new project on the server",
status_codes={ status_codes={
201: "Project created", 201: "Project created",
403: "You are not allowed to modify this property",
409: "Project already created" 409: "Project already created"
}, },
output=PROJECT_OBJECT_SCHEMA, output=PROJECT_OBJECT_SCHEMA,
@ -382,14 +382,16 @@ class ProjectHandler:
"project_id": "The UUID of the project", "project_id": "The UUID of the project",
}, },
raw=True, raw=True,
output=PROJECT_OBJECT_SCHEMA,
status_codes={ status_codes={
200: "Return the file" 200: "Project imported",
403: "You are not allowed to modify this property"
}) })
def import_project(request, response): def import_project(request, response):
pm = ProjectManager.instance() pm = ProjectManager.instance()
project_id = request.match_info["project_id"] project_id = request.match_info["project_id"]
project = pm.create_project(project_id=project_id) project = pm.get_project(project_id)
# We write the content to a temporary location # We write the content to a temporary location
# and after extract all. It could be more optimal to stream # and after extract all. It could be more optimal to stream
@ -403,10 +405,9 @@ class ProjectHandler:
if not packet: if not packet:
break break
temp.write(packet) temp.write(packet)
project.import_zip(temp)
with zipfile.ZipFile(temp) as myzip:
myzip.extractall(project.path)
except OSError as e: except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e)) raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
response.json(project)
response.set_status(201) response.set_status(201)

View File

@ -15,13 +15,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import aiohttp
import os import os
import aiohttp
import shutil import shutil
import asyncio import asyncio
import hashlib import hashlib
import zipstream import zipstream
import zipfile import zipfile
import json
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from .port_manager import PortManager from .port_manager import PortManager
@ -171,6 +172,8 @@ class Project:
@name.setter @name.setter
def name(self, name): def name(self, name):
if "/" in name or "\\" in name:
raise aiohttp.web.HTTPForbidden(text="Name can not contain path separator")
self._name = name self._name = name
@property @property
@ -540,3 +543,24 @@ class Project:
else: else:
z.write(path, os.path.relpath(path, self._path)) z.write(path, os.path.relpath(path, self._path))
return z return z
def import_zip(self, stream):
"""
Import a project contain in a zip file
:params: A io.BytesIO of the zifile
"""
with zipfile.ZipFile(stream) as myzip:
myzip.extractall(self.path)
project_file = os.path.join(self.path, "project.gns3")
if os.path.exists(project_file):
with open(project_file) as f:
topology = json.load(f)
topology["project_id"] = self.id
topology["name"] = self.name
with open(project_file, "w") as f:
json.dump(topology, f, indent=4)
shutil.move(project_file, os.path.join(self.path, self.name + ".gns3"))

View File

@ -306,12 +306,12 @@ def test_export(server, tmpdir, loop, project):
assert content == b"hello" assert content == b"hello"
def test_import(server, tmpdir, loop): def test_import(server, tmpdir, loop, project):
with zipfile.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: with zipfile.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip:
myzip.writestr("demo", b"hello") myzip.writestr("demo", b"hello")
project_id = str(uuid.uuid4()) project_id = project.id
with open(str(tmpdir / "test.zip"), "rb") as f: with open(str(tmpdir / "test.zip"), "rb") as f:
response = server.post("/projects/{project_id}/import".format(project_id=project_id), body=f.read(), raw=True) response = server.post("/projects/{project_id}/import".format(project_id=project_id), body=f.read(), raw=True)

View File

@ -17,6 +17,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os import os
import uuid
import json
import asyncio import asyncio
import pytest import pytest
import aiohttp import aiohttp
@ -293,3 +295,30 @@ def test_export(tmpdir):
assert 'project.gns3' in myzip.namelist() assert 'project.gns3' in myzip.namelist()
assert 'project-files/snapshots/test' not in myzip.namelist() assert 'project-files/snapshots/test' not in myzip.namelist()
assert 'vm-1/dynamips/test_log.txt' not in myzip.namelist() assert 'vm-1/dynamips/test_log.txt' not in myzip.namelist()
def test_import(tmpdir):
project_id = str(uuid.uuid4())
project = Project(name="test", project_id=project_id)
with open(str(tmpdir / "project.gns3"), 'w+') as f:
f.write('{"project_id": "ddd", "name": "test"}')
with open(str(tmpdir / "b.png"), 'w+') as f:
f.write("B")
zip_path = str(tmpdir / "project.zip")
with zipfile.ZipFile(zip_path, 'w') as myzip:
myzip.write(str(tmpdir / "project.gns3"), "project.gns3")
myzip.write(str(tmpdir / "b.png"), "b.png")
with open(zip_path, "rb") as f:
project.import_zip(f)
assert os.path.exists(os.path.join(project.path, "b.png"))
assert os.path.exists(os.path.join(project.path, "test.gns3"))
with open(os.path.join(project.path, "test.gns3")) as f:
content = json.load(f)
assert content["project_id"] == project_id
assert content["name"] == project.name