Started location setting system.
This commit is contained in:
GaMeNu 2023-10-20 01:20:24 +03:00
parent 1e69add653
commit 1477003229
4 changed files with 229 additions and 9 deletions

View File

@ -1,6 +1,9 @@
import datetime
import json
import os
import random
from _xxsubinterpreters import channel_recv
import requests
import discord
@ -84,6 +87,8 @@ class Notificator(commands.Cog):
def __init__(self, bot: commands.Bot, handler: logging.Handler):
self.bot = bot
self.bot.add_command(Location())
self.log = logging.Logger('Notificator')
self.log.addHandler(handler)
@ -518,3 +523,135 @@ class Notificator(commands.Cog):
"desc": desc
}, districts_ls)
location_group = app_commands.Group(name='locations', description='Commands related adding, removing, or setting locations.')
@staticmethod
def locations_page(data_list: list, page: int, res_in_page: int = 50) -> str:
"""
Page starts at 0
max_page is EXCLUSIVE
:param data_list: custom data list to get page info of
:param page: District page
:param res_in_page: Amount of districts to put in one pages
:return:
"""
dist_ls = data_list
dist_len = len(dist_ls)
if dist_len == 0:
return 'No results found.'
max_page = dist_len // res_in_page
if dist_len % res_in_page != 0:
max_page += 1
if page >= max_page:
raise ValueError('Page number is too high.')
if page < 0:
raise ValueError('Page number is too low.')
page_content = f'Page { md.b(f"{page + 1}/{max_page}") }\n\n<District ID> - <District name>\n\n'
start_i = page * res_in_page
end_i = min(start_i + res_in_page, dist_len)
for district in dist_ls[start_i:end_i]:
page_content += f'{district[0]} - {district[1]}\n'
return page_content
@location_group.command(name='locations_list', description='Show the list of all available locations')
async def locations_list(self, intr: discord.Interaction, page: int = 1):
try:
page = self.locations_page(sorted(self.db.get_all_districts(), key=lambda tup: tup[1]), page - 1)
except ValueError as e:
await intr.response.send_message(e.__str__())
return
if len(page) > 2000:
await intr.response.send_message('Page content exceeds character limit.\nPlease contact the bot authors with the command you\'ve tried to run.')
return
await intr.response.send_message(page)
@location_group.command(name='add', description='Add a location(s) to the location list')
@app_commands.describe(locations='A list of comma-separated Area IDs')
async def location_add(self, intr: discord.Interaction, locations: str):
if intr.guild is not None and intr.user.guild_permissions.manage_channels is False:
await intr.response.send_message('Error: You are missing the Manage Channels permission.')
return
locations_ls = [word.strip() for word in locations.split(',')]
location_ids = []
for location in locations_ls:
try:
location_ids.append(int(location))
except ValueError:
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
return
channel_id = intr.channel_id
channel = self.db.get_channel(channel_id)
if channel is None:
channel = self.db.get_channel(intr.user.id)
if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
self.db.add_channel_districts(channel.id, location_ids)
@location_group.command(name='remove', description='Remove a location(s) to the location list')
@app_commands.describe(locations='A list of comma-separated Area IDs')
async def location_remove(self, intr: discord.Interaction, locations: str):
if intr.guild is not None and intr.user.guild_permissions.manage_channels is False:
await intr.response.send_message('Error: You are missing the Manage Channels permission.')
return
locations_ls = [word.strip() for word in locations.split(',')]
location_ids = []
for location in locations_ls:
try:
location_ids.append(int(location))
except ValueError:
await intr.response.send_message(f'District ID {md.b(f"{location}")} is not a valid district ID.')
return
channel_id = intr.channel_id
channel = self.db.get_channel(channel_id)
if channel is None:
channel = self.db.get_channel(intr.user.id)
if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
self.db.remove_channel_districts(channel.id, location_ids)
@location_group.command(name='registered', description='List all locations registered to this channel')
async def location_registered(self, intr: discord.Interaction, page: int = 1):
channel_id = intr.channel_id
channel = self.db.get_channel(channel_id)
if channel is None:
channel = self.db.get_channel(intr.user.id)
if channel is None:
await intr.response.send_message('Could not find this channel. Are you sure it is registered?')
return
districts = self.db.get_channel_districts(chanel.id)
page = self.locations_page(districts, page-1)
if len(page) > 2000:
await intr.response.send_message(
'Page content exceeds character limit.\nPlease contact the bot authors with the command you\'ve tried to run.')
return
await intr.response.send_message(page)

View File

@ -1,6 +1,8 @@
import json
import logging
import os
import time
from typing import Sequence
from dotenv import load_dotenv
from mysql import connector as mysql
@ -15,6 +17,10 @@ class Area:
self.id = id
self.name = name
@staticmethod
def from_tuple(tup: Sequence):
return Area(tup[0], tup[1])
class District:
def __init__(self, id: int, name: str, area_id: int, migun_time: int):
@ -23,12 +29,17 @@ class District:
self.area_id = area_id
self.migun_time = migun_time
@staticmethod
def from_tuple(tup: Sequence):
return District(tup[0], tup[1], tup[2], tup[3])
class Channel:
def __init__(self, id: int, server_id: int | None, channel_lang: str):
def __init__(self, id: int, server_id: int | None, channel_lang: str, locations: list):
self.id = id
self.server_id = server_id
self.channel_lang = channel_lang
self.locations = locations
class Server:
@ -49,7 +60,21 @@ class ChannelIterator:
if res is None:
self.cursor.close()
raise StopIteration
return Channel(res[0], res[1], res[2])
return Channel(res[0], res[1], res[2], json.loads(res[3]))
class DistrictIterator:
def __init__(self, cursor: mysql.connection.MySQLCursor):
self.cursor = cursor
def __iter__(self):
return self
def __next__(self) -> District:
res = self.cursor.fetchone()
if res is None:
self.cursor.close()
raise StopIteration
return District.from_tuple(res)
class DBAccess:
@ -75,7 +100,7 @@ class DBAccess:
break
except mysql.Error as e:
self.connection = None
log.warning(f"Couldn't connect to db. This is attempt #{i}")
log.warning(f"Couldn't connect to db. This is attempt #{i}\n{e.msg}")
time.sleep(5)
if self.connection is None:
@ -121,7 +146,7 @@ class DBAccess:
crsr.fetchall()
if res is not None:
return Area(res[0], res[1])
return Area.from_tuple(res)
else:
return None
@ -132,7 +157,7 @@ class DBAccess:
crsr.fetchall()
if res is not None:
return District(res[0], res[1], res[2], res[3])
return District.from_tuple(res)
else:
return None
@ -157,7 +182,7 @@ class DBAccess:
crsr.fetchall()
if res is not None:
return Channel(res[0], res[1], res[2])
return Channel(res[0], res[1], res[2], json.loads(res[3]))
else:
return None
@ -203,4 +228,60 @@ class DBAccess:
if res is None:
return None
return District(res[0], res[1], res[2], res[3])
return District.from_tuple(res)
def get_all_districts(self) -> Sequence:
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM districts')
ret = crsr.fetchall()
return ret
def district_iterator(self) -> DistrictIterator:
crsr = self.connection.cursor()
crsr.execute('SELECT * FROM district')
return DistrictIterator(crsr)
def add_channel_district(self, channel_id: int, district_id: int):
with self.connection.cursor() as crsr:
crsr.execute("UPDATE channels "
"SET locations = JSON_ARRAY_APPEND(locations, '$', %s) "
"WHERE channel_id=%s;",
(district_id, channel_id))
self.connection.commit()
def add_channel_districts(self, channel_id: int, district_ids: list[int]):
with self.connection.cursor() as crsr:
crsr.execute("UPDATE channels "
"SET locations = JSON_MERGE_PATCH(locations, %s) "
"WHERE channel_id=%s;",
(json.dumps(district_ids), channel_id))
self.connection.commit()
def get_channel_districts(self, channel_id: int) -> list:
with self.connection.cursor() as crsr:
crsr.execute('SELECT locations '
'FROM channels '
'WHERE channel_id=%s;', (channel_id,))
districts = []
dist = crsr.fetchone()
while dist is not None:
districts.append(json.loads(dist[0]))
dist = crsr.fetchone()
return districts
def remove_channel_districts(self, channel_id: int, district_ids: list[int]):
with self.connection.cursor() as crsr:
districts = self.get_channel_districts(channel_id)
updated = [district for district in districts if district not in district_ids]
crsr.execute('UPDATE channels '
'SET locations = %s '
'WHERE channel_id = %s;',
(json.dumps(updated), channel_id))
self.connection.commit()

View File

@ -81,7 +81,7 @@ CREATE TABLE IF NOT EXISTS `hfc_db`.`channels` (
`channel_id` BIGINT(8) UNSIGNED NOT NULL,
`server_id` BIGINT(8) UNSIGNED NULL,
`channel_lang` VARCHAR(15) NOT NULL,
`locations` JSON NULL,
`locations` JSON DEFAULT JSON_ARRAY(),
PRIMARY KEY (`channel_id`),
UNIQUE INDEX `channel_id_UNIQUE` (`channel_id` ASC) VISIBLE,
CONSTRAINT `server_id`

View File

@ -14,7 +14,9 @@ DB_PASSWORD = os.getenv('DB_PASSWORD')
def updater_1_0_0(connection: mysql.connection.MySQLConnection) -> str:
crsr = connection.cursor()
crsr.execute('ALTER TABLE `hfc_db`.`channels` ADD COLUMN `locations` JSON NULL;')
crsr.execute('ALTER TABLE `hfc_db`.`channels` '
'ADD COLUMN `locations` JSON '
'DEFAULT JSON_ARRAY();')
crsr.close()
return '1.0.1'