mirror of
https://github.com/GaMeNu/HFCNotificator.git
synced 2024-11-16 15:24:51 +02:00
v2.2.0
One more command, small updates, unregister confirmatio
This commit is contained in:
parent
2689cef361
commit
292243146d
@ -308,9 +308,18 @@ class Notificator(commands.Cog):
|
|||||||
@app_commands.command(name='unregister',
|
@app_commands.command(name='unregister',
|
||||||
description='Stop a channel from receiving HFC alerts (Requires Manage Channels)')
|
description='Stop a channel from receiving HFC alerts (Requires Manage Channels)')
|
||||||
@app_commands.checks.has_permissions(manage_channels=True)
|
@app_commands.checks.has_permissions(manage_channels=True)
|
||||||
async def unregister_channel(self, intr: discord.Interaction):
|
async def unregister_channel(self, intr: discord.Interaction, confirmation: str=None):
|
||||||
channel_id = intr.channel_id
|
channel_id = intr.channel_id
|
||||||
|
|
||||||
|
conf_str = intr.user.name
|
||||||
|
|
||||||
|
if confirmation is None:
|
||||||
|
await intr.response.send_message(f'Are you sure you want to unregister the channel?\nThis action will also clear all related data.\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
|
||||||
|
return
|
||||||
|
if confirmation != conf_str:
|
||||||
|
await intr.response.send_message(f'Invalid confirmation string!')
|
||||||
|
return
|
||||||
|
|
||||||
channel = self.db.get_channel(channel_id)
|
channel = self.db.get_channel(channel_id)
|
||||||
if channel is None:
|
if channel is None:
|
||||||
channel = self.db.get_channel(intr.user.id)
|
channel = self.db.get_channel(intr.user.id)
|
||||||
@ -523,6 +532,13 @@ class Notificator(commands.Cog):
|
|||||||
"desc": desc
|
"desc": desc
|
||||||
}, districts_ls)
|
}, districts_ls)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def has_permission(intr: discord.Interaction) -> bool:
|
||||||
|
if intr.guild is not None and not intr.user.guild_permissions.manage_channels:
|
||||||
|
await intr.response.send_message('Error: You are missing the Manage Channels permission.')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def locations_page(data_list: list, page: int, res_in_page: int = 50) -> str:
|
def locations_page(data_list: list, page: int, res_in_page: int = 50) -> str:
|
||||||
"""
|
"""
|
||||||
@ -560,7 +576,7 @@ class Notificator(commands.Cog):
|
|||||||
|
|
||||||
return page_content
|
return page_content
|
||||||
|
|
||||||
@location_group.command(name='locations_list', description='Show the list of all available locations')
|
@location_group.command(name='list', description='Show the list of all available locations')
|
||||||
async def locations_list(self, intr: discord.Interaction, page: int = 1):
|
async def locations_list(self, intr: discord.Interaction, page: int = 1):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -578,9 +594,7 @@ class Notificator(commands.Cog):
|
|||||||
@location_group.command(name='add', description='Add a location(s) to the location list')
|
@location_group.command(name='add', description='Add a location(s) to the location list')
|
||||||
@app_commands.describe(locations='A list of comma-separated Area IDs')
|
@app_commands.describe(locations='A list of comma-separated Area IDs')
|
||||||
async def location_add(self, intr: discord.Interaction, locations: str):
|
async def location_add(self, intr: discord.Interaction, locations: str):
|
||||||
|
if not await self.has_permission(intr):
|
||||||
if intr.guild is not None and intr.user.guild_permissions.manage_channels is False:
|
|
||||||
await intr.response.send_message('Error: You are missing the Manage Channels permission.')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
locations_ls = [word.strip() for word in locations.split(',')]
|
locations_ls = [word.strip() for word in locations.split(',')]
|
||||||
@ -601,14 +615,19 @@ class Notificator(commands.Cog):
|
|||||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||||
return
|
return
|
||||||
|
|
||||||
self.db.add_channel_districts(channel.id, location_ids)
|
try:
|
||||||
|
self.db.add_channel_districts(channel.id, location_ids)
|
||||||
|
except ValueError as e:
|
||||||
|
await intr.response.send_message(e.__str__())
|
||||||
|
return
|
||||||
|
|
||||||
|
await intr.response.send_message('Successfully added all IDs')
|
||||||
|
|
||||||
@location_group.command(name='remove', description='Remove a location(s) to the location list')
|
@location_group.command(name='remove', description='Remove a location(s) to the location list')
|
||||||
@app_commands.describe(locations='A list of comma-separated Area IDs')
|
@app_commands.describe(locations='A list of comma-separated Area IDs')
|
||||||
async def location_remove(self, intr: discord.Interaction, locations: str):
|
async def location_remove(self, intr: discord.Interaction, locations: str):
|
||||||
|
|
||||||
if intr.guild is not None and intr.user.guild_permissions.manage_channels is False:
|
if not await self.has_permission(intr):
|
||||||
await intr.response.send_message('Error: You are missing the Manage Channels permission.')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
locations_ls = [word.strip() for word in locations.split(',')]
|
locations_ls = [word.strip() for word in locations.split(',')]
|
||||||
@ -630,6 +649,34 @@ class Notificator(commands.Cog):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.db.remove_channel_districts(channel.id, location_ids)
|
self.db.remove_channel_districts(channel.id, location_ids)
|
||||||
|
await intr.response.send_message('Successfully removed all IDs')
|
||||||
|
|
||||||
|
@location_group.command(name='clear', description='Clear all registered locations (get alerts on all locations)')
|
||||||
|
async def location_clear(self, intr: discord.Interaction, confirmation: str = None):
|
||||||
|
|
||||||
|
if not await self.has_permission(intr):
|
||||||
|
return
|
||||||
|
|
||||||
|
conf_str = intr.user.name
|
||||||
|
|
||||||
|
channel_id = intr.channel_id
|
||||||
|
|
||||||
|
channel = self.db.get_channel(channel_id)
|
||||||
|
if channel is None:
|
||||||
|
channel = self.db.get_channel(intr.user.id)
|
||||||
|
if channel is None:
|
||||||
|
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||||
|
return
|
||||||
|
|
||||||
|
if confirmation is None:
|
||||||
|
await intr.response.send_message(f'Are you sure you want to clear all registered locations?\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
|
||||||
|
return
|
||||||
|
if confirmation != conf_str:
|
||||||
|
await intr.response.send_message(f'Invalid confirmation string!')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.db.clear_channel_districts(channel.id)
|
||||||
|
await intr.response.send_message(f'Cleared all registered locations.\nChannel will now receive alerts from every location.')
|
||||||
|
|
||||||
@location_group.command(name='registered', description='List all locations registered to this channel')
|
@location_group.command(name='registered', description='List all locations registered to this channel')
|
||||||
async def location_registered(self, intr: discord.Interaction, page: int = 1):
|
async def location_registered(self, intr: discord.Interaction, page: int = 1):
|
||||||
@ -643,7 +690,17 @@ class Notificator(commands.Cog):
|
|||||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||||
return
|
return
|
||||||
|
|
||||||
districts = self.db.get_channel_districts(chanel.id)
|
# Congrats! It's a MESS!
|
||||||
|
# This code is so ugly, but basically
|
||||||
|
#
|
||||||
|
# It gets a list of all of a channel's districts
|
||||||
|
# Converts it to a tuple form in which
|
||||||
|
# district_tuple = (district_id: int, district_name: str, area_id: int, migun_time: int)
|
||||||
|
# Then it sorts it with the key being district_name
|
||||||
|
districts = sorted([self.db.get_district(district_id).to_tuple()
|
||||||
|
for district_id
|
||||||
|
in self.db.get_channel_district_ids(channel.id)],
|
||||||
|
key=lambda tup: tup[1])
|
||||||
|
|
||||||
page = self.locations_page(districts, page-1)
|
page = self.locations_page(districts, page-1)
|
||||||
|
|
||||||
|
50
db_access.py
50
db_access.py
@ -33,6 +33,9 @@ class District:
|
|||||||
def from_tuple(tup: Sequence):
|
def from_tuple(tup: Sequence):
|
||||||
return District(tup[0], tup[1], tup[2], tup[3])
|
return District(tup[0], tup[1], tup[2], tup[3])
|
||||||
|
|
||||||
|
def to_tuple(self) -> tuple:
|
||||||
|
return self.id, self.name, self.area_id, self.migun_time
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
def __init__(self, id: int, server_id: int | None, channel_lang: str, locations: list):
|
def __init__(self, id: int, server_id: int | None, channel_lang: str, locations: list):
|
||||||
@ -244,6 +247,13 @@ class DBAccess:
|
|||||||
return DistrictIterator(crsr)
|
return DistrictIterator(crsr)
|
||||||
|
|
||||||
def add_channel_district(self, channel_id: int, district_id: int):
|
def add_channel_district(self, channel_id: int, district_id: int):
|
||||||
|
with self.connection.cursor() as crsr:
|
||||||
|
crsr.execute('SELECT * FROM districts WHERE district_id=%s', (district_id,))
|
||||||
|
res = crsr.fetchone()
|
||||||
|
crsr.fetchall()
|
||||||
|
if res is None:
|
||||||
|
raise ValueError(f'Invalid District ID {district_id}')
|
||||||
|
|
||||||
with self.connection.cursor() as crsr:
|
with self.connection.cursor() as crsr:
|
||||||
crsr.execute("UPDATE channels "
|
crsr.execute("UPDATE channels "
|
||||||
"SET locations = JSON_ARRAY_APPEND(locations, '$', %s) "
|
"SET locations = JSON_ARRAY_APPEND(locations, '$', %s) "
|
||||||
@ -253,29 +263,42 @@ class DBAccess:
|
|||||||
|
|
||||||
def add_channel_districts(self, channel_id: int, district_ids: list[int]):
|
def add_channel_districts(self, channel_id: int, district_ids: list[int]):
|
||||||
with self.connection.cursor() as crsr:
|
with self.connection.cursor() as crsr:
|
||||||
|
# Sorry for the messy statement. I'm lazy and it's 02:13 rn
|
||||||
|
crsr.execute(f"SELECT * FROM districts WHERE district_id IN ({','.join(['%s'] * len(district_ids))})", tuple(district_ids))
|
||||||
|
res = crsr.fetchall()
|
||||||
|
|
||||||
|
if len(district_ids) > len(res):
|
||||||
|
raise ValueError('Received invalid district IDs')
|
||||||
|
|
||||||
|
# Sorry for this way of doing things (3 DB queries omg)
|
||||||
|
# JSON_MERGE_PATCH kept overwriting the existing data
|
||||||
|
# while JSON_MERGE_PRESERVE did not remove duplicates
|
||||||
|
dists = self.get_channel_district_ids(channel_id)
|
||||||
|
updated = [district for district in district_ids if district not in dists]
|
||||||
|
|
||||||
|
with self.connection.cursor() as crsr:
|
||||||
|
|
||||||
crsr.execute("UPDATE channels "
|
crsr.execute("UPDATE channels "
|
||||||
"SET locations = JSON_MERGE_PATCH(locations, %s) "
|
"SET locations = JSON_MERGE_PRESERVE(locations, %s) "
|
||||||
"WHERE channel_id=%s;",
|
"WHERE channel_id=%s;",
|
||||||
(json.dumps(district_ids), channel_id))
|
(json.dumps(updated), channel_id))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def get_channel_districts(self, channel_id: int) -> list:
|
def get_channel_district_ids(self, channel_id: int) -> list:
|
||||||
with self.connection.cursor() as crsr:
|
with self.connection.cursor() as crsr:
|
||||||
crsr.execute('SELECT locations '
|
crsr.execute('SELECT locations '
|
||||||
'FROM channels '
|
'FROM channels '
|
||||||
'WHERE channel_id=%s;', (channel_id,))
|
'WHERE channel_id=%s;', (channel_id,))
|
||||||
districts = []
|
|
||||||
dist = crsr.fetchone()
|
dist = crsr.fetchone()
|
||||||
while dist is not None:
|
crsr.fetchall()
|
||||||
districts.append(json.loads(dist[0]))
|
|
||||||
dist = crsr.fetchone()
|
|
||||||
|
|
||||||
return districts
|
districts = json.loads(dist[0])
|
||||||
|
return districts
|
||||||
|
|
||||||
def remove_channel_districts(self, channel_id: int, district_ids: list[int]):
|
def remove_channel_districts(self, channel_id: int, district_ids: list[int]):
|
||||||
with self.connection.cursor() as crsr:
|
with self.connection.cursor() as crsr:
|
||||||
|
|
||||||
districts = self.get_channel_districts(channel_id)
|
districts = self.get_channel_district_ids(channel_id)
|
||||||
|
|
||||||
updated = [district for district in districts if district not in district_ids]
|
updated = [district for district in districts if district not in district_ids]
|
||||||
|
|
||||||
@ -285,3 +308,12 @@ class DBAccess:
|
|||||||
(json.dumps(updated), channel_id))
|
(json.dumps(updated), channel_id))
|
||||||
|
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
|
def clear_channel_districts(self, channel_id: int):
|
||||||
|
with self.connection.cursor() as crsr:
|
||||||
|
crsr.execute('UPDATE channels '
|
||||||
|
'SET locations = JSON_ARRAY() '
|
||||||
|
'WHERE channel_id = %s;',
|
||||||
|
(channel_id,))
|
||||||
|
|
||||||
|
self.connection.commit()
|
||||||
|
Loading…
Reference in New Issue
Block a user