Merge pull request #2070 from GNS3/project-export-zstd

zstandard compression support for project export
This commit is contained in:
Jeremy Grossmann 2022-06-03 11:31:25 +07:00 committed by GitHub
commit 466aaf5c13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 265 additions and 39 deletions

View File

@ -21,10 +21,10 @@ API routes for projects.
import os
import asyncio
import tempfile
import zipfile
import aiofiles
import time
import urllib.parse
import gns3server.utils.zipfile_zstd as zipfile
import logging
@ -41,7 +41,7 @@ from pathlib import Path
from gns3server import schemas
from gns3server.controller import Controller
from gns3server.controller.project import Project
from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError
from gns3server.controller.controller_error import ControllerError, ControllerBadRequestError
from gns3server.controller.import_project import import_project as import_controller_project
from gns3server.controller.export_project import export_project as export_controller_project
from gns3server.utils.asyncio import aiozipstream
@ -285,7 +285,8 @@ async def export_project(
include_snapshots: bool = False,
include_images: bool = False,
reset_mac_addresses: bool = False,
compression: str = "zip",
compression: schemas.ProjectCompression = "zstd",
compression_level: int = None,
) -> StreamingResponse:
"""
Export a project as a portable archive.
@ -294,12 +295,23 @@ async def export_project(
compression_query = compression.lower()
if compression_query == "zip":
compression = zipfile.ZIP_DEFLATED
if compression_level is not None and (compression_level < 0 or compression_level > 9):
raise ControllerBadRequestError("Compression level must be between 0 and 9 for ZIP compression")
elif compression_query == "none":
compression = zipfile.ZIP_STORED
elif compression_query == "bzip2":
compression = zipfile.ZIP_BZIP2
if compression_level is not None and (compression_level < 1 or compression_level > 9):
raise ControllerBadRequestError("Compression level must be between 1 and 9 for BZIP2 compression")
elif compression_query == "lzma":
compression = zipfile.ZIP_LZMA
elif compression_query == "zstd":
compression = zipfile.ZIP_ZSTANDARD
if compression_level is not None and (compression_level < 1 or compression_level > 22):
raise ControllerBadRequestError("Compression level must be between 1 and 22 for Zstandard compression")
if compression_level is not None and compression_query in ("none", "lzma"):
raise ControllerBadRequestError(f"Compression level is not supported for '{compression_query}' compression method")
try:
begin = time.time()
@ -307,8 +319,10 @@ async def export_project(
working_dir = os.path.abspath(os.path.join(project.path, os.pardir))
async def streamer():
log.info(f"Exporting project '{project.name}' with '{compression_query}' compression "
f"(level {compression_level})")
with tempfile.TemporaryDirectory(dir=working_dir) as tmpdir:
with aiozipstream.ZipFile(compression=compression) as zstream:
with aiozipstream.ZipFile(compression=compression, compresslevel=compression_level) as zstream:
await export_controller_project(
zstream,
project,

View File

@ -166,12 +166,14 @@ async def sqlalchemry_error_handler(request: Request, exc: SQLAlchemyError):
content={"message": "Database error detected, please check logs to find details"},
)
# FIXME: do not use this middleware since it creates issue when using StreamingResponse
# see https://starlette-context.readthedocs.io/en/latest/middleware.html#why-are-there-two-middlewares-that-do-the-same-thing
@app.middleware("http")
async def add_extra_headers(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
response.headers["X-GNS3-Server-Version"] = f"{__version__}"
return response
# @app.middleware("http")
# async def add_extra_headers(request: Request, call_next):
# start_time = time.time()
# response = await call_next(request)
# process_time = time.time() - start_time
# response.headers["X-Process-Time"] = str(process_time)
# response.headers["X-GNS3-Server-Version"] = f"{__version__}"
# return response

View File

@ -16,7 +16,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import sys
import json
import asyncio
import aiofiles

View File

@ -20,10 +20,10 @@ import sys
import json
import uuid
import shutil
import zipfile
import aiofiles
import itertools
import tempfile
import gns3server.utils.zipfile_zstd as zipfile_zstd
from .controller_error import ControllerError
from .topology import load_topology
@ -60,9 +60,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non
raise ControllerError("The destination path should not contain .gns3")
try:
with zipfile.ZipFile(stream) as zip_file:
with zipfile_zstd.ZipFile(stream) as zip_file:
project_file = zip_file.read("project.gns3").decode()
except zipfile.BadZipFile:
except zipfile_zstd.BadZipFile:
raise ControllerError("Cannot import project, not a GNS3 project (invalid zip)")
except KeyError:
raise ControllerError("Cannot import project, project.gns3 file could not be found")
@ -92,9 +92,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non
raise ControllerError("The project name contain non supported or invalid characters")
try:
with zipfile.ZipFile(stream) as zip_file:
with zipfile_zstd.ZipFile(stream) as zip_file:
await wait_run_in_executor(zip_file.extractall, path)
except zipfile.BadZipFile:
except zipfile_zstd.BadZipFile:
raise ControllerError("Cannot extract files from GNS3 project (invalid zip)")
topology = load_topology(os.path.join(path, "project.gns3"))
@ -264,11 +264,11 @@ async def _import_snapshots(snapshots_path, project_name, project_id):
# extract everything to a temporary directory
try:
with open(snapshot_path, "rb") as f:
with zipfile.ZipFile(f) as zip_file:
with zipfile_zstd.ZipFile(f) as zip_file:
await wait_run_in_executor(zip_file.extractall, tmpdir)
except OSError as e:
raise ControllerError(f"Cannot open snapshot '{os.path.basename(snapshot)}': {e}")
except zipfile.BadZipFile:
except zipfile_zstd.BadZipFile:
raise ControllerError(
f"Cannot extract files from snapshot '{os.path.basename(snapshot)}': not a GNS3 project (invalid zip)"
)
@ -294,7 +294,7 @@ async def _import_snapshots(snapshots_path, project_name, project_id):
# write everything back to the original snapshot file
try:
with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream:
with aiozipstream.ZipFile(compression=zipfile_zstd.ZIP_STORED) as zstream:
for root, dirs, files in os.walk(tmpdir, topdown=True, followlinks=False):
for file in files:
path = os.path.join(root, file)

View File

@ -28,7 +28,7 @@ from .controller.appliances import ApplianceVersion, Appliance
from .controller.drawings import Drawing
from .controller.gns3vm import GNS3VM
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression
from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission
from .controller.tokens import Token

View File

@ -102,3 +102,15 @@ class ProjectFile(BaseModel):
path: str = Field(..., description="File path")
md5sum: str = Field(..., description="File checksum")
class ProjectCompression(str, Enum):
"""
Supported project compression.
"""
none = "none"
zip = "zip"
bzip2 = "bzip2"
lzma = "lzma"
zstd = "zstd"

View File

@ -43,26 +43,38 @@ from zipfile import (
stringEndArchive64Locator,
)
ZIP_ZSTANDARD = 93 # zstandard is supported by WinZIP v24 and later, PowerArchiver 2021 and 7-Zip-zstd
ZSTANDARD_VERSION = 20
stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor
def _get_compressor(compress_type):
def _get_compressor(compress_type, compresslevel=None):
"""
Return the compressor.
"""
if compress_type == zipfile.ZIP_DEFLATED:
from zipfile import zlib
if compresslevel is not None:
return zlib.compressobj(compresslevel, zlib.DEFLATED, -15)
return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15)
elif compress_type == zipfile.ZIP_BZIP2:
from zipfile import bz2
if compresslevel is not None:
return bz2.BZ2Compressor(compresslevel)
return bz2.BZ2Compressor()
# compresslevel is ignored for ZIP_LZMA
elif compress_type == zipfile.ZIP_LZMA:
from zipfile import LZMACompressor
return LZMACompressor()
elif compress_type == ZIP_ZSTANDARD:
import zstandard as zstd
if compresslevel is not None:
#params = zstd.ZstdCompressionParameters.from_level(compresslevel, threads=-1, enable_ldm=True, window_log=31)
#return zstd.ZstdCompressor(compression_params=params).compressobj()
return zstd.ZstdCompressor(level=compresslevel).compressobj()
return zstd.ZstdCompressor().compressobj()
else:
return None
@ -129,7 +141,15 @@ class ZipInfo(zipfile.ZipInfo):
class ZipFile(zipfile.ZipFile):
def __init__(self, fileobj=None, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True, chunksize=32768):
def __init__(
self,
fileobj=None,
mode="w",
compression=zipfile.ZIP_STORED,
allowZip64=True,
compresslevel=None,
chunksize=32768
):
"""Open the ZIP file with mode write "w"."""
if mode not in ("w",):
@ -138,7 +158,13 @@ class ZipFile(zipfile.ZipFile):
fileobj = PointerIO()
self._comment = b""
zipfile.ZipFile.__init__(self, fileobj, mode=mode, compression=compression, allowZip64=allowZip64)
zipfile.ZipFile.__init__(
self, fileobj,
mode=mode,
compression=compression,
compresslevel=compresslevel,
allowZip64=allowZip64
)
self._chunksize = chunksize
self.paths_to_write = []
@ -195,23 +221,33 @@ class ZipFile(zipfile.ZipFile):
for chunk in self._close():
yield chunk
def write(self, filename, arcname=None, compress_type=None):
def write(self, filename, arcname=None, compress_type=None, compresslevel=None):
"""
Write a file to the archive under the name `arcname`.
"""
kwargs = {"filename": filename, "arcname": arcname, "compress_type": compress_type}
kwargs = {
"filename": filename,
"arcname": arcname,
"compress_type": compress_type,
"compresslevel": compresslevel
}
self.paths_to_write.append(kwargs)
def write_iter(self, arcname, iterable, compress_type=None):
def write_iter(self, arcname, iterable, compress_type=None, compresslevel=None):
"""
Write the bytes iterable `iterable` to the archive under the name `arcname`.
"""
kwargs = {"arcname": arcname, "iterable": iterable, "compress_type": compress_type}
kwargs = {
"arcname": arcname,
"iterable": iterable,
"compress_type": compress_type,
"compresslevel": compresslevel
}
self.paths_to_write.append(kwargs)
def writestr(self, arcname, data, compress_type=None):
def writestr(self, arcname, data, compress_type=None, compresslevel=None):
"""
Writes a str into ZipFile by wrapping data as a generator
"""
@ -219,9 +255,9 @@ class ZipFile(zipfile.ZipFile):
def _iterable():
yield data
return self.write_iter(arcname, _iterable(), compress_type=compress_type)
return self.write_iter(arcname, _iterable(), compress_type=compress_type, compresslevel=compresslevel)
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None):
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None, compresslevel=None):
"""
Put the bytes from filename into the archive under the name `arcname`.
"""
@ -256,6 +292,11 @@ class ZipFile(zipfile.ZipFile):
else:
zinfo.compress_type = compress_type
if compresslevel is None:
zinfo._compresslevel = self.compresslevel
else:
zinfo._compresslevel = compresslevel
if st:
zinfo.file_size = st[6]
else:
@ -279,7 +320,7 @@ class ZipFile(zipfile.ZipFile):
yield self.fp.write(zinfo.FileHeader(False))
return
cmpr = _get_compressor(zinfo.compress_type)
cmpr = _get_compressor(zinfo.compress_type, zinfo._compresslevel)
# Must overwrite CRC and sizes with correct data later
zinfo.CRC = CRC = 0
@ -369,6 +410,8 @@ class ZipFile(zipfile.ZipFile):
min_version = max(zipfile.BZIP2_VERSION, min_version)
elif zinfo.compress_type == zipfile.ZIP_LZMA:
min_version = max(zipfile.LZMA_VERSION, min_version)
elif zinfo.compress_type == ZIP_ZSTANDARD:
min_version = max(ZSTANDARD_VERSION, min_version)
extract_version = max(min_version, zinfo.extract_version)
create_version = max(min_version, zinfo.create_version)

View File

@ -0,0 +1,10 @@
# NOTE: this patches the standard zipfile module
from . import _zipfile
from zipfile import *
from zipfile import (
ZIP_ZSTANDARD,
ZSTANDARD_VERSION,
)

View File

@ -0,0 +1,20 @@
import functools
class patch:
originals = {}
def __init__(self, host, name):
self.host = host
self.name = name
def __call__(self, func):
original = getattr(self.host, self.name)
self.originals[self.name] = original
functools.update_wrapper(func, original)
setattr(self.host, self.name, func)
return func

View File

@ -0,0 +1,64 @@
import zipfile
import zstandard as zstd
import inspect
from ._patcher import patch
zipfile.ZIP_ZSTANDARD = 93
zipfile.compressor_names[zipfile.ZIP_ZSTANDARD] = 'zstandard'
zipfile.ZSTANDARD_VERSION = 20
@patch(zipfile, '_check_compression')
def zstd_check_compression(compression):
if compression == zipfile.ZIP_ZSTANDARD:
pass
else:
patch.originals['_check_compression'](compression)
class ZstdDecompressObjWrapper:
def __init__(self, o):
self.o = o
def __getattr__(self, attr):
if attr == 'eof':
return False
return getattr(self.o, attr)
@patch(zipfile, '_get_decompressor')
def zstd_get_decompressor(compress_type):
if compress_type == zipfile.ZIP_ZSTANDARD:
return ZstdDecompressObjWrapper(zstd.ZstdDecompressor(max_window_size=2147483648).decompressobj())
else:
return patch.originals['_get_decompressor'](compress_type)
if 'compresslevel' in inspect.signature(zipfile._get_compressor).parameters:
@patch(zipfile, '_get_compressor')
def zstd_get_compressor(compress_type, compresslevel=None):
if compress_type == zipfile.ZIP_ZSTANDARD:
if compresslevel is None:
compresslevel = 3
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
else:
return patch.originals['_get_compressor'](compress_type, compresslevel=compresslevel)
else:
@patch(zipfile, '_get_compressor')
def zstd_get_compressor(compress_type, compresslevel=None):
if compress_type == zipfile.ZIP_ZSTANDARD:
if compresslevel is None:
compresslevel = 3
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
else:
return patch.originals['_get_compressor'](compress_type)
@patch(zipfile.ZipInfo, 'FileHeader')
def zstd_FileHeader(self, zip64=None):
if self.compress_type == zipfile.ZIP_ZSTANDARD:
self.create_version = max(self.create_version, zipfile.ZSTANDARD_VERSION)
self.extract_version = max(self.extract_version, zipfile.ZSTANDARD_VERSION)
return patch.originals['FileHeader'](self, zip64=zip64)

View File

@ -16,4 +16,5 @@ passlib[bcrypt]==1.7.4
python-jose==3.3.0
email-validator==1.2.1
watchfiles==0.14.1
zstandard==0.17.0
setuptools==60.6.0 # don't upgrade because of https://github.com/pypa/setuptools/issues/3084

View File

@ -17,7 +17,6 @@
import uuid
import os
import zipfile
import json
import pytest
@ -26,6 +25,7 @@ from httpx import AsyncClient
from unittest.mock import patch, MagicMock
from tests.utils import asyncio_patch
import gns3server.utils.zipfile_zstd as zipfile_zstd
from gns3server.controller import Controller
from gns3server.controller.project import Project
@ -261,7 +261,7 @@ async def test_export_with_images(app: FastAPI, client: AsyncClient, tmpdir, pro
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
f.write(response.content)
with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip:
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
with myzip.open("a") as myfile:
content = myfile.read()
assert content == b"hello"
@ -304,7 +304,7 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir,
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
f.write(response.content)
with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip:
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
with myzip.open("a") as myfile:
content = myfile.read()
assert content == b"hello"
@ -313,6 +313,67 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir,
myzip.getinfo("images/IOS/test.image")
@pytest.mark.parametrize(
"compression, compression_level, status_code",
(
("none", None, status.HTTP_200_OK),
("none", 4, status.HTTP_400_BAD_REQUEST),
("zip", None, status.HTTP_200_OK),
("zip", 1, status.HTTP_200_OK),
("zip", 12, status.HTTP_400_BAD_REQUEST),
("bzip2", None, status.HTTP_200_OK),
("bzip2", 1, status.HTTP_200_OK),
("bzip2", 13, status.HTTP_400_BAD_REQUEST),
("lzma", None, status.HTTP_200_OK),
("lzma", 1, status.HTTP_400_BAD_REQUEST),
("zstd", None, status.HTTP_200_OK),
("zstd", 12, status.HTTP_200_OK),
("zstd", 23, status.HTTP_400_BAD_REQUEST),
)
)
async def test_export_compression(
app: FastAPI,
client: AsyncClient,
tmpdir,
project: Project,
compression: str,
compression_level: int,
status_code: int
) -> None:
project.dump = MagicMock()
os.makedirs(project.path, exist_ok=True)
topology = {
"topology": {
"nodes": [
{
"node_type": "qemu"
}
]
}
}
with open(os.path.join(project.path, "test.gns3"), 'w+') as f:
json.dump(topology, f)
params = {"compression": compression}
if compression_level:
params["compression_level"] = compression_level
response = await client.get(app.url_path_for("export_project", project_id=project.id), params=params)
assert response.status_code == status_code
if response.status_code == status.HTTP_200_OK:
assert response.headers['CONTENT-TYPE'] == 'application/gns3project'
assert response.headers['CONTENT-DISPOSITION'] == 'attachment; filename="{}.gns3project"'.format(project.name)
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
f.write(response.content)
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
with myzip.open("project.gns3") as myfile:
myfile.read()
async def test_get_file(app: FastAPI, client: AsyncClient, project: Project) -> None:
os.makedirs(project.path, exist_ok=True)