diff --git a/cog_notificator.py b/cog_notificator.py index ec95494..c46deb3 100644 --- a/cog_notificator.py +++ b/cog_notificator.py @@ -308,9 +308,18 @@ class Notificator(commands.Cog): @app_commands.command(name='unregister', description='Stop a channel from receiving HFC alerts (Requires Manage Channels)') @app_commands.checks.has_permissions(manage_channels=True) - async def unregister_channel(self, intr: discord.Interaction): + async def unregister_channel(self, intr: discord.Interaction, confirmation: str=None): channel_id = intr.channel_id + conf_str = intr.user.name + + if confirmation is None: + await intr.response.send_message(f'Are you sure you want to unregister the channel?\nThis action will also clear all related data.\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.') + return + if confirmation != conf_str: + await intr.response.send_message(f'Invalid confirmation string!') + return + channel = self.db.get_channel(channel_id) if channel is None: channel = self.db.get_channel(intr.user.id) @@ -523,6 +532,13 @@ class Notificator(commands.Cog): "desc": desc }, districts_ls) + @staticmethod + async def has_permission(intr: discord.Interaction) -> bool: + if intr.guild is not None and not intr.user.guild_permissions.manage_channels: + await intr.response.send_message('Error: You are missing the Manage Channels permission.') + return False + return True + @staticmethod def locations_page(data_list: list, page: int, res_in_page: int = 50) -> str: """ @@ -560,7 +576,7 @@ class Notificator(commands.Cog): return page_content - @location_group.command(name='locations_list', description='Show the list of all available locations') + @location_group.command(name='list', description='Show the list of all available locations') async def locations_list(self, intr: discord.Interaction, page: int = 1): try: @@ -578,9 +594,7 @@ class Notificator(commands.Cog): @location_group.command(name='add', description='Add a location(s) to the location list') @app_commands.describe(locations='A list of comma-separated Area IDs') async def location_add(self, intr: discord.Interaction, locations: str): - - if intr.guild is not None and intr.user.guild_permissions.manage_channels is False: - await intr.response.send_message('Error: You are missing the Manage Channels permission.') + if not await self.has_permission(intr): return locations_ls = [word.strip() for word in locations.split(',')] @@ -601,14 +615,19 @@ class Notificator(commands.Cog): await intr.response.send_message('Could not find this channel. Are you sure it is registered?') return - self.db.add_channel_districts(channel.id, location_ids) + try: + self.db.add_channel_districts(channel.id, location_ids) + except ValueError as e: + await intr.response.send_message(e.__str__()) + return + + await intr.response.send_message('Successfully added all IDs') @location_group.command(name='remove', description='Remove a location(s) to the location list') @app_commands.describe(locations='A list of comma-separated Area IDs') async def location_remove(self, intr: discord.Interaction, locations: str): - if intr.guild is not None and intr.user.guild_permissions.manage_channels is False: - await intr.response.send_message('Error: You are missing the Manage Channels permission.') + if not await self.has_permission(intr): return locations_ls = [word.strip() for word in locations.split(',')] @@ -630,6 +649,34 @@ class Notificator(commands.Cog): return self.db.remove_channel_districts(channel.id, location_ids) + await intr.response.send_message('Successfully removed all IDs') + + @location_group.command(name='clear', description='Clear all registered locations (get alerts on all locations)') + async def location_clear(self, intr: discord.Interaction, confirmation: str = None): + + if not await self.has_permission(intr): + return + + conf_str = intr.user.name + + channel_id = intr.channel_id + + channel = self.db.get_channel(channel_id) + if channel is None: + channel = self.db.get_channel(intr.user.id) + if channel is None: + await intr.response.send_message('Could not find this channel. Are you sure it is registered?') + return + + if confirmation is None: + await intr.response.send_message(f'Are you sure you want to clear all registered locations?\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.') + return + if confirmation != conf_str: + await intr.response.send_message(f'Invalid confirmation string!') + return + + self.db.clear_channel_districts(channel.id) + await intr.response.send_message(f'Cleared all registered locations.\nChannel will now receive alerts from every location.') @location_group.command(name='registered', description='List all locations registered to this channel') async def location_registered(self, intr: discord.Interaction, page: int = 1): @@ -643,7 +690,17 @@ class Notificator(commands.Cog): await intr.response.send_message('Could not find this channel. Are you sure it is registered?') return - districts = self.db.get_channel_districts(chanel.id) + # Congrats! It's a MESS! + # This code is so ugly, but basically + # + # It gets a list of all of a channel's districts + # Converts it to a tuple form in which + # district_tuple = (district_id: int, district_name: str, area_id: int, migun_time: int) + # Then it sorts it with the key being district_name + districts = sorted([self.db.get_district(district_id).to_tuple() + for district_id + in self.db.get_channel_district_ids(channel.id)], + key=lambda tup: tup[1]) page = self.locations_page(districts, page-1) diff --git a/db_access.py b/db_access.py index 80b6b5d..6299e11 100644 --- a/db_access.py +++ b/db_access.py @@ -33,6 +33,9 @@ class District: def from_tuple(tup: Sequence): return District(tup[0], tup[1], tup[2], tup[3]) + def to_tuple(self) -> tuple: + return self.id, self.name, self.area_id, self.migun_time + class Channel: def __init__(self, id: int, server_id: int | None, channel_lang: str, locations: list): @@ -244,6 +247,13 @@ class DBAccess: return DistrictIterator(crsr) def add_channel_district(self, channel_id: int, district_id: int): + with self.connection.cursor() as crsr: + crsr.execute('SELECT * FROM districts WHERE district_id=%s', (district_id,)) + res = crsr.fetchone() + crsr.fetchall() + if res is None: + raise ValueError(f'Invalid District ID {district_id}') + with self.connection.cursor() as crsr: crsr.execute("UPDATE channels " "SET locations = JSON_ARRAY_APPEND(locations, '$', %s) " @@ -253,29 +263,42 @@ class DBAccess: def add_channel_districts(self, channel_id: int, district_ids: list[int]): with self.connection.cursor() as crsr: + # Sorry for the messy statement. I'm lazy and it's 02:13 rn + crsr.execute(f"SELECT * FROM districts WHERE district_id IN ({','.join(['%s'] * len(district_ids))})", tuple(district_ids)) + res = crsr.fetchall() + + if len(district_ids) > len(res): + raise ValueError('Received invalid district IDs') + + # Sorry for this way of doing things (3 DB queries omg) + # JSON_MERGE_PATCH kept overwriting the existing data + # while JSON_MERGE_PRESERVE did not remove duplicates + dists = self.get_channel_district_ids(channel_id) + updated = [district for district in district_ids if district not in dists] + + with self.connection.cursor() as crsr: + crsr.execute("UPDATE channels " - "SET locations = JSON_MERGE_PATCH(locations, %s) " + "SET locations = JSON_MERGE_PRESERVE(locations, %s) " "WHERE channel_id=%s;", - (json.dumps(district_ids), channel_id)) + (json.dumps(updated), channel_id)) self.connection.commit() - def get_channel_districts(self, channel_id: int) -> list: + def get_channel_district_ids(self, channel_id: int) -> list: with self.connection.cursor() as crsr: crsr.execute('SELECT locations ' 'FROM channels ' 'WHERE channel_id=%s;', (channel_id,)) - districts = [] dist = crsr.fetchone() - while dist is not None: - districts.append(json.loads(dist[0])) - dist = crsr.fetchone() + crsr.fetchall() - return districts + districts = json.loads(dist[0]) + return districts def remove_channel_districts(self, channel_id: int, district_ids: list[int]): with self.connection.cursor() as crsr: - districts = self.get_channel_districts(channel_id) + districts = self.get_channel_district_ids(channel_id) updated = [district for district in districts if district not in district_ids] @@ -285,3 +308,12 @@ class DBAccess: (json.dumps(updated), channel_id)) self.connection.commit() + + def clear_channel_districts(self, channel_id: int): + with self.connection.cursor() as crsr: + crsr.execute('UPDATE channels ' + 'SET locations = JSON_ARRAY() ' + 'WHERE channel_id = %s;', + (channel_id,)) + + self.connection.commit() diff --git a/db_creation/__init__.py b/db_creation/__init__.py index 6c4c011..cd7ca49 100644 --- a/db_creation/__init__.py +++ b/db_creation/__init__.py @@ -1 +1 @@ -__version__ = '1.0.1' \ No newline at end of file +__version__ = '1.0.1'