diff --git a/db_creation/update_db.py b/db_creation/update_db.py index 50dc0e1..3627d22 100644 --- a/db_creation/update_db.py +++ b/db_creation/update_db.py @@ -13,10 +13,22 @@ DB_PASSWORD = os.getenv('DB_PASSWORD') def updater_1_0_0(connection: mysql.connection.MySQLConnection) -> str: - crsr = connection.cursor() - crsr.execute("ALTER TABLE `hfc_db`.`channels` DROP COLUMN IF EXISTS locations") - crsr.execute("ALTER TABLE `hfc_db`.`channels` ADD COLUMN `locations` JSON NOT NULL DEFAULT ('[]');") - crsr.close() + with connection.cursor() as crsr: + crsr.execute("SELECT COLUMN_NAME " + "FROM INFORMATION_SCHEMA.COLUMNS " + "WHERE TABLE_SCHEMA = 'hfc_db' " + "AND TABLE_NAME = 'channels' " + "AND COLUMN_NAME = 'locations';") + + exists = (crsr.fetchone() is not None) + + crsr.nextset() + + if exists: + crsr.execute("ALTER TABLE `hfc_db`.`channels` DROP COLUMN `locations`;") + + crsr.execute("ALTER TABLE `hfc_db`.`channels` ADD COLUMN `locations` JSON NOT NULL DEFAULT ('[]');") + return '1.0.1'