mirror of
https://github.com/GaMeNu/HFCNotificator.git
synced 2024-11-16 15:24:51 +02:00
v2.2.3
Added method to re-open db connection. Final commit of version 2.2.3
This commit is contained in:
parent
81f1cad483
commit
d95289c262
238
db_access.py
238
db_access.py
@ -14,60 +14,171 @@ DB_PASSWORD = os.getenv('DB_PASSWORD')
|
|||||||
|
|
||||||
|
|
||||||
class Area:
|
class Area:
|
||||||
|
"""
|
||||||
|
An object representing an Area record in the database
|
||||||
|
|
||||||
|
:var id: area id
|
||||||
|
:var name: area name
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, id: int, name: str):
|
def __init__(self, id: int, name: str):
|
||||||
|
"""
|
||||||
|
Create a new Area object
|
||||||
|
:param id: area id
|
||||||
|
:param name: area name
|
||||||
|
"""
|
||||||
self.id = id
|
self.id = id
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_tuple(tup: Sequence):
|
def from_tuple(cls, tup: Sequence):
|
||||||
return Area(tup[0], tup[1])
|
"""
|
||||||
|
Create an area object from tuple of the following form:
|
||||||
|
|
||||||
|
(area_id: int, area_name: str)
|
||||||
|
|
||||||
|
:param tup: tuple to use
|
||||||
|
:return: new Area object
|
||||||
|
"""
|
||||||
|
return cls(tup[0], tup[1])
|
||||||
|
|
||||||
|
|
||||||
class District:
|
class District:
|
||||||
|
"""
|
||||||
|
An object representing a District record in the database
|
||||||
|
|
||||||
|
:var district_id: District ID
|
||||||
|
:var name: District name
|
||||||
|
:var area_id: Area ID of the area the district belongs to
|
||||||
|
:var migun_time: Time (in seconds) to reach shelters in case of a missile alert
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, id: int, name: str, area_id: int, migun_time: int):
|
def __init__(self, id: int, name: str, area_id: int, migun_time: int):
|
||||||
|
"""
|
||||||
|
:param id: district ID
|
||||||
|
:param name: district name
|
||||||
|
:param area_id: area id
|
||||||
|
:param migun_time: migun time
|
||||||
|
"""
|
||||||
self.district_id = id
|
self.district_id = id
|
||||||
self.name = name
|
self.name = name
|
||||||
self.area_id = area_id
|
self.area_id = area_id
|
||||||
self.migun_time = migun_time
|
self.migun_time = migun_time
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_tuple(tup: Sequence):
|
def from_tuple(cls, tup: Sequence):
|
||||||
return District(tup[0], tup[1], tup[2], tup[3])
|
"""
|
||||||
|
Create a District object from tuple of the form:
|
||||||
|
|
||||||
|
(id: int, name: str, area_id: int, migun_time: int)
|
||||||
|
|
||||||
|
:param tup: Tuple to pass
|
||||||
|
:return: new instance
|
||||||
|
"""
|
||||||
|
return cls(tup[0], tup[1], tup[2], tup[3])
|
||||||
|
|
||||||
def to_tuple(self) -> tuple:
|
def to_tuple(self) -> tuple:
|
||||||
|
"""
|
||||||
|
Convert the district back to Tuple form
|
||||||
|
:return: tuple representation of district
|
||||||
|
"""
|
||||||
return self.district_id, self.name, self.area_id, self.migun_time
|
return self.district_id, self.name, self.area_id, self.migun_time
|
||||||
|
|
||||||
|
|
||||||
class AreaDistrict(District):
|
class AreaDistrict(District):
|
||||||
|
"""
|
||||||
|
A child class of district, containing also an Area object of which the district belongs to
|
||||||
|
|
||||||
|
:var district_id: District ID
|
||||||
|
:var name: District name
|
||||||
|
:var area_id: Area ID of the area the district belongs to
|
||||||
|
:var migun_time: Time (in seconds) to reach shelters in case of a missile alert
|
||||||
|
:var area: Area object of said area_id
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, id: int, name: str, area_id: int, migun_time: int, area: Area):
|
def __init__(self, id: int, name: str, area_id: int, migun_time: int, area: Area):
|
||||||
|
"""
|
||||||
|
:param id: district ID
|
||||||
|
:param name: district name
|
||||||
|
:param area_id: area id
|
||||||
|
:param migun_time: migun time
|
||||||
|
:param area: Area object
|
||||||
|
"""
|
||||||
super().__init__(id, name, area_id, migun_time)
|
super().__init__(id, name, area_id, migun_time)
|
||||||
self.area = area
|
self.area = area
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_district(cls, district: District, area: Area):
|
def from_district(cls, district: District, area: Area):
|
||||||
|
"""
|
||||||
|
Get a District and an Area, and return an AreaDistrict
|
||||||
|
:param district: District object
|
||||||
|
:param area: Area object
|
||||||
|
:return: new instance of AreaDistrict
|
||||||
|
"""
|
||||||
return cls(district.district_id, district.name, district.area_id, district.migun_time, area)
|
return cls(district.district_id, district.name, district.area_id, district.migun_time, area)
|
||||||
|
|
||||||
|
|
||||||
class Channel:
|
class Channel:
|
||||||
|
"""
|
||||||
|
An object representing a Channel record in the database
|
||||||
|
|
||||||
|
:var id: channel id (matches the discord channel/user id)
|
||||||
|
:var server_id: Channel's server ID (None if is a DM)
|
||||||
|
:var channel_lang: obsolete, just pass in 'he'
|
||||||
|
:var locations: a list of ints, each int correlating to a District ID
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, id: int, server_id: int | None, channel_lang: str, locations: list):
|
def __init__(self, id: int, server_id: int | None, channel_lang: str, locations: list):
|
||||||
|
"""
|
||||||
|
:param id: channel ID
|
||||||
|
:param server_id: server ID (None for DMs)
|
||||||
|
:param channel_lang: obsolete, just pass in 'he'
|
||||||
|
:param locations: List of District IDs
|
||||||
|
"""
|
||||||
self.id = id
|
self.id = id
|
||||||
self.server_id = server_id
|
self.server_id = server_id
|
||||||
self.channel_lang = channel_lang
|
self.channel_lang = channel_lang
|
||||||
self.locations = locations
|
self.locations = locations
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_tuple(tup: tuple):
|
def from_tuple(cls, tup: tuple):
|
||||||
return Channel(tup[0], tup[1], tup[2], json.loads(tup[3]))
|
"""
|
||||||
|
Create a Channel object from a tuple of this form:
|
||||||
|
|
||||||
|
(id: int, server_id: int | None, channel_lang: str, locations: list)
|
||||||
|
|
||||||
|
:param tup: Tuple to pass
|
||||||
|
:return: New Channel instance
|
||||||
|
"""
|
||||||
|
return cls(tup[0], tup[1], tup[2], json.loads(tup[3]))
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
|
"""
|
||||||
|
An object representing a Server record in the database
|
||||||
|
|
||||||
|
:var id: Server ID
|
||||||
|
:var lang: obsolete, pass in 'he'
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, id: int, lang: str):
|
def __init__(self, id: int, lang: str):
|
||||||
|
"""
|
||||||
|
:param id: Server ID
|
||||||
|
:param lang: obsolete, pass in 'he'
|
||||||
|
"""
|
||||||
self.id = id
|
self.id = id
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
|
|
||||||
|
|
||||||
class ChannelIterator:
|
class ChannelIterator:
|
||||||
|
"""
|
||||||
|
DEPRECATED!
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, cursor: MySQLCursorAbstract):
|
def __init__(self, cursor: MySQLCursorAbstract):
|
||||||
|
raise DeprecationWarning(
|
||||||
|
'This class does not allow database queries while active, and thus has been deprecated.')
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@ -83,6 +194,8 @@ class ChannelIterator:
|
|||||||
|
|
||||||
class DistrictIterator:
|
class DistrictIterator:
|
||||||
def __init__(self, cursor: MySQLCursorAbstract):
|
def __init__(self, cursor: MySQLCursorAbstract):
|
||||||
|
raise DeprecationWarning(
|
||||||
|
'This class does not allow database queries while active, and thus has been deprecated.')
|
||||||
self.cursor = cursor
|
self.cursor = cursor
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@ -97,69 +210,73 @@ class DistrictIterator:
|
|||||||
|
|
||||||
|
|
||||||
class DBAccess:
|
class DBAccess:
|
||||||
|
|
||||||
|
def get_cursor(self):
|
||||||
|
try:
|
||||||
|
crsr = self.connection.cursor()
|
||||||
|
except mysql.errors.OperationalError:
|
||||||
|
self.connection.reconnect()
|
||||||
|
crsr = self.connection.cursor()
|
||||||
|
|
||||||
|
return crsr
|
||||||
|
|
||||||
def __init__(self, handler: logging.Handler = None):
|
def __init__(self, handler: logging.Handler = None):
|
||||||
|
|
||||||
log = logging.Logger('DBAccess')
|
self.log = logging.Logger('DBAccess')
|
||||||
|
|
||||||
if handler is not None:
|
if handler is not None:
|
||||||
log.addHandler(handler)
|
self.log.addHandler(handler)
|
||||||
else:
|
else:
|
||||||
log.addHandler(logging.StreamHandler())
|
self.log.addHandler(logging.StreamHandler())
|
||||||
|
|
||||||
self.connection = None
|
try:
|
||||||
|
|
||||||
for i in range(12):
|
|
||||||
try:
|
|
||||||
self.connection = mysql.connect(
|
|
||||||
host='localhost',
|
|
||||||
user=DB_USERNAME,
|
|
||||||
password=DB_PASSWORD,
|
|
||||||
database='hfc_db'
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except mysql.Error as e:
|
|
||||||
self.connection = None
|
|
||||||
log.warning(f"Couldn't connect to db. This is attempt #{i}\n{e.msg}")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
if self.connection is None:
|
|
||||||
self.connection = mysql.connect(
|
self.connection = mysql.connect(
|
||||||
host='localhost',
|
host='localhost',
|
||||||
user=DB_USERNAME,
|
user=DB_USERNAME,
|
||||||
password=DB_PASSWORD,
|
password=DB_PASSWORD,
|
||||||
database='hfc_db'
|
database='hfc_db'
|
||||||
)
|
)
|
||||||
|
except mysql.Error as e:
|
||||||
|
self.connection.reconnect(attempts=12, delay=5)
|
||||||
|
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.connection.close()
|
||||||
|
|
||||||
def add_area(self, area_id: int, area_name: str):
|
def add_area(self, area_id: int, area_name: str):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute(f'REPLACE INTO areas (area_id, area_name) VALUES (%s, %s)', (area_id, area_name))
|
crsr.execute(f'REPLACE INTO areas (area_id, area_name) VALUES (%s, %s)', (area_id, area_name))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def add_district(self, district_id: int, district_name: str, area_id: int, area_name: str, migun_time: int):
|
def add_district(self, district_id: int, district_name: str, area_id: int, area_name: str, migun_time: int):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute(f'SELECT * FROM areas WHERE area_id=%s', (area_id,))
|
crsr.execute(f'SELECT * FROM areas WHERE area_id=%s', (area_id,))
|
||||||
crsr.fetchall()
|
crsr.fetchall()
|
||||||
|
|
||||||
if crsr.rowcount == 0:
|
if crsr.rowcount == 0:
|
||||||
self.add_area(area_id, area_name)
|
self.add_area(area_id, area_name)
|
||||||
|
|
||||||
crsr.execute(f'REPLACE INTO districts (district_id, district_name, area_id, migun_time) VALUES (%s, %s, %s, %s)', (district_id, district_name, area_id, migun_time))
|
crsr.execute(
|
||||||
|
f'REPLACE INTO districts (district_id, district_name, area_id, migun_time) VALUES (%s, %s, %s, %s)',
|
||||||
|
(district_id, district_name, area_id, migun_time))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def add_server(self, server_id: int, server_lang: str):
|
def add_server(self, server_id: int, server_lang: str):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute(f'INSERT IGNORE INTO servers (server_id, server_lang) VALUES (%s, %s)', (server_id, server_lang))
|
crsr.execute(f'INSERT IGNORE INTO servers (server_id, server_lang) VALUES (%s, %s)',
|
||||||
|
(server_id, server_lang))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def add_channel(self, channel_id: int, server_id: int | None, channel_lang: str | None):
|
def add_channel(self, channel_id: int, server_id: int | None, channel_lang: str | None):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
if server_id is not None:
|
if server_id is not None:
|
||||||
self.add_server(server_id, channel_lang)
|
self.add_server(server_id, channel_lang)
|
||||||
crsr.execute(f'REPLACE INTO channels (channel_id, server_id, channel_lang) VALUES (%s, %s, %s)', (channel_id, server_id, channel_lang))
|
crsr.execute(f'REPLACE INTO channels (channel_id, server_id, channel_lang) VALUES (%s, %s, %s)',
|
||||||
|
(channel_id, server_id, channel_lang))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def get_area(self, id: int) -> Area | None:
|
def get_area(self, id: int) -> Area | None:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM areas WHERE area_id=%s', (id,))
|
crsr.execute('SELECT * FROM areas WHERE area_id=%s', (id,))
|
||||||
res = crsr.fetchone()
|
res = crsr.fetchone()
|
||||||
crsr.nextset()
|
crsr.nextset()
|
||||||
@ -170,7 +287,7 @@ class DBAccess:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_district(self, id: int) -> District | None:
|
def get_district(self, id: int) -> District | None:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM districts WHERE district_id=%s', (id,))
|
crsr.execute('SELECT * FROM districts WHERE district_id=%s', (id,))
|
||||||
res = crsr.fetchone()
|
res = crsr.fetchone()
|
||||||
crsr.nextset()
|
crsr.nextset()
|
||||||
@ -184,7 +301,7 @@ class DBAccess:
|
|||||||
return self.get_area(district.area_id)
|
return self.get_area(district.area_id)
|
||||||
|
|
||||||
def get_server(self, id: int) -> Server | None:
|
def get_server(self, id: int) -> Server | None:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM servers WHERE server_id=%s', (id,))
|
crsr.execute('SELECT * FROM servers WHERE server_id=%s', (id,))
|
||||||
res = crsr.fetchone()
|
res = crsr.fetchone()
|
||||||
crsr.nextset()
|
crsr.nextset()
|
||||||
@ -195,7 +312,7 @@ class DBAccess:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_channel(self, id: int) -> Channel | None:
|
def get_channel(self, id: int) -> Channel | None:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM channels WHERE channel_id=%s', (id,))
|
crsr.execute('SELECT * FROM channels WHERE channel_id=%s', (id,))
|
||||||
res = crsr.fetchone()
|
res = crsr.fetchone()
|
||||||
crsr.nextset()
|
crsr.nextset()
|
||||||
@ -218,43 +335,42 @@ class DBAccess:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("This function has been deprecated!")
|
raise NotImplementedError("This function has been deprecated!")
|
||||||
|
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
|
|
||||||
crsr.execute('SELECT * FROM channels')
|
crsr.execute('SELECT * FROM channels')
|
||||||
|
|
||||||
return ChannelIterator(crsr)
|
return ChannelIterator(crsr)
|
||||||
|
|
||||||
def get_all_channels(self):
|
def get_all_channels(self):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM channels')
|
crsr.execute('SELECT * FROM channels')
|
||||||
res = crsr.fetchall()
|
res = crsr.fetchall()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def remove_channel(self, id: int):
|
def remove_channel(self, id: int):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('DELETE FROM channels WHERE channel_id=%s', (id,))
|
crsr.execute('DELETE FROM channels WHERE channel_id=%s', (id,))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def remove_server(self, id: int):
|
def remove_server(self, id: int):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('DELETE FROM channels WHERE server_id=%s', (id,))
|
crsr.execute('DELETE FROM channels WHERE server_id=%s', (id,))
|
||||||
crsr.execute('DELETE FROM servers WHERE server_id=%s', (id,))
|
crsr.execute('DELETE FROM servers WHERE server_id=%s', (id,))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def remove_district(self, id: int):
|
def remove_district(self, id: int):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('DELETE FROM districts WHERE district_id=%s', (id,))
|
crsr.execute('DELETE FROM districts WHERE district_id=%s', (id,))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def remove_area(self, id: int):
|
def remove_area(self, id: int):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('DELETE FROM districts WHERE area_id=%s', (id,))
|
crsr.execute('DELETE FROM districts WHERE area_id=%s', (id,))
|
||||||
crsr.execute('DELETE FROM areas WHERE area_id=%s', (id,))
|
crsr.execute('DELETE FROM areas WHERE area_id=%s', (id,))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def get_district_by_name(self, name: str) -> District | None:
|
def get_district_by_name(self, name: str) -> District | None:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM districts WHERE district_name=%s', (name,))
|
crsr.execute('SELECT * FROM districts WHERE district_name=%s', (name,))
|
||||||
res = crsr.fetchone()
|
res = crsr.fetchone()
|
||||||
crsr.nextset()
|
crsr.nextset()
|
||||||
@ -265,13 +381,13 @@ class DBAccess:
|
|||||||
return District.from_tuple(res)
|
return District.from_tuple(res)
|
||||||
|
|
||||||
def get_all_districts(self) -> Sequence:
|
def get_all_districts(self) -> Sequence:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM districts')
|
crsr.execute('SELECT * FROM districts')
|
||||||
ret = crsr.fetchall()
|
ret = crsr.fetchall()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def search_districts(self, *tokens: str) -> Sequence:
|
def search_districts(self, *tokens: str) -> Sequence:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
query = 'SELECT * FROM districts WHERE '
|
query = 'SELECT * FROM districts WHERE '
|
||||||
query += ' AND '.join(["district_name LIKE %s" for _ in tokens])
|
query += ' AND '.join(["district_name LIKE %s" for _ in tokens])
|
||||||
query += ';'
|
query += ';'
|
||||||
@ -289,20 +405,20 @@ class DBAccess:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("This function has been deprecated!")
|
raise NotImplementedError("This function has been deprecated!")
|
||||||
|
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM district')
|
crsr.execute('SELECT * FROM district')
|
||||||
|
|
||||||
return DistrictIterator(crsr)
|
return DistrictIterator(crsr)
|
||||||
|
|
||||||
def add_channel_district(self, channel_id: int, district_id: int):
|
def add_channel_district(self, channel_id: int, district_id: int):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('SELECT * FROM districts WHERE district_id=%s', (district_id,))
|
crsr.execute('SELECT * FROM districts WHERE district_id=%s', (district_id,))
|
||||||
res = crsr.fetchone()
|
res = crsr.fetchone()
|
||||||
crsr.nextset()
|
crsr.nextset()
|
||||||
if res is None:
|
if res is None:
|
||||||
raise ValueError(f'Invalid District ID {district_id}')
|
raise ValueError(f'Invalid District ID {district_id}')
|
||||||
|
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute("UPDATE channels "
|
crsr.execute("UPDATE channels "
|
||||||
"SET locations = JSON_ARRAY_APPEND(locations, '$', %s) "
|
"SET locations = JSON_ARRAY_APPEND(locations, '$', %s) "
|
||||||
"WHERE channel_id=%s;",
|
"WHERE channel_id=%s;",
|
||||||
@ -310,9 +426,10 @@ class DBAccess:
|
|||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def add_channel_districts(self, channel_id: int, district_ids: list[int]):
|
def add_channel_districts(self, channel_id: int, district_ids: list[int]):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
# Sorry for the messy statement. I'm lazy and it's 02:13 rn
|
# Sorry for the messy statement. I'm lazy and it's 02:13 rn
|
||||||
crsr.execute(f"SELECT * FROM districts WHERE district_id IN ({','.join(['%s'] * len(district_ids))})", tuple(district_ids))
|
crsr.execute(f"SELECT * FROM districts WHERE district_id IN ({','.join(['%s'] * len(district_ids))})",
|
||||||
|
tuple(district_ids))
|
||||||
res = crsr.fetchall()
|
res = crsr.fetchall()
|
||||||
|
|
||||||
if len(district_ids) > len(res):
|
if len(district_ids) > len(res):
|
||||||
@ -324,8 +441,7 @@ class DBAccess:
|
|||||||
dists = self.get_channel_district_ids(channel_id)
|
dists = self.get_channel_district_ids(channel_id)
|
||||||
updated = [district for district in district_ids if district not in dists]
|
updated = [district for district in district_ids if district not in dists]
|
||||||
|
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
|
|
||||||
crsr.execute("UPDATE channels "
|
crsr.execute("UPDATE channels "
|
||||||
"SET locations = JSON_MERGE_PRESERVE(locations, %s) "
|
"SET locations = JSON_MERGE_PRESERVE(locations, %s) "
|
||||||
"WHERE channel_id=%s;",
|
"WHERE channel_id=%s;",
|
||||||
@ -333,7 +449,7 @@ class DBAccess:
|
|||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def get_channel_district_ids(self, channel_id: int) -> list:
|
def get_channel_district_ids(self, channel_id: int) -> list:
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.nextset()
|
crsr.nextset()
|
||||||
crsr.execute('SELECT locations '
|
crsr.execute('SELECT locations '
|
||||||
'FROM channels '
|
'FROM channels '
|
||||||
@ -357,8 +473,7 @@ class DBAccess:
|
|||||||
return filtered_districts
|
return filtered_districts
|
||||||
|
|
||||||
def remove_channel_districts(self, channel_id: int, district_ids: list[int]):
|
def remove_channel_districts(self, channel_id: int, district_ids: list[int]):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
|
|
||||||
districts = self.get_channel_district_ids(channel_id)
|
districts = self.get_channel_district_ids(channel_id)
|
||||||
|
|
||||||
updated = [district for district in districts if district not in district_ids]
|
updated = [district for district in districts if district not in district_ids]
|
||||||
@ -371,7 +486,7 @@ class DBAccess:
|
|||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def clear_channel_districts(self, channel_id: int):
|
def clear_channel_districts(self, channel_id: int):
|
||||||
with self.connection.cursor() as crsr:
|
with self.get_cursor() as crsr:
|
||||||
crsr.execute('UPDATE channels '
|
crsr.execute('UPDATE channels '
|
||||||
'SET locations = JSON_ARRAY() '
|
'SET locations = JSON_ARRAY() '
|
||||||
'WHERE channel_id = %s;',
|
'WHERE channel_id = %s;',
|
||||||
@ -381,4 +496,3 @@ class DBAccess:
|
|||||||
|
|
||||||
def is_registered_channel(self, channel_id: int) -> bool:
|
def is_registered_channel(self, channel_id: int) -> bool:
|
||||||
return self.get_channel(channel_id) is not None
|
return self.get_channel(channel_id) is not None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user