mirror of
https://github.com/GaMeNu/HFCNotificator.git
synced 2024-11-16 15:24:51 +02:00
v2.2.0
Started location setting system.
This commit is contained in:
parent
1e69add653
commit
1477003229
@ -1,6 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from _xxsubinterpreters import channel_recv
|
||||
|
||||
import requests
|
||||
|
||||
import discord
|
||||
@ -84,6 +87,8 @@ class Notificator(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot, handler: logging.Handler):
|
||||
self.bot = bot
|
||||
|
||||
self.bot.add_command(Location())
|
||||
|
||||
self.log = logging.Logger('Notificator')
|
||||
self.log.addHandler(handler)
|
||||
|
||||
@ -518,3 +523,135 @@ class Notificator(commands.Cog):
|
||||
"desc": desc
|
||||
}, districts_ls)
|
||||
|
||||
location_group = app_commands.Group(name='locations', description='Commands related adding, removing, or setting locations.')
|
||||
|
||||
@staticmethod
|
||||
def locations_page(data_list: list, page: int, res_in_page: int = 50) -> str:
|
||||
"""
|
||||
Page starts at 0
|
||||
|
||||
max_page is EXCLUSIVE
|
||||
:param data_list: custom data list to get page info of
|
||||
:param page: District page
|
||||
:param res_in_page: Amount of districts to put in one pages
|
||||
:return:
|
||||
"""
|
||||
|
||||
dist_ls = data_list
|
||||
|
||||
dist_len = len(dist_ls)
|
||||
|
||||
if dist_len == 0:
|
||||
return 'No results found.'
|
||||
|
||||
max_page = dist_len // res_in_page
|
||||
if dist_len % res_in_page != 0:
|
||||
max_page += 1
|
||||
|
||||
if page >= max_page:
|
||||
raise ValueError('Page number is too high.')
|
||||
if page < 0:
|
||||
raise ValueError('Page number is too low.')
|
||||
|
||||
page_content = f'Page { md.b(f"{page + 1}/{max_page}") }\n\n<District ID> - <District name>\n\n'
|
||||
|
||||
start_i = page * res_in_page
|
||||
end_i = min(start_i + res_in_page, dist_len)
|
||||
for district in dist_ls[start_i:end_i]:
|
||||
page_content += f'{district[0]} - {district[1]}\n'
|
||||
|
||||
return page_content
|
||||
|
||||
@location_group.command(name='locations_list', description='Show the list of all available locations')
|
||||
async def locations_list(self, intr: discord.Interaction, page: int = 1):
|
||||
|
||||
try:
|
||||
page = self.locations_page(sorted(self.db.get_all_districts(), key=lambda tup: tup[1]), page - 1)
|
||||
except ValueError as e:
|
||||
await intr.response.send_message(e.__str__())
|
||||
return
|
||||
|
||||
if len(page) > 2000:
|
||||
await intr.response.send_message('Page content exceeds character limit.\nPlease contact the bot authors with the command you\'ve tried to run.')
|
||||
return
|
||||
|
||||
await intr.response.send_message(page)
|
||||
|
||||
@location_group.command(name='add', description='Add a location(s) to the location list')
|
||||
@app_commands.describe(locations='A list of comma-separated Area IDs')
|
||||
async def location_add(self, intr: discord.Interaction, locations: str):
|
||||
|
||||
if intr.guild is not None and intr.user.guild_permissions.manage_channels is False:
|
||||
await intr.response.send_message('Error: You are missing the Manage Channels permission.')
|
||||
return
|
||||
|
||||
locations_ls = [word.strip() for word in locations.split(',')]
|
||||
location_ids = []
|
||||
for location in locations_ls:
|
||||
try:
|
||||
location_ids.append(int(location))
|
||||
except ValueError:
|
||||
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
|
||||
return
|
||||
|
||||
channel_id = intr.channel_id
|
||||
|
||||
channel = self.db.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = self.db.get_channel(intr.user.id)
|
||||
if channel is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
self.db.add_channel_districts(channel.id, location_ids)
|
||||
|
||||
@location_group.command(name='remove', description='Remove a location(s) to the location list')
|
||||
@app_commands.describe(locations='A list of comma-separated Area IDs')
|
||||
async def location_remove(self, intr: discord.Interaction, locations: str):
|
||||
|
||||
if intr.guild is not None and intr.user.guild_permissions.manage_channels is False:
|
||||
await intr.response.send_message('Error: You are missing the Manage Channels permission.')
|
||||
return
|
||||
|
||||
locations_ls = [word.strip() for word in locations.split(',')]
|
||||
location_ids = []
|
||||
for location in locations_ls:
|
||||
try:
|
||||
location_ids.append(int(location))
|
||||
except ValueError:
|
||||
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
|
||||
return
|
||||
|
||||
channel_id = intr.channel_id
|
||||
|
||||
channel = self.db.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = self.db.get_channel(intr.user.id)
|
||||
if channel is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
self.db.remove_channel_districts(channel.id, location_ids)
|
||||
|
||||
@location_group.command(name='registered', description='List all locations registered to this channel')
|
||||
async def location_registered(self, intr: discord.Interaction, page: int = 1):
|
||||
|
||||
channel_id = intr.channel_id
|
||||
|
||||
channel = self.db.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = self.db.get_channel(intr.user.id)
|
||||
if channel is None:
|
||||
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
|
||||
return
|
||||
|
||||
districts = self.db.get_channel_districts(chanel.id)
|
||||
|
||||
page = self.locations_page(districts, page-1)
|
||||
|
||||
if len(page) > 2000:
|
||||
await intr.response.send_message(
|
||||
'Page content exceeds character limit.\nPlease contact the bot authors with the command you\'ve tried to run.')
|
||||
return
|
||||
|
||||
await intr.response.send_message(page)
|
||||
|
95
db_access.py
95
db_access.py
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Sequence
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from mysql import connector as mysql
|
||||
@ -15,6 +17,10 @@ class Area:
|
||||
self.id = id
|
||||
self.name = name
|
||||
|
||||
@staticmethod
|
||||
def from_tuple(tup: Sequence):
|
||||
return Area(tup[0], tup[1])
|
||||
|
||||
|
||||
class District:
|
||||
def __init__(self, id: int, name: str, area_id: int, migun_time: int):
|
||||
@ -23,12 +29,17 @@ class District:
|
||||
self.area_id = area_id
|
||||
self.migun_time = migun_time
|
||||
|
||||
@staticmethod
|
||||
def from_tuple(tup: Sequence):
|
||||
return District(tup[0], tup[1], tup[2], tup[3])
|
||||
|
||||
|
||||
class Channel:
|
||||
def __init__(self, id: int, server_id: int | None, channel_lang: str):
|
||||
def __init__(self, id: int, server_id: int | None, channel_lang: str, locations: list):
|
||||
self.id = id
|
||||
self.server_id = server_id
|
||||
self.channel_lang = channel_lang
|
||||
self.locations = locations
|
||||
|
||||
|
||||
class Server:
|
||||
@ -49,7 +60,21 @@ class ChannelIterator:
|
||||
if res is None:
|
||||
self.cursor.close()
|
||||
raise StopIteration
|
||||
return Channel(res[0], res[1], res[2])
|
||||
return Channel(res[0], res[1], res[2], json.loads(res[3]))
|
||||
|
||||
class DistrictIterator:
|
||||
def __init__(self, cursor: mysql.connection.MySQLCursor):
|
||||
self.cursor = cursor
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> District:
|
||||
res = self.cursor.fetchone()
|
||||
if res is None:
|
||||
self.cursor.close()
|
||||
raise StopIteration
|
||||
return District.from_tuple(res)
|
||||
|
||||
|
||||
class DBAccess:
|
||||
@ -75,7 +100,7 @@ class DBAccess:
|
||||
break
|
||||
except mysql.Error as e:
|
||||
self.connection = None
|
||||
log.warning(f"Couldn't connect to db. This is attempt #{i}")
|
||||
log.warning(f"Couldn't connect to db. This is attempt #{i}\n{e.msg}")
|
||||
time.sleep(5)
|
||||
|
||||
if self.connection is None:
|
||||
@ -121,7 +146,7 @@ class DBAccess:
|
||||
crsr.fetchall()
|
||||
|
||||
if res is not None:
|
||||
return Area(res[0], res[1])
|
||||
return Area.from_tuple(res)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -132,7 +157,7 @@ class DBAccess:
|
||||
crsr.fetchall()
|
||||
|
||||
if res is not None:
|
||||
return District(res[0], res[1], res[2], res[3])
|
||||
return District.from_tuple(res)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -157,7 +182,7 @@ class DBAccess:
|
||||
crsr.fetchall()
|
||||
|
||||
if res is not None:
|
||||
return Channel(res[0], res[1], res[2])
|
||||
return Channel(res[0], res[1], res[2], json.loads(res[3]))
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -203,4 +228,60 @@ class DBAccess:
|
||||
if res is None:
|
||||
return None
|
||||
|
||||
return District(res[0], res[1], res[2], res[3])
|
||||
return District.from_tuple(res)
|
||||
|
||||
def get_all_districts(self) -> Sequence:
|
||||
with self.connection.cursor() as crsr:
|
||||
crsr.execute('SELECT * FROM districts')
|
||||
ret = crsr.fetchall()
|
||||
return ret
|
||||
|
||||
def district_iterator(self) -> DistrictIterator:
|
||||
crsr = self.connection.cursor()
|
||||
|
||||
crsr.execute('SELECT * FROM district')
|
||||
|
||||
return DistrictIterator(crsr)
|
||||
|
||||
def add_channel_district(self, channel_id: int, district_id: int):
|
||||
with self.connection.cursor() as crsr:
|
||||
crsr.execute("UPDATE channels "
|
||||
"SET locations = JSON_ARRAY_APPEND(locations, '$', %s) "
|
||||
"WHERE channel_id=%s;",
|
||||
(district_id, channel_id))
|
||||
self.connection.commit()
|
||||
|
||||
def add_channel_districts(self, channel_id: int, district_ids: list[int]):
|
||||
with self.connection.cursor() as crsr:
|
||||
crsr.execute("UPDATE channels "
|
||||
"SET locations = JSON_MERGE_PATCH(locations, %s) "
|
||||
"WHERE channel_id=%s;",
|
||||
(json.dumps(district_ids), channel_id))
|
||||
self.connection.commit()
|
||||
|
||||
def get_channel_districts(self, channel_id: int) -> list:
|
||||
with self.connection.cursor() as crsr:
|
||||
crsr.execute('SELECT locations '
|
||||
'FROM channels '
|
||||
'WHERE channel_id=%s;', (channel_id,))
|
||||
districts = []
|
||||
dist = crsr.fetchone()
|
||||
while dist is not None:
|
||||
districts.append(json.loads(dist[0]))
|
||||
dist = crsr.fetchone()
|
||||
|
||||
return districts
|
||||
|
||||
def remove_channel_districts(self, channel_id: int, district_ids: list[int]):
|
||||
with self.connection.cursor() as crsr:
|
||||
|
||||
districts = self.get_channel_districts(channel_id)
|
||||
|
||||
updated = [district for district in districts if district not in district_ids]
|
||||
|
||||
crsr.execute('UPDATE channels '
|
||||
'SET locations = %s '
|
||||
'WHERE channel_id = %s;',
|
||||
(json.dumps(updated), channel_id))
|
||||
|
||||
self.connection.commit()
|
||||
|
@ -81,7 +81,7 @@ CREATE TABLE IF NOT EXISTS `hfc_db`.`channels` (
|
||||
`channel_id` BIGINT(8) UNSIGNED NOT NULL,
|
||||
`server_id` BIGINT(8) UNSIGNED NULL,
|
||||
`channel_lang` VARCHAR(15) NOT NULL,
|
||||
`locations` JSON NULL,
|
||||
`locations` JSON DEFAULT JSON_ARRAY(),
|
||||
PRIMARY KEY (`channel_id`),
|
||||
UNIQUE INDEX `channel_id_UNIQUE` (`channel_id` ASC) VISIBLE,
|
||||
CONSTRAINT `server_id`
|
||||
|
@ -14,7 +14,9 @@ DB_PASSWORD = os.getenv('DB_PASSWORD')
|
||||
|
||||
def updater_1_0_0(connection: mysql.connection.MySQLConnection) -> str:
|
||||
crsr = connection.cursor()
|
||||
crsr.execute('ALTER TABLE `hfc_db`.`channels` ADD COLUMN `locations` JSON NULL;')
|
||||
crsr.execute('ALTER TABLE `hfc_db`.`channels` '
|
||||
'ADD COLUMN `locations` JSON '
|
||||
'DEFAULT JSON_ARRAY();')
|
||||
crsr.close()
|
||||
return '1.0.1'
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user