Added searching in /locations list and /locations registered

Started creating a support server
This commit is contained in:
GaMeNu 2023-10-25 21:16:31 +03:00
parent 360b57786d
commit 7601de3c0b
3 changed files with 45 additions and 17 deletions

View File

@ -86,5 +86,7 @@ note that the .env file must be in the same directory as main.py
[HFC Website](https://www.oref.org.il/) [HFC Website](https://www.oref.org.il/)
[Support Server](https://discord.gg/K3E4a5ekNy)

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import datetime import datetime
import re
import requests import requests
@ -570,7 +571,8 @@ class Notificator(commands.Cog):
value=md.bq(f'{md.hl("GitHub", "https://github.com/GaMeNu/HFCNotificator")}\n' value=md.bq(f'{md.hl("GitHub", "https://github.com/GaMeNu/HFCNotificator")}\n'
f'{md.hl("Official Bot Invite Link", "https://discord.com/api/oauth2/authorize?client_id=1160344131067977738&permissions=0&scope=applications.commands%20bot")}\n' f'{md.hl("Official Bot Invite Link", "https://discord.com/api/oauth2/authorize?client_id=1160344131067977738&permissions=0&scope=applications.commands%20bot")}\n'
f'{md.hl("HFC Website", "https://www.oref.org.il/")}\n' f'{md.hl("HFC Website", "https://www.oref.org.il/")}\n'
f'{md.hl("Bot Profile (for DMs)", "https://discord.com/users/1160344131067977738")}'), f'{md.hl("Bot Profile (for DMs)", "https://discord.com/users/1160344131067977738")}\n'
f'{md.hl("Support Server", "https://discord.gg/K3E4a5ekNy")}'),
inline=True) inline=True)
e.add_field(name='Created by', value=md.bq('GaMeNu (@gamenu)\n' e.add_field(name='Created by', value=md.bq('GaMeNu (@gamenu)\n'
@ -663,11 +665,18 @@ class Notificator(commands.Cog):
return page_content return page_content
@location_group.command(name='list', description='Show the list of all available locations with matching IDs, sorted alphabetically') @location_group.command(name='list', description='List all available locations, by IDs and names. Sorted alphabetically')
async def locations_list(self, intr: discord.Interaction, page: int = 1): @app_commands.describe(search='Search tokens, separated by spaces')
async def locations_list(self, intr: discord.Interaction, search: str | None = None, page: int = 1):
# decide the search_results
if search is not None:
search_results = self.db.search_districts(*re.split(r"\s+", search))
else:
search_results = self.db.get_all_districts()
try: try:
page = self.locations_page(sorted(self.db.get_all_districts(), key=lambda tup: tup[1]), page - 1) # Turn into a display-able page
page = self.locations_page(sorted(search_results, key=lambda tup: tup[1]), page - 1)
except ValueError as e: except ValueError as e:
await intr.response.send_message(e.__str__()) await intr.response.send_message(e.__str__())
return return
@ -760,25 +769,21 @@ class Notificator(commands.Cog):
await intr.response.send_message( await intr.response.send_message(
f'Cleared all registered locations.\nChannel will now receive alerts from every location.') f'Cleared all registered locations.\nChannel will now receive alerts from every location.')
@location_group.command(name='registered', description='List all locations registered to this channel, by IDs and names.') @location_group.command(name='registered', description='List all locations registered to this channel, by IDs and names. Sorted alphabetically')
async def location_registered(self, intr: discord.Interaction, page: int = 1): @app_commands.describe(search='Search tokens, separated by spaces')
async def location_registered(self, intr: discord.Interaction, search: str | None = None, page: int = 1):
channel = self.get_matching_channel(intr) channel = self.get_matching_channel(intr)
if channel is None: if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?') await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return return
# Congrats! It's a MESS! if search is None:
# This code is so ugly, but basically search_results = [dist.to_tuple() for dist in self.db.district_ids_to_districts(*self.db.get_channel_district_ids(channel.id))]
# else:
# It gets a list of all of a channel's districts search_results = [dist.to_tuple() for dist in self.db.search_channel_districts(channel.id, *re.split(r"\s+", search))]
# Converts it to a tuple form in which
# district_tuple = (district_id: int, district_name: str, area_id: int, migun_time: int) districts = sorted(search_results, key=lambda tup: tup[1])
# Then it sorts it with the key being district_name
districts = sorted([self.db.get_district(district_id).to_tuple()
for district_id
in self.db.get_channel_district_ids(channel.id)],
key=lambda tup: tup[1])
page = self.locations_page(districts, page - 1) page = self.locations_page(districts, page - 1)

View File

@ -270,6 +270,15 @@ class DBAccess:
ret = crsr.fetchall() ret = crsr.fetchall()
return ret return ret
def search_districts(self, *tokens: str) -> Sequence:
with self.connection.cursor() as crsr:
query = 'SELECT * FROM districts WHERE '
query += ' AND '.join(["district_name LIKE %s" for _ in tokens])
query += ';'
crsr.execute(query, [f'%{token}%' for token in tokens])
ret = crsr.fetchall()
return ret
def district_iterator(self) -> DistrictIterator: def district_iterator(self) -> DistrictIterator:
""" """
This function is DEPRECATED! This function is DEPRECATED!
@ -335,6 +344,18 @@ class DBAccess:
districts = json.loads(dist[0]) districts = json.loads(dist[0])
return districts return districts
def district_ids_to_districts(self, *district_ids) -> list[District]:
return [self.get_district(district_id) for district_id in district_ids]
def search_channel_districts(self, channel_id: int, *tokens: str) -> list[District]:
district_ids = self.get_channel_district_ids(channel_id)
districts = [self.get_district(district_id) for district_id in district_ids]
filtered_districts = [district for district in districts if all(token in district.name for token in tokens)]
return filtered_districts
def remove_channel_districts(self, channel_id: int, district_ids: list[int]): def remove_channel_districts(self, channel_id: int, district_ids: list[int]):
with self.connection.cursor() as crsr: with self.connection.cursor() as crsr: