Extract the notification part of controller to a dedicated class

This commit is contained in:
Julien Duponchelle 2016-05-18 14:56:23 +02:00
parent d86cefaaeb
commit 694e1a2e68
No known key found for this signature in database
GPG Key ID: CE8B29639E07F5E8
10 changed files with 162 additions and 106 deletions

View File

@ -24,6 +24,7 @@ import aiohttp
from ..config import Config from ..config import Config
from .project import Project from .project import Project
from .compute import Compute from .compute import Compute
from .notification import Notification
from ..version import __version__ from ..version import __version__
import logging import logging
@ -36,6 +37,7 @@ class Controller:
def __init__(self): def __init__(self):
self._computes = {} self._computes = {}
self._projects = {} self._projects = {}
self._notification = Notification(self)
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
config_path = os.path.join(os.path.expandvars("%APPDATA%"), "GNS3") config_path = os.path.join(os.path.expandvars("%APPDATA%"), "GNS3")
@ -116,6 +118,13 @@ class Controller:
password=server_config.get("password", "")) password=server_config.get("password", ""))
return self._computes["local"] return self._computes["local"]
@property
def notification(self):
"""
The notification system
"""
return self._notification
@property @property
def computes(self): def computes(self):
""" """
@ -180,17 +189,3 @@ class Controller:
Controller._instance = Controller() Controller._instance = Controller()
return Controller._instance return Controller._instance
def emit(self, action, event, **kwargs):
"""
Send a notification to clients scoped by projects
"""
if "project_id" in kwargs:
try:
project_id = kwargs.pop("project_id")
self._projects[project_id].emit(action, event, **kwargs)
except KeyError:
pass
else:
for project_instance in self._projects.values():
project_instance.emit(action, event, **kwargs)

View File

@ -196,7 +196,7 @@ class Compute:
msg = json.loads(response.data) msg = json.loads(response.data)
action = msg.pop("action") action = msg.pop("action")
event = msg.pop("event") event = msg.pop("event")
self._controller.emit(action, event, compute_id=self.id, **msg) self._controller.notification.emit(action, event, compute_id=self.id, **msg)
def _getUrl(self, path): def _getUrl(self, path):
return "{}://{}:{}/v2/compute{}".format(self._protocol, self._host, self._port, path) return "{}://{}:{}/v2/compute{}".format(self._protocol, self._host, self._port, path)

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python
#
# Copyright (C) 2016 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from contextlib import contextmanager
from ..notification_queue import NotificationQueue
class Notification:
"""
Manage notification for the controller
"""
def __init__(self, controller):
self._controller = controller
self._listeners = {}
@contextmanager
def queue(self, project):
"""
Get a queue of notifications
Use it with Python with
"""
queue = NotificationQueue()
self._listeners.setdefault(project.id, set())
self._listeners[project.id].add(queue)
yield queue
self._listeners[project.id].remove(queue)
def emit(self, action, event, **kwargs):
"""
Send a notification to clients scoped by projects
:param action: Action name
:param event: Event to send
:param kwargs: Add this meta to the notification
"""
if "project_id" in kwargs:
project_id = kwargs.pop("project_id")
self._send_event_to_project(project_id, action, event, **kwargs)
else:
self._send_event_to_all(action, event, **kwargs)
def _send_event_to_project(self, project_id, action, event, **kwargs):
"""
Send an event to all the client listening for notifications for
this project
:param project: Project where we need to send the event
:param action: Action name
:param event: Event to send
:param kwargs: Add this meta to the notification
"""
try:
project_listeners = self._listeners[project_id]
except KeyError:
return
for listener in project_listeners:
listener.put_nowait((action, event, kwargs))
def _send_event_to_all(self, action, event, **kwargs):
"""
Send an event to all the client listening for notifications on all
projects
:param action: Action name
:param event: Event to send
:param kwargs: Add this meta to the notification
"""
for project_listeners in self._listeners.values():
for listener in project_listeners:
listener.put_nowait((action, event, kwargs))

View File

@ -21,11 +21,9 @@ import aiohttp
import shutil import shutil
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from contextlib import contextmanager
from .node import Node from .node import Node
from .udp_link import UDPLink from .udp_link import UDPLink
from ..notification_queue import NotificationQueue
from ..config import Config from ..config import Config
from ..utils.path import check_path_allowed, get_default_project_directory from ..utils.path import check_path_allowed, get_default_project_directory
@ -59,7 +57,6 @@ class Project:
self._computes = set() self._computes = set()
self._nodes = {} self._nodes = {}
self._links = {} self._links = {}
self._listeners = set()
# Create the project on demand on the compute node # Create the project on demand on the compute node
self._project_created_on_compute = set() self._project_created_on_compute = set()
@ -203,29 +200,6 @@ class Project:
yield from compute.delete("/projects/{}".format(self._id)) yield from compute.delete("/projects/{}".format(self._id))
shutil.rmtree(self.path, ignore_errors=True) shutil.rmtree(self.path, ignore_errors=True)
@contextmanager
def queue(self):
"""
Get a queue of notifications
Use it with Python with
"""
queue = NotificationQueue()
self._listeners.add(queue)
yield queue
self._listeners.remove(queue)
def emit(self, action, event, **kwargs):
"""
Send an event to all the client listening for notifications
:param action: Action name
:param event: Event to send
:param kwargs: Add this meta to the notification (project_id for example)
"""
for listener in self._listeners:
listener.put_nowait((action, event, kwargs))
@classmethod @classmethod
def _get_default_project_directory(cls): def _get_default_project_directory(cls):
""" """

View File

@ -152,7 +152,7 @@ class ProjectHandler:
response.content_length = None response.content_length = None
response.start(request) response.start(request)
with project.queue() as queue: with controller.notification.queue(project) as queue:
while True: while True:
try: try:
msg = yield from queue.get_json(5) msg = yield from queue.get_json(5)
@ -178,7 +178,7 @@ class ProjectHandler:
ws = aiohttp.web.WebSocketResponse() ws = aiohttp.web.WebSocketResponse()
yield from ws.prepare(request) yield from ws.prepare(request)
with project.queue() as queue: with controller.notification.queue(project) as queue:
while True: while True:
try: try:
notification = yield from queue.get_json(5) notification = yield from queue.get_json(5)

View File

@ -160,13 +160,13 @@ def test_connectNotification(compute, async_run):
response.tp = aiohttp.MsgType.closed response.tp = aiohttp.MsgType.closed
return response return response
compute._controller = MagicMock() compute._controller._notifications = MagicMock()
compute._session = AsyncioMagicMock(return_value=ws_mock) compute._session = AsyncioMagicMock(return_value=ws_mock)
compute._session.ws_connect = AsyncioMagicMock(return_value=ws_mock) compute._session.ws_connect = AsyncioMagicMock(return_value=ws_mock)
ws_mock.receive = receive ws_mock.receive = receive
async_run(compute._connect_notification()) async_run(compute._connect_notification())
compute._controller.emit.assert_called_with('test', {'a': 1}, compute_id=compute.id, project_id='42') compute._controller.notification.emit.assert_called_with('test', {'a': 1}, compute_id=compute.id, project_id='42')
assert compute._connected is False assert compute._connected is False

View File

@ -157,47 +157,3 @@ def test_getProject(controller, async_run):
with pytest.raises(aiohttp.web.HTTPNotFound): with pytest.raises(aiohttp.web.HTTPNotFound):
assert controller.get_project("dsdssd") assert controller.get_project("dsdssd")
def test_emit(controller, async_run):
project1 = MagicMock()
uuid1 = str(uuid.uuid4())
controller._projects[uuid1] = project1
project2 = MagicMock()
uuid2 = str(uuid.uuid4())
controller._projects[uuid2] = project2
# Notif without project should be send to all projects
controller.emit("test", {})
assert project1.emit.called
assert project2.emit.called
def test_emit_to_project(controller, async_run):
project1 = MagicMock()
uuid1 = str(uuid.uuid4())
controller._projects[uuid1] = project1
project2 = MagicMock()
uuid2 = str(uuid.uuid4())
controller._projects[uuid2] = project2
# Notif with project should be send to this project
controller.emit("test", {}, project_id=uuid1)
project1.emit.assert_called_with('test', {})
assert not project2.emit.called
def test_emit_to_project_not_exists(controller, async_run):
project1 = MagicMock()
uuid1 = str(uuid.uuid4())
controller._projects[uuid1] = project1
project2 = MagicMock()
uuid2 = str(uuid.uuid4())
controller._projects[uuid2] = project2
# Notif with project should be send to this project
controller.emit("test", {}, project_id="4444444")
assert not project1.emit.called
assert not project2.emit.called

View File

@ -0,0 +1,55 @@
#!/usr/bin/env python
#
# Copyright (C) 2016 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest
from gns3server.controller.notification import Notification
from gns3server.controller.project import Project
def test_emit_to_all(async_run, controller):
"""
Send an event to all if we don't have a project id in the event
"""
project = Project()
notif = controller.notification
with notif.queue(project) as queue:
assert len(notif._listeners[project.id]) == 1
async_run(queue.get(0.1)) #  ping
notif.emit('test', {})
msg = async_run(queue.get(5))
assert msg == ('test', {}, {})
assert len(notif._listeners[project.id]) == 0
def test_emit_to_project(async_run, controller):
"""
Send an event to a project listeners
"""
project = Project()
notif = controller.notification
with notif.queue(project) as queue:
assert len(notif._listeners[project.id]) == 1
async_run(queue.get(0.1)) #  ping
# This event has not listener
notif.emit('ignore', {}, project_id=42)
notif.emit('test', {}, project_id=project.id)
msg = async_run(queue.get(5))
assert msg == ('test', {}, {})
assert len(notif._listeners[project.id]) == 0

View File

@ -194,14 +194,3 @@ def test_getLink(async_run):
with pytest.raises(aiohttp.web_exceptions.HTTPNotFound): with pytest.raises(aiohttp.web_exceptions.HTTPNotFound):
project.get_link("test") project.get_link("test")
def test_emit(async_run):
project = Project()
with project.queue() as queue:
assert len(project._listeners) == 1
async_run(queue.get(0.1)) #  ping
project.emit('test', {})
notif = async_run(queue.get(5))
assert notif == ('test', {}, {})
assert len(project._listeners) == 0

View File

@ -35,11 +35,11 @@ from gns3server.controller import Controller
@pytest.fixture @pytest.fixture
def project(http_controller): def project(http_controller, controller):
u = str(uuid.uuid4()) u = str(uuid.uuid4())
query = {"name": "test", "project_id": u} query = {"name": "test", "project_id": u}
response = http_controller.post("/projects", query) response = http_controller.post("/projects", query)
return Controller.instance().get_project(u) return controller.get_project(u)
def test_create_project_with_path(http_controller, tmpdir): def test_create_project_with_path(http_controller, tmpdir):
@ -121,12 +121,12 @@ def test_close_project(http_controller, project):
assert project not in Controller.instance().projects assert project not in Controller.instance().projects
def test_notification(http_controller, project, loop): def test_notification(http_controller, project, controller, loop):
@asyncio.coroutine @asyncio.coroutine
def go(future): def go(future):
response = yield from aiohttp.request("GET", http_controller.get_url("/projects/{project_id}/notifications".format(project_id=project.id))) response = yield from aiohttp.request("GET", http_controller.get_url("/projects/{project_id}/notifications".format(project_id=project.id)))
response.body = yield from response.content.read(200) response.body = yield from response.content.read(200)
project.emit("node.created", {"a": "b"}) controller.notification.emit("node.created", {"a": "b"})
response.body += yield from response.content.read(50) response.body += yield from response.content.read(50)
response.close() response.close()
future.set_result(response) future.set_result(response)
@ -145,13 +145,13 @@ def test_notification_invalid_id(http_controller):
assert response.status == 404 assert response.status == 404
def test_notification_ws(http_controller, project, async_run): def test_notification_ws(http_controller, controller, project, async_run):
ws = http_controller.websocket("/projects/{project_id}/notifications/ws".format(project_id=project.id)) ws = http_controller.websocket("/projects/{project_id}/notifications/ws".format(project_id=project.id))
answer = async_run(ws.receive()) answer = async_run(ws.receive())
answer = json.loads(answer.data) answer = json.loads(answer.data)
assert answer["action"] == "ping" assert answer["action"] == "ping"
project.emit("test", {}) controller.notification.emit("test", {})
answer = async_run(ws.receive()) answer = async_run(ws.receive())
answer = json.loads(answer.data) answer = json.loads(answer.data)