Use aiofiles where relevant.

This commit is contained in:
grossmj 2019-03-06 23:00:01 +07:00
parent b0df7ecabf
commit af80b0bb6e
17 changed files with 90 additions and 234 deletions

View File

@ -20,6 +20,7 @@ import os
import struct import struct
import stat import stat
import asyncio import asyncio
import aiofiles
import aiohttp import aiohttp
import socket import socket
@ -46,6 +47,8 @@ from .nios.nio_ethernet import NIOEthernet
from ..utils.images import md5sum, remove_checksum, images_directories, default_images_directory, list_images from ..utils.images import md5sum, remove_checksum, images_directories, default_images_directory, list_images
from .error import NodeError, ImageMissingError from .error import NodeError, ImageMissingError
CHUNK_SIZE = 1024 * 8 # 8KB
class BaseManager: class BaseManager:
@ -456,7 +459,7 @@ class BaseManager:
with open(path, "rb") as f: with open(path, "rb") as f:
await response.prepare(request) await response.prepare(request)
while nio.capturing: while nio.capturing:
data = f.read(4096) data = f.read(CHUNK_SIZE)
if not data: if not data:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
@ -594,18 +597,18 @@ class BaseManager:
path = os.path.abspath(os.path.join(directory, *os.path.split(filename))) path = os.path.abspath(os.path.join(directory, *os.path.split(filename)))
if os.path.commonprefix([directory, path]) != directory: if os.path.commonprefix([directory, path]) != directory:
raise aiohttp.web.HTTPForbidden(text="Could not write image: {}, {} is forbidden".format(filename, path)) raise aiohttp.web.HTTPForbidden(text="Could not write image: {}, {} is forbidden".format(filename, path))
log.info("Writing image file %s", path) log.info("Writing image file to '{}'".format(path))
try: try:
remove_checksum(path) remove_checksum(path)
# We store the file under his final name only when the upload is finished # We store the file under his final name only when the upload is finished
tmp_path = path + ".tmp" tmp_path = path + ".tmp"
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
with open(tmp_path, 'wb') as f: async with aiofiles.open(tmp_path, 'wb') as f:
while True: while True:
packet = await stream.read(4096) chunk = await stream.read(CHUNK_SIZE)
if not packet: if not chunk:
break break
f.write(packet) await f.write(chunk)
os.chmod(tmp_path, stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC) os.chmod(tmp_path, stat.S_IWRITE | stat.S_IREAD | stat.S_IEXEC)
shutil.move(tmp_path, path) shutil.move(tmp_path, path)
await cancellable_wait_run_in_executor(md5sum, path) await cancellable_wait_run_in_executor(md5sum, path)

View File

@ -37,6 +37,7 @@ log = logging.getLogger(__name__)
DOCKER_MINIMUM_API_VERSION = "1.25" DOCKER_MINIMUM_API_VERSION = "1.25"
DOCKER_MINIMUM_VERSION = "1.13" DOCKER_MINIMUM_VERSION = "1.13"
DOCKER_PREFERRED_API_VERSION = "1.30" DOCKER_PREFERRED_API_VERSION = "1.30"
CHUNK_SIZE = 1024 * 8 # 8KB
class Docker(BaseManager): class Docker(BaseManager):
@ -206,7 +207,7 @@ class Docker(BaseManager):
content = "" content = ""
while True: while True:
try: try:
chunk = await response.content.read(1024) chunk = await response.content.read(CHUNK_SIZE)
except aiohttp.ServerDisconnectedError: except aiohttp.ServerDisconnectedError:
log.error("Disconnected from server while pulling Docker image '{}' from docker hub".format(image)) log.error("Disconnected from server while pulling Docker image '{}' from docker hub".format(image))
break break

View File

@ -320,28 +320,6 @@ class Compute:
raise aiohttp.web.HTTPNotFound(text="{} not found on compute".format(image)) raise aiohttp.web.HTTPNotFound(text="{} not found on compute".format(image))
return response return response
async def stream_file(self, project, path, timeout=None):
"""
Read file of a project and stream it
:param project: A project object
:param path: The path of the file in the project
:param timeout: timeout
:returns: A file stream
"""
url = self._getUrl("/projects/{}/stream/{}".format(project.id, path))
response = await self._session().request("GET", url, auth=self._auth, timeout=timeout)
if response.status == 404:
raise aiohttp.web.HTTPNotFound(text="file '{}' not found on compute".format(path))
elif response.status == 403:
raise aiohttp.web.HTTPForbidden(text="forbidden to open '{}' on compute".format(path))
elif response.status != 200:
raise aiohttp.web.HTTPInternalServerError(text="Unexpected error {}: {}: while opening {} on compute".format(response.status,
response.reason,
path))
return response
async def http_query(self, method, path, data=None, dont_connect=False, **kwargs): async def http_query(self, method, path, data=None, dont_connect=False, **kwargs):
""" """
:param dont_connect: If true do not reconnect if not connected :param dont_connect: If true do not reconnect if not connected

View File

@ -19,6 +19,7 @@ import os
import sys import sys
import json import json
import asyncio import asyncio
import aiofiles
import aiohttp import aiohttp
import zipfile import zipfile
import tempfile import tempfile
@ -28,6 +29,8 @@ from datetime import datetime
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
CHUNK_SIZE = 1024 * 8 # 8KB
async def export_project(zstream, project, temporary_dir, include_images=False, keep_compute_id=False, allow_all_nodes=False, reset_mac_addresses=False): async def export_project(zstream, project, temporary_dir, include_images=False, keep_compute_id=False, allow_all_nodes=False, reset_mac_addresses=False):
""" """
@ -36,13 +39,13 @@ async def export_project(zstream, project, temporary_dir, include_images=False,
The file will be read chunk by chunk when you iterate over the zip stream. The file will be read chunk by chunk when you iterate over the zip stream.
Some files like snapshots and packet captures are ignored. Some files like snapshots and packet captures are ignored.
:param zstream: ZipStream object
:param project: Project instance
:param temporary_dir: A temporary dir where to store intermediate data :param temporary_dir: A temporary dir where to store intermediate data
:param include images: save OS images to the zip file :param include images: save OS images to the zip file
:param keep_compute_id: If false replace all compute id by local (standard behavior for .gns3project to make it portable) :param keep_compute_id: If false replace all compute id by local (standard behavior for .gns3project to make it portable)
:param allow_all_nodes: Allow all nodes type to be include in the zip even if not portable :param allow_all_nodes: Allow all nodes type to be include in the zip even if not portable
:param reset_mac_addresses: Reset MAC addresses for every nodes. :param reset_mac_addresses: Reset MAC addresses for every nodes.
:returns: ZipStream object
""" """
# To avoid issue with data not saved we disallow the export of a running project # To avoid issue with data not saved we disallow the export of a running project
@ -80,28 +83,28 @@ async def export_project(zstream, project, temporary_dir, include_images=False,
zstream.write(path, os.path.relpath(path, project._path)) zstream.write(path, os.path.relpath(path, project._path))
# Export files from remote computes # Export files from remote computes
downloaded_files = set()
for compute in project.computes: for compute in project.computes:
if compute.id != "local": if compute.id != "local":
compute_files = await compute.list_files(project) compute_files = await compute.list_files(project)
for compute_file in compute_files: for compute_file in compute_files:
if _is_exportable(compute_file["path"]): if _is_exportable(compute_file["path"]):
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir) log.debug("Downloading file '{}' from compute '{}'".format(compute_file["path"], compute.id))
f = open(fd, "wb", closefd=True)
response = await compute.download_file(project, compute_file["path"]) response = await compute.download_file(project, compute_file["path"])
while True: #if response.status != 200:
try: # raise aiohttp.web.HTTPConflict(text="Cannot export file from compute '{}'. Compute returned status code {}.".format(compute.id, response.status))
data = await response.content.read(1024) (fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
except asyncio.TimeoutError: async with aiofiles.open(fd, 'wb') as f:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading file '{}' from remote compute {}:{}".format(compute_file["path"], compute.host, compute.port)) while True:
if not data: try:
break data = await response.content.read(CHUNK_SIZE)
f.write(data) except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading file '{}' from remote compute {}:{}".format(compute_file["path"], compute.host, compute.port))
if not data:
break
await f.write(data)
response.close() response.close()
f.close()
_patch_mtime(temp_path) _patch_mtime(temp_path)
zstream.write(temp_path, arcname=compute_file["path"]) zstream.write(temp_path, arcname=compute_file["path"])
downloaded_files.add(compute_file['path'])
def _patch_mtime(path): def _patch_mtime(path):
@ -262,30 +265,26 @@ async def _export_remote_images(project, compute_id, image_type, image, project_
Export specific image from remote compute. Export specific image from remote compute.
""" """
log.info("Downloading image '{}' from compute '{}'".format(image, compute_id)) log.debug("Downloading image '{}' from compute '{}'".format(image, compute_id))
try: try:
compute = [compute for compute in project.computes if compute.id == compute_id][0] compute = [compute for compute in project.computes if compute.id == compute_id][0]
except IndexError: except IndexError:
raise aiohttp.web.HTTPConflict(text="Cannot export image from '{}' compute. Compute doesn't exist.".format(compute_id)) raise aiohttp.web.HTTPConflict(text="Cannot export image from '{}' compute. Compute doesn't exist.".format(compute_id))
(fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
f = open(fd, "wb", closefd=True)
response = await compute.download_image(image_type, image) response = await compute.download_image(image_type, image)
if response.status != 200: if response.status != 200:
raise aiohttp.web.HTTPConflict(text="Cannot export image from '{}' compute. Compute returned status code {}.".format(compute_id, response.status)) raise aiohttp.web.HTTPConflict(text="Cannot export image from compute '{}'. Compute returned status code {}.".format(compute_id, response.status))
while True: (fd, temp_path) = tempfile.mkstemp(dir=temporary_dir)
try: async with aiofiles.open(fd, 'wb') as f:
data = await response.content.read(1024) while True:
except asyncio.TimeoutError: try:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading image '{}' from remote compute {}:{}".format(image, compute.host, compute.port)) data = await response.content.read(CHUNK_SIZE)
if not data: except asyncio.TimeoutError:
break raise aiohttp.web.HTTPRequestTimeout(text="Timeout when downloading image '{}' from remote compute {}:{}".format(image, compute.host, compute.port))
f.write(data) if not data:
break
await f.write(data)
response.close() response.close()
f.close()
arcname = os.path.join("images", image_type, image) arcname = os.path.join("images", image_type, image)
log.info("Saved {}".format(arcname))
project_zipfile.write(temp_path, arcname=arcname, compress_type=zipfile.ZIP_DEFLATED) project_zipfile.write(temp_path, arcname=arcname, compress_type=zipfile.ZIP_DEFLATED)

View File

@ -20,7 +20,6 @@ import sys
import json import json
import uuid import uuid
import shutil import shutil
import asyncio
import zipfile import zipfile
import aiohttp import aiohttp
import itertools import itertools

View File

@ -971,7 +971,8 @@ class Project:
try: try:
begin = time.time() begin = time.time()
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
with aiozipstream.ZipFile(compression=zipfile.ZIP_DEFLATED) as zstream: # Do not compress the exported project when duplicating
with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream:
await export_project(zstream, self, tmpdir, keep_compute_id=True, allow_all_nodes=True, reset_mac_addresses=True) await export_project(zstream, self, tmpdir, keep_compute_id=True, allow_all_nodes=True, reset_mac_addresses=True)
# export the project to a temporary location # export the project to a temporary location
@ -985,7 +986,7 @@ class Project:
with open(project_path, "rb") as f: with open(project_path, "rb") as f:
project = await import_project(self._controller, str(uuid.uuid4()), f, location=location, name=name, keep_compute_id=True) project = await import_project(self._controller, str(uuid.uuid4()), f, location=location, name=name, keep_compute_id=True)
log.info("Project '{}' duplicated in {:.4f} seconds".format(project.id, time.time() - begin)) log.info("Project '{}' duplicated in {:.4f} seconds".format(project.name, time.time() - begin))
except (ValueError, OSError, UnicodeEncodeError) as e: except (ValueError, OSError, UnicodeEncodeError) as e:
raise aiohttp.web.HTTPConflict(text="Cannot duplicate project: {}".format(str(e))) raise aiohttp.web.HTTPConflict(text="Cannot duplicate project: {}".format(str(e)))

View File

@ -101,7 +101,7 @@ class Snapshot:
async with aiofiles.open(self.path, 'wb') as f: async with aiofiles.open(self.path, 'wb') as f:
async for chunk in zstream: async for chunk in zstream:
await f.write(chunk) await f.write(chunk)
log.info("Snapshot '{}' created in {:.4f} seconds".format(self.path, time.time() - begin)) log.info("Snapshot '{}' created in {:.4f} seconds".format(self.name, time.time() - begin))
except (ValueError, OSError, RuntimeError) as e: except (ValueError, OSError, RuntimeError) as e:
raise aiohttp.web.HTTPConflict(text="Could not create snapshot file '{}': {}".format(self.path, e)) raise aiohttp.web.HTTPConflict(text="Could not create snapshot file '{}': {}".format(self.path, e))

View File

@ -493,7 +493,7 @@ class DynamipsVMHandler:
if filename[0] == ".": if filename[0] == ".":
raise aiohttp.web.HTTPForbidden() raise aiohttp.web.HTTPForbidden()
await response.file(image_path) await response.stream_file(image_path)
@Route.post( @Route.post(
r"/projects/{project_id}/dynamips/nodes/{node_id}/duplicate", r"/projects/{project_id}/dynamips/nodes/{node_id}/duplicate",

View File

@ -451,4 +451,4 @@ class IOUHandler:
if filename[0] == ".": if filename[0] == ".":
raise aiohttp.web.HTTPForbidden() raise aiohttp.web.HTTPForbidden()
await response.file(image_path) await response.stream_file(image_path)

View File

@ -37,6 +37,8 @@ from gns3server.schemas.project import (
import logging import logging
log = logging.getLogger() log = logging.getLogger()
CHUNK_SIZE = 1024 * 8 # 8KB
class ProjectHandler: class ProjectHandler:
@ -248,64 +250,7 @@ class ProjectHandler:
raise aiohttp.web.HTTPForbidden() raise aiohttp.web.HTTPForbidden()
path = os.path.join(project.path, path) path = os.path.join(project.path, path)
response.content_type = "application/octet-stream" await response.stream_file(path)
response.set_status(200)
response.enable_chunked_encoding()
try:
with open(path, "rb") as f:
await response.prepare(request)
while True:
data = f.read(4096)
if not data:
break
await response.write(data)
except FileNotFoundError:
raise aiohttp.web.HTTPNotFound()
except PermissionError:
raise aiohttp.web.HTTPForbidden()
@Route.get(
r"/projects/{project_id}/stream/{path:.+}",
description="Stream a file from a project",
parameters={
"project_id": "Project UUID",
},
status_codes={
200: "File returned",
403: "Permission denied",
404: "The file doesn't exist"
})
async def stream_file(request, response):
pm = ProjectManager.instance()
project = pm.get_project(request.match_info["project_id"])
path = request.match_info["path"]
path = os.path.normpath(path)
# Raise an error if user try to escape
if path[0] == ".":
raise aiohttp.web.HTTPForbidden()
path = os.path.join(project.path, path)
response.content_type = "application/octet-stream"
response.set_status(200)
response.enable_chunked_encoding()
# FIXME: file streaming is never stopped
try:
with open(path, "rb") as f:
await response.prepare(request)
while True:
data = f.read(4096)
if not data:
await asyncio.sleep(0.1)
await response.write(data)
except FileNotFoundError:
raise aiohttp.web.HTTPNotFound()
except PermissionError:
raise aiohttp.web.HTTPForbidden()
@Route.post( @Route.post(
r"/projects/{project_id}/files/{path:.+}", r"/projects/{project_id}/files/{path:.+}",
@ -338,7 +283,7 @@ class ProjectHandler:
with open(path, 'wb+') as f: with open(path, 'wb+') as f:
while True: while True:
try: try:
chunk = await request.content.read(1024) chunk = await request.content.read(CHUNK_SIZE)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path)) raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path))
if not chunk: if not chunk:
@ -349,64 +294,3 @@ class ProjectHandler:
raise aiohttp.web.HTTPNotFound() raise aiohttp.web.HTTPNotFound()
except PermissionError: except PermissionError:
raise aiohttp.web.HTTPForbidden() raise aiohttp.web.HTTPForbidden()
@Route.get(
r"/projects/{project_id}/export",
description="Export a project as a portable archive",
parameters={
"project_id": "Project UUID",
},
raw=True,
status_codes={
200: "File returned",
404: "The project doesn't exist"
})
async def export_project(request, response):
pm = ProjectManager.instance()
project = pm.get_project(request.match_info["project_id"])
response.content_type = 'application/gns3project'
response.headers['CONTENT-DISPOSITION'] = 'attachment; filename="{}.gns3project"'.format(project.name)
response.enable_chunked_encoding()
await response.prepare(request)
include_images = bool(int(request.json.get("include_images", "0")))
for data in project.export(include_images=include_images):
await response.write(data)
#await response.write_eof() #FIXME: shound't be needed anymore
@Route.post(
r"/projects/{project_id}/import",
description="Import a project from a portable archive",
parameters={
"project_id": "Project UUID",
},
raw=True,
output=PROJECT_OBJECT_SCHEMA,
status_codes={
200: "Project imported",
403: "Forbidden to import project"
})
async def import_project(request, response):
pm = ProjectManager.instance()
project_id = request.match_info["project_id"]
project = pm.create_project(project_id=project_id)
# We write the content to a temporary location and after we extract it all.
# It could be more optimal to stream this but it is not implemented in Python.
# Spooled means the file is temporary kept in memory until max_size is reached
try:
with tempfile.SpooledTemporaryFile(max_size=10000) as temp:
while True:
chunk = await request.content.read(1024)
if not chunk:
break
temp.write(chunk)
project.import_zip(temp, gns3vm=bool(int(request.GET.get("gns3vm", "1"))))
except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
response.json(project)
response.set_status(201)

View File

@ -576,4 +576,4 @@ class QEMUHandler:
if filename[0] == ".": if filename[0] == ".":
raise aiohttp.web.HTTPForbidden() raise aiohttp.web.HTTPForbidden()
await response.file(image_path) await response.stream_file(image_path)

View File

@ -16,11 +16,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os import os
import sys
import aiohttp import aiohttp
import asyncio import asyncio
import tempfile import tempfile
import zipfile import zipfile
import aiofiles
import time import time
from gns3server.web.route import Route from gns3server.web.route import Route
@ -51,6 +51,8 @@ async def process_websocket(ws):
except aiohttp.WSServerHandshakeError: except aiohttp.WSServerHandshakeError:
pass pass
CHUNK_SIZE = 1024 * 8 # 8KB
class ProjectHandler: class ProjectHandler:
@ -304,7 +306,6 @@ class ProjectHandler:
controller = Controller.instance() controller = Controller.instance()
project = await controller.get_loaded_project(request.match_info["project_id"]) project = await controller.get_loaded_project(request.match_info["project_id"])
try: try:
begin = time.time() begin = time.time()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
@ -321,8 +322,8 @@ class ProjectHandler:
async for chunk in zstream: async for chunk in zstream:
await response.write(chunk) await response.write(chunk)
log.info("Project '{}' exported in {:.4f} seconds".format(project.id, time.time() - begin)) log.info("Project '{}' exported in {:.4f} seconds".format(project.name, time.time() - begin))
#await response.write_eof() #FIXME: shound't be needed anymore
# Will be raise if you have no space left or permission issue on your temporary directory # Will be raise if you have no space left or permission issue on your temporary directory
# RuntimeError: something was wrong during the zip process # RuntimeError: something was wrong during the zip process
except (ValueError, OSError, RuntimeError) as e: except (ValueError, OSError, RuntimeError) as e:
@ -354,29 +355,23 @@ class ProjectHandler:
# We write the content to a temporary location and after we extract it all. # We write the content to a temporary location and after we extract it all.
# It could be more optimal to stream this but it is not implemented in Python. # It could be more optimal to stream this but it is not implemented in Python.
# Spooled means the file is temporary kept in memory until max_size is reached
# Cannot use tempfile.SpooledTemporaryFile(max_size=10000) in Python 3.7 due
# to a bug https://bugs.python.org/issue26175
try: try:
if sys.version_info >= (3, 7) and sys.version_info < (3, 8): begin = time.time()
with tempfile.TemporaryFile() as temp: with tempfile.TemporaryDirectory() as tmpdir:
temp_project_path = os.path.join(tmpdir, "project.zip")
async with aiofiles.open(temp_project_path, 'wb') as f:
while True: while True:
chunk = await request.content.read(1024) chunk = await request.content.read(CHUNK_SIZE)
if not chunk: if not chunk:
break break
temp.write(chunk) await f.write(chunk)
project = await import_project(controller, request.match_info["project_id"], temp, location=path, name=name)
else: with open(temp_project_path, "rb") as f:
with tempfile.SpooledTemporaryFile(max_size=10000) as temp: project = await import_project(controller, request.match_info["project_id"], f, location=path, name=name)
while True:
chunk = await request.content.read(1024) log.info("Project '{}' imported in {:.4f} seconds".format(project.name, time.time() - begin))
if not chunk:
break
temp.write(chunk)
project = await import_project(controller, request.match_info["project_id"], temp, location=path, name=name)
except OSError as e: except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e)) raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
response.json(project) response.json(project)
response.set_status(201) response.set_status(201)
@ -443,7 +438,7 @@ class ProjectHandler:
with open(path, "rb") as f: with open(path, "rb") as f:
await response.prepare(request) await response.prepare(request)
while True: while True:
data = f.read(4096) data = f.read(CHUNK_SIZE)
if not data: if not data:
break break
await response.write(data) await response.write(data)
@ -483,7 +478,7 @@ class ProjectHandler:
with open(path, 'wb+') as f: with open(path, 'wb+') as f:
while True: while True:
try: try:
chunk = await request.content.read(1024) chunk = await request.content.read(CHUNK_SIZE)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path)) raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path))
if not chunk: if not chunk:

View File

@ -53,7 +53,7 @@ class SymbolHandler:
controller = Controller.instance() controller = Controller.instance()
try: try:
await response.file(controller.symbols.get_path(request.match_info["symbol_id"])) await response.stream_file(controller.symbols.get_path(request.match_info["symbol_id"]))
except (KeyError, OSError) as e: except (KeyError, OSError) as e:
log.warning("Could not get symbol file: {}".format(e)) log.warning("Could not get symbol file: {}".format(e))
response.set_status(404) response.set_status(404)

View File

@ -92,7 +92,7 @@ class IndexHandler:
if not os.path.exists(static): if not os.path.exists(static):
static = get_static_path(os.path.join('web-ui', 'index.html')) static = get_static_path(os.path.join('web-ui', 'index.html'))
await response.file(static) await response.stream_file(static)
@Route.get( @Route.get(
r"/v1/version", r"/v1/version",

View File

@ -20,7 +20,7 @@ import jsonschema
import aiohttp import aiohttp
import aiohttp.web import aiohttp.web
import mimetypes import mimetypes
import asyncio import aiofiles
import logging import logging
import jinja2 import jinja2
import sys import sys
@ -32,6 +32,8 @@ from ..version import __version__
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
renderer = jinja2.Environment(loader=jinja2.FileSystemLoader(get_resource('templates'))) renderer = jinja2.Environment(loader=jinja2.FileSystemLoader(get_resource('templates')))
CHUNK_SIZE = 1024 * 8 # 8KB
class Response(aiohttp.web.Response): class Response(aiohttp.web.Response):
@ -112,16 +114,21 @@ class Response(aiohttp.web.Response):
raise aiohttp.web.HTTPBadRequest(text="{}".format(e)) raise aiohttp.web.HTTPBadRequest(text="{}".format(e))
self.body = json.dumps(answer, indent=4, sort_keys=True).encode('utf-8') self.body = json.dumps(answer, indent=4, sort_keys=True).encode('utf-8')
async def file(self, path, status=200, set_content_length=True): async def stream_file(self, path, status=200, set_content_type=None, set_content_length=True):
""" """
Return a file as a response Stream a file as a response
""" """
if not os.path.exists(path): if not os.path.exists(path):
raise aiohttp.web.HTTPNotFound() raise aiohttp.web.HTTPNotFound()
ct, encoding = mimetypes.guess_type(path) if not set_content_type:
if not ct: ct, encoding = mimetypes.guess_type(path)
ct = 'application/octet-stream' if not ct:
ct = 'application/octet-stream'
else:
ct = set_content_type
if encoding: if encoding:
self.headers[aiohttp.hdrs.CONTENT_ENCODING] = encoding self.headers[aiohttp.hdrs.CONTENT_ENCODING] = encoding
self.content_type = ct self.content_type = ct
@ -136,16 +143,13 @@ class Response(aiohttp.web.Response):
self.set_status(status) self.set_status(status)
try: try:
with open(path, 'rb') as fobj: async with aiofiles.open(path, 'rb') as f:
await self.prepare(self._request) await self.prepare(self._request)
while True: while True:
data = fobj.read(4096) data = await f.read(CHUNK_SIZE)
if not data: if not data:
break break
await self.write(data) await self.write(data)
# await self.drain()
except FileNotFoundError: except FileNotFoundError:
raise aiohttp.web.HTTPNotFound() raise aiohttp.web.HTTPNotFound()
except PermissionError: except PermissionError:

View File

@ -293,15 +293,6 @@ def test_json(compute):
} }
def test_streamFile(project, async_run, compute):
response = MagicMock()
response.status = 200
with asyncio_patch("aiohttp.ClientSession.request", return_value=response) as mock:
async_run(compute.stream_file(project, "test/titi", timeout=120))
mock.assert_called_with("GET", "https://example.com:84/v2/compute/projects/{}/stream/test/titi".format(project.id), auth=None, timeout=120)
async_run(compute.close())
def test_downloadFile(project, async_run, compute): def test_downloadFile(project, async_run, compute):
response = MagicMock() response = MagicMock()
response.status = 200 response.status = 200
@ -310,6 +301,7 @@ def test_downloadFile(project, async_run, compute):
mock.assert_called_with("GET", "https://example.com:84/v2/compute/projects/{}/files/test/titi".format(project.id), auth=None) mock.assert_called_with("GET", "https://example.com:84/v2/compute/projects/{}/files/test/titi".format(project.id), auth=None)
async_run(compute.close()) async_run(compute.close())
def test_close(compute, async_run): def test_close(compute, async_run):
assert compute.connected is True assert compute.connected is True
async_run(compute.close()) async_run(compute.close())

View File

@ -34,11 +34,11 @@ def test_response_file(async_run, tmpdir, response):
with open(filename, 'w+') as f: with open(filename, 'w+') as f:
f.write('world') f.write('world')
async_run(response.file(filename)) async_run(response.stream_file(filename))
assert response.status == 200 assert response.status == 200
def test_response_file_not_found(async_run, tmpdir, response): def test_response_file_not_found(async_run, tmpdir, response):
filename = str(tmpdir / 'hello-not-found') filename = str(tmpdir / 'hello-not-found')
pytest.raises(HTTPNotFound, lambda: async_run(response.file(filename))) pytest.raises(HTTPNotFound, lambda: async_run(response.stream_file(filename)))