diff --git a/config.py b/config.py index e96eee1..fe04055 100644 --- a/config.py +++ b/config.py @@ -3,6 +3,8 @@ from __future__ import annotations import os from typing import TypedDict +from eloquent import DatabaseManager, Model # type: ignore + CONFIG: dict[str, DbConfig] = { 'mysql': { 'driver': 'mysql', 'host': 'localhost', 'database': 'nizika_ai', @@ -10,6 +12,9 @@ CONFIG: dict[str, DbConfig] = { 'mysql': { 'driver': 'mysql', 'password': os.environ['MYSQL_PASS'], 'prefix': '' } } +DB = DatabaseManager (CONFIG) +Model.set_connection_resolver (DB) + class DbConfig (TypedDict): driver: str diff --git a/models.py b/models.py index 8dd09b3..77f5d15 100644 --- a/models.py +++ b/models.py @@ -4,6 +4,8 @@ from datetime import datetime from eloquent import Model # type: ignore +from config import DB + class AnsweredFlag (Model): id: int diff --git a/queries_to_answers.py b/queries_to_answers.py index 0ae9a8b..75d1bfd 100644 --- a/queries_to_answers.py +++ b/queries_to_answers.py @@ -4,9 +4,7 @@ import random from datetime import datetime from typing import cast -from eloquent import DatabaseManager, Model # type: ignore - -from config import CONFIG +from config import DB from consts import Character from models import Answer, Query, User from talk import Talk @@ -14,8 +12,7 @@ from talk import Talk def main ( ) -> None: - db = DatabaseManager (CONFIG) - Model.set_connection_resolver (db) + DB.begin_transaction () queries: list[Query] = Query.where ('answered', False).get () if not queries: return @@ -53,6 +50,7 @@ def main ( answer.save () query.answered = True query.save () + DB.commit () if __name__ == '__main__':