v2.2.0 RELEASE!

Added checker for each embed. I SERIOUSLY need to update generate_alert_embed ASAP into the DistrictEmbed class
This commit is contained in:
GaMeNu 2023-10-20 16:07:16 +03:00
parent 292243146d
commit bc65c908bb
2 changed files with 72 additions and 45 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import datetime
import json
import os
@ -12,7 +13,9 @@ from discord.ext import commands, tasks
from discord import app_commands
import logging
from db_access import DBAccess
import db_access
from db_access import *
from markdown import md
load_dotenv()
@ -80,6 +83,11 @@ class Alert:
data.get('desc'))
class DistrictEmbed(discord.Embed):
def __init__(self, district_id: int, **kwargs):
self.district_id = district_id
super().__init__(**kwargs)
# noinspection PyUnresolvedReferences
class Notificator(commands.Cog):
location_group = app_commands.Group(name='locations',
@ -168,9 +176,9 @@ class Notificator(commands.Cog):
@staticmethod
def generate_alert_embed(alert_object: Alert, district: str, arrival_time: int | None, time: str,
lang: str) -> discord.Embed:
lang: str, district_id: int) -> DistrictEmbed:
# TODO: Using 1 generate alert function is probably bad, should probably split into a utility class
e = discord.Embed(color=discord.Color.from_str('#FF0000'))
e = DistrictEmbed(district_id=district_id, color=discord.Color.from_str('#FF0000'))
e.title = f'התראה ב{district}'
e.add_field(name=district, value=alert_object.title, inline=False)
match alert_object.category:
@ -178,7 +186,7 @@ class Notificator(commands.Cog):
if arrival_time is not None:
e.add_field(name='זמן מיגון', value=f'{arrival_time} שניות', inline=False)
else:
e.add_field(name='זמן מיגון', value='שגיאה בהוצאת המידע', inline=False)
e.add_field(name='זמן מיגון', value='שגיאה באחזרת המידע', inline=False)
case _:
pass
@ -205,13 +213,12 @@ class Notificator(commands.Cog):
alert_history = None
self.log.info(f'Sending alerts to channels')
embed_ls: list[discord.Embed] = []
embed_ls_ls: list[list[discord.Embed]] = []
embed_ls: list[DistrictEmbed] = []
new_alert = Alert.from_dict(alert_data)
for district in new_districts:
district_data = self.db.get_district_by_name(district)
district_data = self.db.get_district_by_name(district) # DB
alert_time = datetime.datetime.now() # .strftime()
# TODO: THIS REQUIRES SIMPLIFICATION ASAP
@ -234,26 +241,27 @@ class Notificator(commands.Cog):
alert_time_str = alert_time.strftime("%H:%M:%S\n%d/%m/%Y")
if district_data is not None:
embed_ls.append(Notificator.generate_alert_embed(new_alert, district, district_data.migun_time,
alert_time_str, 'he'))
alert_time_str, 'he', district_data.id))
else:
embed_ls.append(Notificator.generate_alert_embed(new_alert, district, None, alert_time_str, 'he'))
if len(embed_ls) == 10:
embed_ls_ls.append(embed_ls)
embed_ls = []
embed_ls.append(Notificator.generate_alert_embed(new_alert, district, None, alert_time_str, 'he', district_data.id))
if len(embed_ls) > 0:
embed_ls_ls.append(embed_ls)
for channel in self.db.channel_iterator():
for channel_tup in self.db.get_all_channels():
channel = Channel.from_tuple(channel_tup)
if channel.server_id is not None:
dc_ch = self.bot.get_channel(channel.id)
else:
dc_ch = self.bot.get_user(channel.id)
for embed_list in embed_ls_ls:
channel_districts = self.db.get_channel_district_ids(channel.id)
for emb in embed_ls:
if dc_ch is None:
continue
if len(channel.locations) != 0 and emb.district_id not in channel.locations:
continue
try:
await dc_ch.send(embeds=embed_list, view=self.hfc_button_view())
await dc_ch.send(embed=emb, view=self.hfc_button_view())
await asyncio.sleep(0.01)
except BaseException as e:
self.log.warning(f'Failed to send alert in channel id={channel.id}:\n'
f'{e}')
@ -308,13 +316,14 @@ class Notificator(commands.Cog):
@app_commands.command(name='unregister',
description='Stop a channel from receiving HFC alerts (Requires Manage Channels)')
@app_commands.checks.has_permissions(manage_channels=True)
async def unregister_channel(self, intr: discord.Interaction, confirmation: str=None):
async def unregister_channel(self, intr: discord.Interaction, confirmation: str = None):
channel_id = intr.channel_id
conf_str = intr.user.name
if confirmation is None:
await intr.response.send_message(f'Are you sure you want to unregister the channel?\nThis action will also clear all related data.\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
await intr.response.send_message(
f'Are you sure you want to unregister the channel?\nThis action will also clear all related data.\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
return
if confirmation != conf_str:
await intr.response.send_message(f'Invalid confirmation string!')
@ -516,7 +525,8 @@ class Notificator(commands.Cog):
intr: discord.Interaction,
title: str = 'בדיקת מערכת שליחת התראות',
desc: str = 'התעלמו מהתראה זו',
districts: str = 'בדיקה'):
districts: str = 'בדיקה',
cat: int = 99):
if intr.user.id != AUTHOR_ID:
await intr.response.send_message('No access.')
return
@ -526,7 +536,7 @@ class Notificator(commands.Cog):
await self.send_new_alert({
"id": "133413211330000000",
"cat": "99",
"cat": str(cat),
"title": title,
"data": districts_ls,
"desc": desc
@ -567,7 +577,7 @@ class Notificator(commands.Cog):
if page < 0:
raise ValueError('Page number is too low.')
page_content = f'Page { md.b(f"{page + 1}/{max_page}") }\n\n<District ID> - <District name>\n\n'
page_content = f'Page {md.b(f"{page + 1}/{max_page}")}\n\n<District ID> - <District name>\n\n'
start_i = page * res_in_page
end_i = min(start_i + res_in_page, dist_len)
@ -586,7 +596,8 @@ class Notificator(commands.Cog):
return
if len(page) > 2000:
await intr.response.send_message('Page content exceeds character limit.\nPlease contact the bot authors with the command you\'ve tried to run.')
await intr.response.send_message(
'Page content exceeds character limit.\nPlease contact the bot authors with the command you\'ve tried to run.')
return
await intr.response.send_message(page)
@ -669,14 +680,16 @@ class Notificator(commands.Cog):
return
if confirmation is None:
await intr.response.send_message(f'Are you sure you want to clear all registered locations?\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
await intr.response.send_message(
f'Are you sure you want to clear all registered locations?\n{md.b("Warning:")} this action cannot be reversed!\nPlease type your username ("{conf_str}") in the confirmation argument to confirm.')
return
if confirmation != conf_str:
await intr.response.send_message(f'Invalid confirmation string!')
return
self.db.clear_channel_districts(channel.id)
await intr.response.send_message(f'Cleared all registered locations.\nChannel will now receive alerts from every location.')
await intr.response.send_message(
f'Cleared all registered locations.\nChannel will now receive alerts from every location.')
@location_group.command(name='registered', description='List all locations registered to this channel')
async def location_registered(self, intr: discord.Interaction, page: int = 1):
@ -702,7 +715,7 @@ class Notificator(commands.Cog):
in self.db.get_channel_district_ids(channel.id)],
key=lambda tup: tup[1])
page = self.locations_page(districts, page-1)
page = self.locations_page(districts, page - 1)
if len(page) > 2000:
await intr.response.send_message(

View File

@ -6,6 +6,7 @@ from typing import Sequence
from dotenv import load_dotenv
from mysql import connector as mysql
from mysql.connector.abstracts import MySQLCursorAbstract
load_dotenv()
DB_USERNAME = os.getenv('DB_USERNAME')
@ -44,6 +45,10 @@ class Channel:
self.channel_lang = channel_lang
self.locations = locations
@staticmethod
def from_tuple(tup: tuple):
return Channel(tup[0], tup[1], tup[2], json.loads(tup[3]))
class Server:
def __init__(self, id: int, lang: str):
@ -52,7 +57,7 @@ class Server:
class ChannelIterator:
def __init__(self, cursor: mysql.connection.MySQLCursor):
def __init__(self, cursor: MySQLCursorAbstract):
self.cursor = cursor
def __iter__(self):
@ -63,10 +68,11 @@ class ChannelIterator:
if res is None:
self.cursor.close()
raise StopIteration
return Channel(res[0], res[1], res[2], json.loads(res[3]))
return Channel.from_tuple(res)
class DistrictIterator:
def __init__(self, cursor: mysql.connection.MySQLCursor):
def __init__(self, cursor: MySQLCursorAbstract):
self.cursor = cursor
def __iter__(self):
@ -146,7 +152,7 @@ class DBAccess:
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM areas WHERE area_id=%s', (id,))
res = crsr.fetchone()
crsr.fetchall()
crsr.nextset()
if res is not None:
return Area.from_tuple(res)
@ -157,7 +163,7 @@ class DBAccess:
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM districts WHERE district_id=%s', (id,))
res = crsr.fetchone()
crsr.fetchall()
crsr.nextset()
if res is not None:
return District.from_tuple(res)
@ -171,7 +177,7 @@ class DBAccess:
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM servers WHERE server_id=%s', (id,))
res = crsr.fetchone()
crsr.fetchall()
crsr.nextset()
if res is not None:
return Server(res[0], res[1])
@ -182,10 +188,10 @@ class DBAccess:
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM channels WHERE channel_id=%s', (id,))
res = crsr.fetchone()
crsr.fetchall()
crsr.nextset()
if res is not None:
return Channel(res[0], res[1], res[2], json.loads(res[3]))
return Channel.from_tuple(res)
else:
return None
@ -193,11 +199,18 @@ class DBAccess:
return self.get_server(channel.server_id)
def channel_iterator(self):
crsr = self.connection.cursor()
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM channels')
crsr.execute('SELECT * FROM channels')
return ChannelIterator(crsr)
return ChannelIterator(crsr)
def get_all_channels(self):
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM channels')
res = crsr.fetchall()
return res
def remove_channel(self, id: int):
print(id)
@ -226,7 +239,7 @@ class DBAccess:
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM districts WHERE district_name=%s', (name,))
res = crsr.fetchone()
crsr.fetchall()
crsr.nextset()
if res is None:
return None
@ -240,17 +253,16 @@ class DBAccess:
return ret
def district_iterator(self) -> DistrictIterator:
crsr = self.connection.cursor()
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM district')
crsr.execute('SELECT * FROM district')
return DistrictIterator(crsr)
return DistrictIterator(crsr)
def add_channel_district(self, channel_id: int, district_id: int):
with self.connection.cursor() as crsr:
crsr.execute('SELECT * FROM districts WHERE district_id=%s', (district_id,))
res = crsr.fetchone()
crsr.fetchall()
crsr.nextset()
if res is None:
raise ValueError(f'Invalid District ID {district_id}')
@ -286,11 +298,13 @@ class DBAccess:
def get_channel_district_ids(self, channel_id: int) -> list:
with self.connection.cursor() as crsr:
crsr.nextset()
crsr.execute('SELECT locations '
'FROM channels '
'WHERE channel_id=%s;', (channel_id,))
dist = crsr.fetchone()
crsr.fetchall()
crsr.nextset()
districts = json.loads(dist[0])
return districts