Handle ACE propagate and allowed

This commit is contained in:
grossmj 2023-08-28 12:06:01 +10:00
parent 57197c3d1c
commit 3e0592520b
2 changed files with 27 additions and 87 deletions

View File

@ -32,11 +32,11 @@ class ResourcePool(BaseTable):
resource_id = Column(GUID, primary_key=True) resource_id = Column(GUID, primary_key=True)
resource_type = Column(String) resource_type = Column(String)
# # Create a self-referential relationship to represent a hierarchy of resources # Create a self-referential relationship to represent a hierarchy of resources
# parent_id = Column(GUID, ForeignKey("resources.resource_id", ondelete="CASCADE")) parent_id = Column(GUID, ForeignKey("resources.resource_id", ondelete="CASCADE"))
# children = relationship( children = relationship(
# "Resource", "Resource",
# remote_side=[resource_id], remote_side=[resource_id],
# cascade="all, delete-orphan", cascade="all, delete-orphan",
# single_parent=True single_parent=True
# ) )

View File

@ -16,6 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from uuid import UUID from uuid import UUID
from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
from sqlalchemy import select, update, delete, null from sqlalchemy import select, update, delete, null
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -290,25 +291,6 @@ class RbacRepository(BaseRepository):
# permissions_deleted += 1 # permissions_deleted += 1
# log.info(f"{permissions_deleted} orphaned permissions have been deleted") # log.info(f"{permissions_deleted} orphaned permissions have been deleted")
# return permissions_deleted # return permissions_deleted
#
# def _match_permission(
# self,
# permissions: List[models.Permission],
# method: str,
# path: str
# ) -> Union[None, models.Permission]:
# """
# Match the methods and path with a permission.
# """
#
# for permission in permissions:
# log.debug(f"RBAC: checking permission {permission.methods} {permission.path} {permission.action}")
# if method not in permission.methods:
# continue
# if permission.path.endswith("/*") and path.startswith(permission.path[:-2]):
# return permission
# elif permission.path == path:
# return permission
async def delete_all_ace_starting_with_path(self, path: str) -> None: async def delete_all_ace_starting_with_path(self, path: str) -> None:
""" """
@ -323,71 +305,29 @@ class RbacRepository(BaseRepository):
async def check_user_has_privilege(self, user_id: UUID, path: str, privilege_name: str) -> bool: async def check_user_has_privilege(self, user_id: UUID, path: str, privilege_name: str) -> bool:
# query = select(models.Privilege.name).\ #TODO: handle when user belong to one or more groups (left join?)
# join(models.Privilege.roles).\ query = select(models.ACE.path, models.ACE.propagate, models.ACE.allowed, models.Privilege.name).\
# join(models.Role.acl_entries).\
# join(models.ACE.user).\
# filter(models.Privilege.name == privilege). \
# filter(models.User.user_id == user_id).\
# filter(models.ACE.path == path).\
# distinct()
#query = select(models.ACE.path)
#result = await self._db_session.execute(query)
#res = result.scalars().all()
#print("ACL TABLE ==>", res)
#for ace in res:
# print(ace)
query = select(models.Privilege.name, models.ACE.path, models.ACE.propagate).\
join(models.Privilege.roles).\ join(models.Privilege.roles).\
join(models.Role.acl_entries).\ join(models.Role.acl_entries).\
join(models.ACE.user).\ join(models.ACE.user).\
filter(models.User.user_id == user_id).\ filter(models.User.user_id == user_id).\
filter(models.Privilege.name == privilege_name).\ filter(models.Privilege.name == privilege_name).\
filter(models.ACE.path == path).\
order_by(models.ACE.path.desc()) order_by(models.ACE.path.desc())
result = await self._db_session.execute(query) result = await self._db_session.execute(query)
privileges = result.all() aces = result.all()
#print(privileges)
for privilege, privilege_path, propagate in privileges: parsed_url = urlparse(path)
if privilege_path == path: original_path = path
return True path_components = parsed_url.path.split("/")
# traverse the path in reverse order
for i in range(len(path_components), 0, -1):
path = "/".join(path_components[:i])
if not path:
path = "/"
for ace_path, ace_propagate, ace_allowed, ace_privilege in aces:
if ace_path == path:
if not ace_allowed:
return False
if path == original_path or ace_propagate:
return True # only allow if the path is the original path or the ACE is set to propagate
return False return False
async def check_user_is_authorized(self, user_id: UUID, path: str) -> bool:
"""
Check if a user is authorized to access a resource.
"""
return True
# query = select(models.Permission).\
# join(models.Permission.roles).\
# join(models.Role.groups).\
# join(models.UserGroup.users).\
# filter(models.User.user_id == user_id).\
# order_by(models.Permission.path.desc())
#
# result = await self._db_session.execute(query)
# permissions = result.scalars().all()
# log.debug(f"RBAC: checking authorization for user '{user_id}' on {method} '{path}'")
# matched_permission = self._match_permission(permissions, method, path)
# if matched_permission:
# log.debug(f"RBAC: matched role permission {matched_permission.methods} "
# f"{matched_permission.path} {matched_permission.action}")
# if matched_permission.action == "DENY":
# return False
# return True
#
# log.debug(f"RBAC: could not find a role permission, checking user permissions...")
# permissions = await self.get_user_permissions(user_id)
# matched_permission = self._match_permission(permissions, method, path)
# if matched_permission:
# log.debug(f"RBAC: matched user permission {matched_permission.methods} "
# f"{matched_permission.path} {matched_permission.action}")
# if matched_permission.action == "DENY":
# return False
# return True
#
# return False