diff --git a/README.md b/README.md index 2f384af..6ce4bf3 100644 --- a/README.md +++ b/README.md @@ -86,5 +86,7 @@ note that the .env file must be in the same directory as main.py [HFC Website](https://www.oref.org.il/) +[Support Server](https://discord.gg/K3E4a5ekNy) + diff --git a/cog_notificator.py b/cog_notificator.py index ec525e1..d68957b 100644 --- a/cog_notificator.py +++ b/cog_notificator.py @@ -1,5 +1,6 @@ import asyncio import datetime +import re import requests @@ -570,7 +571,8 @@ class Notificator(commands.Cog): value=md.bq(f'{md.hl("GitHub", "https://github.com/GaMeNu/HFCNotificator")}\n' f'{md.hl("Official Bot Invite Link", "https://discord.com/api/oauth2/authorize?client_id=1160344131067977738&permissions=0&scope=applications.commands%20bot")}\n' f'{md.hl("HFC Website", "https://www.oref.org.il/")}\n' - f'{md.hl("Bot Profile (for DMs)", "https://discord.com/users/1160344131067977738")}'), + f'{md.hl("Bot Profile (for DMs)", "https://discord.com/users/1160344131067977738")}\n' + f'{md.hl("Support Server", "https://discord.gg/K3E4a5ekNy")}'), inline=True) e.add_field(name='Created by', value=md.bq('GaMeNu (@gamenu)\n' @@ -663,11 +665,18 @@ class Notificator(commands.Cog): return page_content - @location_group.command(name='list', description='Show the list of all available locations with matching IDs, sorted alphabetically') - async def locations_list(self, intr: discord.Interaction, page: int = 1): + @location_group.command(name='list', description='List all available locations, by IDs and names. Sorted alphabetically') + @app_commands.describe(search='Search tokens, separated by spaces') + async def locations_list(self, intr: discord.Interaction, search: str | None = None, page: int = 1): + # decide the search_results + if search is not None: + search_results = self.db.search_districts(*re.split(r"\s+", search)) + else: + search_results = self.db.get_all_districts() try: - page = self.locations_page(sorted(self.db.get_all_districts(), key=lambda tup: tup[1]), page - 1) + # Turn into a display-able page + page = self.locations_page(sorted(search_results, key=lambda tup: tup[1]), page - 1) except ValueError as e: await intr.response.send_message(e.__str__()) return @@ -760,25 +769,21 @@ class Notificator(commands.Cog): await intr.response.send_message( f'Cleared all registered locations.\nChannel will now receive alerts from every location.') - @location_group.command(name='registered', description='List all locations registered to this channel, by IDs and names.') - async def location_registered(self, intr: discord.Interaction, page: int = 1): + @location_group.command(name='registered', description='List all locations registered to this channel, by IDs and names. Sorted alphabetically') + @app_commands.describe(search='Search tokens, separated by spaces') + async def location_registered(self, intr: discord.Interaction, search: str | None = None, page: int = 1): channel = self.get_matching_channel(intr) if channel is None: await intr.response.send_message('Could not find this channel. Are you sure it is registered?') return - # Congrats! It's a MESS! - # This code is so ugly, but basically - # - # It gets a list of all of a channel's districts - # Converts it to a tuple form in which - # district_tuple = (district_id: int, district_name: str, area_id: int, migun_time: int) - # Then it sorts it with the key being district_name - districts = sorted([self.db.get_district(district_id).to_tuple() - for district_id - in self.db.get_channel_district_ids(channel.id)], - key=lambda tup: tup[1]) + if search is None: + search_results = [dist.to_tuple() for dist in self.db.district_ids_to_districts(*self.db.get_channel_district_ids(channel.id))] + else: + search_results = [dist.to_tuple() for dist in self.db.search_channel_districts(channel.id, *re.split(r"\s+", search))] + + districts = sorted(search_results, key=lambda tup: tup[1]) page = self.locations_page(districts, page - 1) diff --git a/db_access.py b/db_access.py index 4f83094..c064599 100644 --- a/db_access.py +++ b/db_access.py @@ -270,6 +270,15 @@ class DBAccess: ret = crsr.fetchall() return ret + def search_districts(self, *tokens: str) -> Sequence: + with self.connection.cursor() as crsr: + query = 'SELECT * FROM districts WHERE ' + query += ' AND '.join(["district_name LIKE %s" for _ in tokens]) + query += ';' + crsr.execute(query, [f'%{token}%' for token in tokens]) + ret = crsr.fetchall() + return ret + def district_iterator(self) -> DistrictIterator: """ This function is DEPRECATED! @@ -335,6 +344,18 @@ class DBAccess: districts = json.loads(dist[0]) return districts + def district_ids_to_districts(self, *district_ids) -> list[District]: + return [self.get_district(district_id) for district_id in district_ids] + + def search_channel_districts(self, channel_id: int, *tokens: str) -> list[District]: + district_ids = self.get_channel_district_ids(channel_id) + + districts = [self.get_district(district_id) for district_id in district_ids] + + filtered_districts = [district for district in districts if all(token in district.name for token in tokens)] + + return filtered_districts + def remove_channel_districts(self, channel_id: int, district_ids: list[int]): with self.connection.cursor() as crsr: