Fixed bot preferring user id over channel id in new location commands, as per issue #005
This commit is contained in:
GaMeNu 2023-10-22 14:03:07 +03:00
parent ae1d439f48
commit 6591bc6dd9
2 changed files with 76 additions and 38 deletions

View File

@ -122,6 +122,41 @@ class Notificator(commands.Cog):
return
self.check_for_updates.start()
def in_registered_channel(self, intr: discord.Interaction) -> bool | None:
"""
:param intr: discord command interaction
:return: True - is a registered server channel, False - is a registered DM, None - was not found (is not registered)
"""
# OPTIONS:
# Channel ID not None + DB not None: IS Channel and IS Registered => matching output and end
# Channel ID not None + DB None: IS Channel and NOT Registered => matching output and end
# Channel ID None cases:
# User ID not None + DB not None: IS DM and IS Registered
# User ID not None + DB None: IS DM and NOT Registered
#
# Off I go to make a utility function!
ch = self.bot.get_channel(intr.channel_id)
if ch is not None and self.db.is_registered_channel(ch.id):
return True
ch = self.bot.get_user(intr.user.id)
if ch is not None and self.db.is_registered_channel(ch.id):
return False
return None
def get_matching_channel_id(self, intr: discord.Interaction) -> int | None:
channel_type = self.in_registered_channel(intr)
if channel_type is None:
return None
elif channel_type:
return intr.channel_id
else:
return intr.user.id
@tasks.loop(seconds=1)
async def check_for_updates(self):
try:
@ -605,9 +640,15 @@ class Notificator(commands.Cog):
@location_group.command(name='add', description='Add a location(s) to the location list')
@app_commands.describe(locations='A list of comma-separated Area IDs')
async def location_add(self, intr: discord.Interaction, locations: str):
if not await self.has_permission(intr):
return
channel_id = self.get_matching_channel_id(intr)
if channel_id is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
locations_ls = [word.strip() for word in locations.split(',')]
location_ids = []
for location in locations_ls:
@ -617,17 +658,8 @@ class Notificator(commands.Cog):
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
return
channel_id = intr.channel_id
channel = self.db.get_channel(channel_id)
if channel is None:
channel = self.db.get_channel(intr.user.id)
if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
try:
self.db.add_channel_districts(channel.id, location_ids)
self.db.add_channel_districts(channel_id, location_ids)
except ValueError as e:
await intr.response.send_message(e.__str__())
return
@ -641,6 +673,11 @@ class Notificator(commands.Cog):
if not await self.has_permission(intr):
return
channel_id = self.get_matching_channel_id(intr)
if channel_id is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
locations_ls = [word.strip() for word in locations.split(',')]
location_ids = []
for location in locations_ls:
@ -650,16 +687,7 @@ class Notificator(commands.Cog):
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
return
channel_id = intr.channel_id
channel = self.db.get_channel(channel_id)
if channel is None:
channel = self.db.get_channel(intr.user.id)
if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
self.db.remove_channel_districts(channel.id, location_ids)
self.db.remove_channel_districts(channel_id, location_ids)
await intr.response.send_message('Successfully removed all IDs')
@location_group.command(name='clear', description='Clear all registered locations (get alerts on all locations)')
@ -668,17 +696,13 @@ class Notificator(commands.Cog):
if not await self.has_permission(intr):
return
channel_id = self.get_matching_channel_id(intr)
if channel_id is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
conf_str = intr.user.name
channel_id = intr.channel_id
channel = self.db.get_channel(channel_id)
if channel is None:
channel = self.db.get_channel(intr.user.id)
if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
if confirmation is None:
await intr.response.send_message(
f'Are you sure you want to clear all registered locations?\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
@ -694,14 +718,10 @@ class Notificator(commands.Cog):
@location_group.command(name='registered', description='List all locations registered to this channel, by IDs and names.')
async def location_registered(self, intr: discord.Interaction, page: int = 1):
channel_id = intr.channel_id
channel = self.db.get_channel(channel_id)
if channel is None:
channel = self.db.get_channel(intr.user.id)
if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
channel_id = self.get_matching_channel_id(intr)
if channel_id is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
# Congrats! It's a MESS!
# This code is so ugly, but basically
@ -712,7 +732,7 @@ class Notificator(commands.Cog):
# Then it sorts it with the key being district_name
districts = sorted([self.db.get_district(district_id).to_tuple()
for district_id
in self.db.get_channel_district_ids(channel.id)],
in self.db.get_channel_district_ids(channel_id)],
key=lambda tup: tup[1])
page = self.locations_page(districts, page - 1)

View File

@ -199,6 +199,13 @@ class DBAccess:
return self.get_server(channel.server_id)
def channel_iterator(self):
"""
This function is DEPRECATED!
Please use get_all_channels() instead.
Reason: Cannot create more queries while an iterator is active due to unread results.
"""
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM channels')
@ -253,6 +260,13 @@ class DBAccess:
return ret
def district_iterator(self) -> DistrictIterator:
"""
This function is DEPRECATED!
Please use get_all_districts() instead.
Reason: Cannot create more queries while an iterator is active due to unread results.
"""
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM district')
@ -331,3 +345,7 @@ class DBAccess:
(channel_id,))
self.connection.commit()
def is_registered_channel(self, channel_id: int) -> bool:
return self.get_channel(channel_id) is not None