mirror of
https://github.com/GaMeNu/HFCNotificator.git
synced 2024-11-16 15:24:51 +02:00
v2.2.1
Fixed bot preferring user id over channel id in new location commands, as per issue #005
This commit is contained in:
parent
ae1d439f48
commit
6591bc6dd9
@ -122,6 +122,41 @@ class Notificator(commands.Cog):
|
||||
return
|
||||
self.check_for_updates.start()
|
||||
|
||||
def in_registered_channel(self, intr: discord.Interaction) -> bool | None:
|
||||
"""
|
||||
|
||||
:param intr: discord command interaction
|
||||
:return: True - is a registered server channel, False - is a registered DM, None - was not found (is not registered)
|
||||
"""
|
||||
|
||||
# OPTIONS:
|
||||
# Channel ID not None + DB not None: IS Channel and IS Registered => matching output and end
|
||||
# Channel ID not None + DB None: IS Channel and NOT Registered => matching output and end
|
||||
# Channel ID None cases:
|
||||
# User ID not None + DB not None: IS DM and IS Registered
|
||||
# User ID not None + DB None: IS DM and NOT Registered
|
||||
#
|
||||
# Off I go to make a utility function!
|
||||
|
||||
ch = self.bot.get_channel(intr.channel_id)
|
||||
if ch is not None and self.db.is_registered_channel(ch.id):
|
||||
return True
|
||||
|
||||
ch = self.bot.get_user(intr.user.id)
|
||||
if ch is not None and self.db.is_registered_channel(ch.id):
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
def get_matching_channel_id(self, intr: discord.Interaction) -> int | None:
|
||||
channel_type = self.in_registered_channel(intr)
|
||||
if channel_type is None:
|
||||
return None
|
||||
elif channel_type:
|
||||
return intr.channel_id
|
||||
else:
|
||||
return intr.user.id
|
||||
|
||||
@tasks.loop(seconds=1)
|
||||
async def check_for_updates(self):
|
||||
try:
|
||||
@ -605,9 +640,15 @@ class Notificator(commands.Cog):
|
||||
@location_group.command(name='add', description='Add a location(s) to the location list')
|
||||
@app_commands.describe(locations='A list of comma-separated Area IDs')
|
||||
async def location_add(self, intr: discord.Interaction, locations: str):
|
||||
|
||||
if not await self.has_permission(intr):
|
||||
return
|
||||
|
||||
channel_id = self.get_matching_channel_id(intr)
|
||||
if channel_id is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
locations_ls = [word.strip() for word in locations.split(',')]
|
||||
location_ids = []
|
||||
for location in locations_ls:
|
||||
@ -617,17 +658,8 @@ class Notificator(commands.Cog):
|
||||
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
|
||||
return
|
||||
|
||||
channel_id = intr.channel_id
|
||||
|
||||
channel = self.db.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = self.db.get_channel(intr.user.id)
|
||||
if channel is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
try:
|
||||
self.db.add_channel_districts(channel.id, location_ids)
|
||||
self.db.add_channel_districts(channel_id, location_ids)
|
||||
except ValueError as e:
|
||||
await intr.response.send_message(e.__str__())
|
||||
return
|
||||
@ -641,6 +673,11 @@ class Notificator(commands.Cog):
|
||||
if not await self.has_permission(intr):
|
||||
return
|
||||
|
||||
channel_id = self.get_matching_channel_id(intr)
|
||||
if channel_id is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
locations_ls = [word.strip() for word in locations.split(',')]
|
||||
location_ids = []
|
||||
for location in locations_ls:
|
||||
@ -650,16 +687,7 @@ class Notificator(commands.Cog):
|
||||
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
|
||||
return
|
||||
|
||||
channel_id = intr.channel_id
|
||||
|
||||
channel = self.db.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = self.db.get_channel(intr.user.id)
|
||||
if channel is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
self.db.remove_channel_districts(channel.id, location_ids)
|
||||
self.db.remove_channel_districts(channel_id, location_ids)
|
||||
await intr.response.send_message('Successfully removed all IDs')
|
||||
|
||||
@location_group.command(name='clear', description='Clear all registered locations (get alerts on all locations)')
|
||||
@ -668,17 +696,13 @@ class Notificator(commands.Cog):
|
||||
if not await self.has_permission(intr):
|
||||
return
|
||||
|
||||
channel_id = self.get_matching_channel_id(intr)
|
||||
if channel_id is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
conf_str = intr.user.name
|
||||
|
||||
channel_id = intr.channel_id
|
||||
|
||||
channel = self.db.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = self.db.get_channel(intr.user.id)
|
||||
if channel is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
if confirmation is None:
|
||||
await intr.response.send_message(
|
||||
f'Are you sure you want to clear all registered locations?\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
|
||||
@ -694,14 +718,10 @@ class Notificator(commands.Cog):
|
||||
@location_group.command(name='registered', description='List all locations registered to this channel, by IDs and names.')
|
||||
async def location_registered(self, intr: discord.Interaction, page: int = 1):
|
||||
|
||||
channel_id = intr.channel_id
|
||||
|
||||
channel = self.db.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = self.db.get_channel(intr.user.id)
|
||||
if channel is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
channel_id = self.get_matching_channel_id(intr)
|
||||
if channel_id is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
# Congrats! It's a MESS!
|
||||
# This code is so ugly, but basically
|
||||
@ -712,7 +732,7 @@ class Notificator(commands.Cog):
|
||||
# Then it sorts it with the key being district_name
|
||||
districts = sorted([self.db.get_district(district_id).to_tuple()
|
||||
for district_id
|
||||
in self.db.get_channel_district_ids(channel.id)],
|
||||
in self.db.get_channel_district_ids(channel_id)],
|
||||
key=lambda tup: tup[1])
|
||||
|
||||
page = self.locations_page(districts, page - 1)
|
||||
|
18
db_access.py
18
db_access.py
@ -199,6 +199,13 @@ class DBAccess:
|
||||
return self.get_server(channel.server_id)
|
||||
|
||||
def channel_iterator(self):
|
||||
"""
|
||||
This function is DEPRECATED!
|
||||
|
||||
Please use get_all_channels() instead.
|
||||
|
||||
Reason: Cannot create more queries while an iterator is active due to unread results.
|
||||
"""
|
||||
with self.connection.cursor() as crsr:
|
||||
|
||||
crsr.execute('SELECT * FROM channels')
|
||||
@ -253,6 +260,13 @@ class DBAccess:
|
||||
return ret
|
||||
|
||||
def district_iterator(self) -> DistrictIterator:
|
||||
"""
|
||||
This function is DEPRECATED!
|
||||
|
||||
Please use get_all_districts() instead.
|
||||
|
||||
Reason: Cannot create more queries while an iterator is active due to unread results.
|
||||
"""
|
||||
with self.connection.cursor() as crsr:
|
||||
crsr.execute('SELECT * FROM district')
|
||||
|
||||
@ -331,3 +345,7 @@ class DBAccess:
|
||||
(channel_id,))
|
||||
|
||||
self.connection.commit()
|
||||
|
||||
def is_registered_channel(self, channel_id: int) -> bool:
|
||||
return self.get_channel(channel_id) is not None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user