Use aiofiles where relevant.

This commit is contained in:
grossmj 2019-03-06 23:00:01 +07:00
parent b0df7ecabf
commit af80b0bb6e
17 changed files with 90 additions and 234 deletions

View File

@ -20,6 +20,7 @@ import os
import struct
import stat
import asyncio
import aiofiles
import aiohttp
import socket
@ -46,6 +47,8 @@ from .nios.nio_ethernet import NIOEthernet
from ..utils.images import md5sum, remove_checksum, images_directories, default_images_directory, list_images
from .error import NodeError, ImageMissingError
CHUNK_SIZE = 1024 * 8 # 8KB
class BaseManager:
@ -456,7 +459,7 @@ class BaseManager:
with open(path, "rb") as f:
await response.prepare(request)
while nio.capturing:
data = f.read(4096)
data = f.read(CHUNK_SIZE)
if not data:
await asyncio.sleep(0.1)
continue
@ -594,18 +597,18 @@ class BaseManager:
path = os.path.abspath(os.path.join(directory, *os.path.split(filename)))
if os.path.commonprefix([directory, path]) != directory:
raise aiohttp.web.HTTPForbidden(text="Could not write image: {}, {} is forbidden".format(filename, path))
log.info("Writing image file %s", path)
log.info("Writing image file to '{}'".format(path))
try:
remove_checksum(path)
# We store the file under his final name only when the upload is finished
tmp_path = path + ".tmp"
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(tmp_path, 'wb') as f:
async with aiofiles.open(tmp_path, 'wb') as f:
while True:
packet = await stream.read(4096)
if not packet:
chunk = await stream.read(CHUNK_SIZE)
if not chunk:
break
f.write(packet)
await f.write(chunk)
os.chmod(tmp_path, stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC)
shutil.move(tmp_path, path)
await cancellable_wait_run_in_executor(md5sum, path)

View File

@ -37,6 +37,7 @@ log = logging.getLogger(__name__)
DOCKER_MINIMUM_API_VERSION = "1.25"
DOCKER_MINIMUM_VERSION = "1.13"
DOCKER_PREFERRED_API_VERSION = "1.30"
CHUNK_SIZE = 1024 * 8 # 8KB
class Docker(BaseManager):
@ -206,7 +207,7 @@ class Docker(BaseManager):
content = ""
while True:
try:
chunk = await response.content.read(1024)
chunk = await response.content.read(CHUNK_SIZE)
except aiohttp.ServerDisconnectedError:
log.error("Disconnected from server while pulling Docker image '{}' from docker hub".format(image))
break

View File

@ -320,28 +320,6 @@ class Compute:
raise aiohttp.web.HTTPNotFound(text="{} not found on compute".format(image))
return response
async def stream_file(self, project, path, timeout=None):
"""
Read file of a project and stream it
:param project: A project object
:param path: The path of the file in the project
:param timeout: timeout
:returns: A file stream
"""
url = self._getUrl("/projects/{}/stream/{}".format(project.id, path))
response = await self._session().request("GET", url, auth=self._auth, timeout=timeout)
if response.status == 404:
raise aiohttp.web.HTTPNotFound(text="file '{}' not found on compute".format(path))
elif response.status == 403:
raise aiohttp.web.HTTPForbidden(text="forbidden to open '{}' on compute".format(path))
elif response.status != 200:
raise aiohttp.web.HTTPInternalServerError(text="Unexpected error {}: {}: while opening {} on compute".format(response.status,
response.reason,
path))
return response
async def http_query(self, method, path, data=None, dont_connect=False, **kwargs):
"""
:param dont_connect: If true do not reconnect if not connected

View File

@ -19,6 +19,7 @@ import os
import sys
import json
import asyncio
import aiofiles
import aiohttp
import zipfile
import tempfile
@ -28,6 +29,8 @@ from datetime import datetime
import logging
log = logging.getLogger(__name__)
CHUNK_SIZE = 1024 * 8 # 8KB
async def export_project(zstream, project, temporary_dir, include_images=False, keep_compute_id=False, allow_all_nodes=False, reset_mac_addresses=False):
"""
@ -36,13 +39,13 @@ async def export_project(zstream, project, temporary_dir, include_images=False,
The file will be read chunk by chunk when you iterate over the zip stream.
Some files like snapshots and packet captures are ignored.
:param zstream: ZipStream object
:param project: Project instance
:param temporary_dir: A temporary dir where to store intermediate data
:param include images: save OS images to the zip file
:param keep_compute_id: If false replace all compute id by local (standard behavior for .gns3project to make it portable)
:param allow_all_nodes: Allow all nodes type to be include in the zip even if not portable
:param reset_mac_addresses: Reset MAC addresses for every nodes.
:returns: ZipStream object
"""
# To avoid issue with data not saved we disallow the export of a running project
@ -80,28 +83,28 @@ async def export_project(zstream, project, temporary_dir, include_images=False,
zstream.write(path, os.path.relpath(path, project._path))
# Export files from remote computes
downloaded_files = set()
for compute in project.computes:
if compute.id != "local":
compute_files = await compute.list_files(project)
for compute_file in compute_files:
if _is_exportable(compute_file["path"]):
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
f = open(fd, "wb", closefd=True)
log.debug("Downloading file '{}' from compute '{}'".format(compute_file["path"], compute.id))
response = await compute.download_file(project, compute_file["path"])
while True:
try:
data = await response.content.read(1024)
except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading file '{}' from remote compute {}:{}".format(compute_file["path"], compute.host, compute.port))
if not data:
break
f.write(data)
#if response.status != 200:
# raise aiohttp.web.HTTPConflict(text="Cannot export file from compute '{}'. Compute returned status code {}.".format(compute.id, response.status))
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
async with aiofiles.open(fd, 'wb') as f:
while True:
try:
data = await response.content.read(CHUNK_SIZE)
except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading file '{}' from remote compute {}:{}".format(compute_file["path"], compute.host, compute.port))
if not data:
break
await f.write(data)
response.close()
f.close()
_patch_mtime(temp_path)
zstream.write(temp_path, arcname=compute_file["path"])
downloaded_files.add(compute_file['path'])
def _patch_mtime(path):
@ -262,30 +265,26 @@ async def _export_remote_images(project, compute_id, image_type, image, project_
Export specific image from remote compute.
"""
log.info("Downloading image '{}' from compute '{}'".format(image, compute_id))
log.debug("Downloading image '{}' from compute '{}'".format(image, compute_id))
try:
compute = [compute for compute in project.computes if compute.id == compute_id][0]
except IndexError:
raise aiohttp.web.HTTPConflict(text="Cannot export image from '{}' compute. Compute doesn't exist.".format(compute_id))
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
f = open(fd, "wb", closefd=True)
response = await compute.download_image(image_type, image)
if response.status != 200:
raise aiohttp.web.HTTPConflict(text="Cannot export image from '{}' compute. Compute returned status code {}.".format(compute_id, response.status))
raise aiohttp.web.HTTPConflict(text="Cannot export image from compute '{}'. Compute returned status code {}.".format(compute_id, response.status))
while True:
try:
data = await response.content.read(1024)
except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading image '{}' from remote compute {}:{}".format(image, compute.host, compute.port))
if not data:
break
f.write(data)
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
async with aiofiles.open(fd, 'wb') as f:
while True:
try:
data = await response.content.read(CHUNK_SIZE)
except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading image '{}' from remote compute {}:{}".format(image, compute.host, compute.port))
if not data:
break
await f.write(data)
response.close()
f.close()
arcname = os.path.join("images", image_type, image)
log.info("Saved {}".format(arcname))
project_zipfile.write(temp_path, arcname=arcname, compress_type=zipfile.ZIP_DEFLATED)

View File

@ -20,7 +20,6 @@ import sys
import json
import uuid
import shutil
import asyncio
import zipfile
import aiohttp
import itertools

View File

@ -971,7 +971,8 @@ class Project:
try:
begin = time.time()
with tempfile.TemporaryDirectory() as tmpdir:
with aiozipstream.ZipFile(compression=zipfile.ZIP_DEFLATED) as zstream:
# Do not compress the exported project when duplicating
with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream:
await export_project(zstream, self, tmpdir, keep_compute_id=True, allow_all_nodes=True, reset_mac_addresses=True)
# export the project to a temporary location
@ -985,7 +986,7 @@ class Project:
with open(project_path, "rb") as f:
project = await import_project(self._controller, str(uuid.uuid4()), f, location=location, name=name, keep_compute_id=True)
log.info("Project '{}' duplicated in {:.4f} seconds".format(project.id, time.time() - begin))
log.info("Project '{}' duplicated in {:.4f} seconds".format(project.name, time.time() - begin))
except (ValueError, OSError, UnicodeEncodeError) as e:
raise aiohttp.web.HTTPConflict(text="Cannot duplicate project: {}".format(str(e)))

View File

@ -101,7 +101,7 @@ class Snapshot:
async with aiofiles.open(self.path, 'wb') as f:
async for chunk in zstream:
await f.write(chunk)
log.info("Snapshot '{}' created in {:.4f} seconds".format(self.path, time.time() - begin))
log.info("Snapshot '{}' created in {:.4f} seconds".format(self.name, time.time() - begin))
except (ValueError, OSError, RuntimeError) as e:
raise aiohttp.web.HTTPConflict(text="Could not create snapshot file '{}': {}".format(self.path, e))

View File

@ -493,7 +493,7 @@ class DynamipsVMHandler:
if filename[0] == ".":
raise aiohttp.web.HTTPForbidden()
await response.file(image_path)
await response.stream_file(image_path)
@Route.post(
r"/projects/{project_id}/dynamips/nodes/{node_id}/duplicate",

View File

@ -451,4 +451,4 @@ class IOUHandler:
if filename[0] == ".":
raise aiohttp.web.HTTPForbidden()
await response.file(image_path)
await response.stream_file(image_path)

View File

@ -37,6 +37,8 @@ from gns3server.schemas.project import (
import logging
log = logging.getLogger()
CHUNK_SIZE = 1024 * 8 # 8KB
class ProjectHandler:
@ -248,64 +250,7 @@ class ProjectHandler:
raise aiohttp.web.HTTPForbidden()
path = os.path.join(project.path, path)
response.content_type = "application/octet-stream"
response.set_status(200)
response.enable_chunked_encoding()
try:
with open(path, "rb") as f:
await response.prepare(request)
while True:
data = f.read(4096)
if not data:
break
await response.write(data)
except FileNotFoundError:
raise aiohttp.web.HTTPNotFound()
except PermissionError:
raise aiohttp.web.HTTPForbidden()
@Route.get(
r"/projects/{project_id}/stream/{path:.+}",
description="Stream a file from a project",
parameters={
"project_id": "Project UUID",
},
status_codes={
200: "File returned",
403: "Permission denied",
404: "The file doesn't exist"
})
async def stream_file(request, response):
pm = ProjectManager.instance()
project = pm.get_project(request.match_info["project_id"])
path = request.match_info["path"]
path = os.path.normpath(path)
# Raise an error if user try to escape
if path[0] == ".":
raise aiohttp.web.HTTPForbidden()
path = os.path.join(project.path, path)
response.content_type = "application/octet-stream"
response.set_status(200)
response.enable_chunked_encoding()
# FIXME: file streaming is never stopped
try:
with open(path, "rb") as f:
await response.prepare(request)
while True:
data = f.read(4096)
if not data:
await asyncio.sleep(0.1)
await response.write(data)
except FileNotFoundError:
raise aiohttp.web.HTTPNotFound()
except PermissionError:
raise aiohttp.web.HTTPForbidden()
await response.stream_file(path)
@Route.post(
r"/projects/{project_id}/files/{path:.+}",
@ -338,7 +283,7 @@ class ProjectHandler:
with open(path, 'wb+') as f:
while True:
try:
chunk = await request.content.read(1024)
chunk = await request.content.read(CHUNK_SIZE)
except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path))
if not chunk:
@ -349,64 +294,3 @@ class ProjectHandler:
raise aiohttp.web.HTTPNotFound()
except PermissionError:
raise aiohttp.web.HTTPForbidden()
@Route.get(
r"/projects/{project_id}/export",
description="Export a project as a portable archive",
parameters={
"project_id": "Project UUID",
},
raw=True,
status_codes={
200: "File returned",
404: "The project doesn't exist"
})
async def export_project(request, response):
pm = ProjectManager.instance()
project = pm.get_project(request.match_info["project_id"])
response.content_type = 'application/gns3project'
response.headers['CONTENT-DISPOSITION'] = 'attachment; filename="{}.gns3project"'.format(project.name)
response.enable_chunked_encoding()
await response.prepare(request)
include_images = bool(int(request.json.get("include_images", "0")))
for data in project.export(include_images=include_images):
await response.write(data)
#await response.write_eof() #FIXME: shound't be needed anymore
@Route.post(
r"/projects/{project_id}/import",
description="Import a project from a portable archive",
parameters={
"project_id": "Project UUID",
},
raw=True,
output=PROJECT_OBJECT_SCHEMA,
status_codes={
200: "Project imported",
403: "Forbidden to import project"
})
async def import_project(request, response):
pm = ProjectManager.instance()
project_id = request.match_info["project_id"]
project = pm.create_project(project_id=project_id)
# We write the content to a temporary location and after we extract it all.
# It could be more optimal to stream this but it is not implemented in Python.
# Spooled means the file is temporary kept in memory until max_size is reached
try:
with tempfile.SpooledTemporaryFile(max_size=10000) as temp:
while True:
chunk = await request.content.read(1024)
if not chunk:
break
temp.write(chunk)
project.import_zip(temp, gns3vm=bool(int(request.GET.get("gns3vm", "1"))))
except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
response.json(project)
response.set_status(201)

View File

@ -576,4 +576,4 @@ class QEMUHandler:
if filename[0] == ".":
raise aiohttp.web.HTTPForbidden()
await response.file(image_path)
await response.stream_file(image_path)

View File

@ -16,11 +16,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import sys
import aiohttp
import asyncio
import tempfile
import zipfile
import aiofiles
import time
from gns3server.web.route import Route
@ -51,6 +51,8 @@ async def process_websocket(ws):
except aiohttp.WSServerHandshakeError:
pass
CHUNK_SIZE = 1024 * 8 # 8KB
class ProjectHandler:
@ -304,7 +306,6 @@ class ProjectHandler:
controller = Controller.instance()
project = await controller.get_loaded_project(request.match_info["project_id"])
try:
begin = time.time()
with tempfile.TemporaryDirectory() as tmp_dir:
@ -321,8 +322,8 @@ class ProjectHandler:
async for chunk in zstream:
await response.write(chunk)
log.info("Project '{}' exported in {:.4f} seconds".format(project.id, time.time() - begin))
#await response.write_eof() #FIXME: shound't be needed anymore
log.info("Project '{}' exported in {:.4f} seconds".format(project.name, time.time() - begin))
# Will be raise if you have no space left or permission issue on your temporary directory
# RuntimeError: something was wrong during the zip process
except (ValueError, OSError, RuntimeError) as e:
@ -354,29 +355,23 @@ class ProjectHandler:
# We write the content to a temporary location and after we extract it all.
# It could be more optimal to stream this but it is not implemented in Python.
# Spooled means the file is temporary kept in memory until max_size is reached
# Cannot use tempfile.SpooledTemporaryFile(max_size=10000) in Python 3.7 due
# to a bug https://bugs.python.org/issue26175
try:
if sys.version_info >= (3, 7) and sys.version_info < (3, 8):
with tempfile.TemporaryFile() as temp:
begin = time.time()
with tempfile.TemporaryDirectory() as tmpdir:
temp_project_path = os.path.join(tmpdir, "project.zip")
async with aiofiles.open(temp_project_path, 'wb') as f:
while True:
chunk = await request.content.read(1024)
chunk = await request.content.read(CHUNK_SIZE)
if not chunk:
break
temp.write(chunk)
project = await import_project(controller, request.match_info["project_id"], temp, location=path, name=name)
else:
with tempfile.SpooledTemporaryFile(max_size=10000) as temp:
while True:
chunk = await request.content.read(1024)
if not chunk:
break
temp.write(chunk)
project = await import_project(controller, request.match_info["project_id"], temp, location=path, name=name)
await f.write(chunk)
with open(temp_project_path, "rb") as f:
project = await import_project(controller, request.match_info["project_id"], f, location=path, name=name)
log.info("Project '{}' imported in {:.4f} seconds".format(project.name, time.time() - begin))
except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
response.json(project)
response.set_status(201)
@ -443,7 +438,7 @@ class ProjectHandler:
with open(path, "rb") as f:
await response.prepare(request)
while True:
data = f.read(4096)
data = f.read(CHUNK_SIZE)
if not data:
break
await response.write(data)
@ -483,7 +478,7 @@ class ProjectHandler:
with open(path, 'wb+') as f:
while True:
try:
chunk = await request.content.read(1024)
chunk = await request.content.read(CHUNK_SIZE)
except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path))
if not chunk:

View File

@ -53,7 +53,7 @@ class SymbolHandler:
controller = Controller.instance()
try:
await response.file(controller.symbols.get_path(request.match_info["symbol_id"]))
await response.stream_file(controller.symbols.get_path(request.match_info["symbol_id"]))
except (KeyError, OSError) as e:
log.warning("Could not get symbol file: {}".format(e))
response.set_status(404)

View File

@ -92,7 +92,7 @@ class IndexHandler:
if not os.path.exists(static):
static = get_static_path(os.path.join('web-ui', 'index.html'))
await response.file(static)
await response.stream_file(static)
@Route.get(
r"/v1/version",

View File

@ -20,7 +20,7 @@ import jsonschema
import aiohttp
import aiohttp.web
import mimetypes
import asyncio
import aiofiles
import logging
import jinja2
import sys
@ -32,6 +32,8 @@ from ..version import __version__
log = logging.getLogger(__name__)
renderer = jinja2.Environment(loader=jinja2.FileSystemLoader(get_resource('templates')))
CHUNK_SIZE = 1024 * 8 # 8KB
class Response(aiohttp.web.Response):
@ -112,16 +114,21 @@ class Response(aiohttp.web.Response):
raise aiohttp.web.HTTPBadRequest(text="{}".format(e))
self.body = json.dumps(answer, indent=4, sort_keys=True).encode('utf-8')
async def file(self, path, status=200, set_content_length=True):
async def stream_file(self, path, status=200, set_content_type=None, set_content_length=True):
"""
Return a file as a response
Stream a file as a response
"""
if not os.path.exists(path):
raise aiohttp.web.HTTPNotFound()
ct, encoding = mimetypes.guess_type(path)
if not ct:
ct = 'application/octet-stream'
if not set_content_type:
ct, encoding = mimetypes.guess_type(path)
if not ct:
ct = 'application/octet-stream'
else:
ct = set_content_type
if encoding:
self.headers[aiohttp.hdrs.CONTENT_ENCODING] = encoding
self.content_type = ct
@ -136,16 +143,13 @@ class Response(aiohttp.web.Response):
self.set_status(status)
try:
with open(path, 'rb') as fobj:
async with aiofiles.open(path, 'rb') as f:
await self.prepare(self._request)
while True:
data = fobj.read(4096)
data = await f.read(CHUNK_SIZE)
if not data:
break
await self.write(data)
# await self.drain()
except FileNotFoundError:
raise aiohttp.web.HTTPNotFound()
except PermissionError:

View File

@ -293,15 +293,6 @@ def test_json(compute):
}
def test_streamFile(project, async_run, compute):
response = MagicMock()
response.status = 200
with asyncio_patch("aiohttp.ClientSession.request", return_value=response) as mock:
async_run(compute.stream_file(project, "test/titi", timeout=120))
mock.assert_called_with("GET", "https://example.com:84/v2/compute/projects/{}/stream/test/titi".format(project.id), auth=None, timeout=120)
async_run(compute.close())
def test_downloadFile(project, async_run, compute):
response = MagicMock()
response.status = 200
@ -310,6 +301,7 @@ def test_downloadFile(project, async_run, compute):
mock.assert_called_with("GET", "https://example.com:84/v2/compute/projects/{}/files/test/titi".format(project.id), auth=None)
async_run(compute.close())
def test_close(compute, async_run):
assert compute.connected is True
async_run(compute.close())

View File

@ -34,11 +34,11 @@ def test_response_file(async_run, tmpdir, response):
with open(filename, 'w+') as f:
f.write('world')
async_run(response.file(filename))
async_run(response.stream_file(filename))
assert response.status == 200
def test_response_file_not_found(async_run, tmpdir, response):
filename = str(tmpdir / 'hello-not-found')
pytest.raises(HTTPNotFound, lambda: async_run(response.file(filename)))
pytest.raises(HTTPNotFound, lambda: async_run(response.stream_file(filename)))