@@ -3,6 +3,8 @@ from __future__ import annotations | |||||
import os | import os | ||||
from typing import TypedDict | from typing import TypedDict | ||||
from eloquent import DatabaseManager, Model # type: ignore | |||||
CONFIG: dict[str, DbConfig] = { 'mysql': { 'driver': 'mysql', | CONFIG: dict[str, DbConfig] = { 'mysql': { 'driver': 'mysql', | ||||
'host': 'localhost', | 'host': 'localhost', | ||||
'database': 'nizika_ai', | 'database': 'nizika_ai', | ||||
@@ -10,6 +12,9 @@ CONFIG: dict[str, DbConfig] = { 'mysql': { 'driver': 'mysql', | |||||
'password': os.environ['MYSQL_PASS'], | 'password': os.environ['MYSQL_PASS'], | ||||
'prefix': '' } } | 'prefix': '' } } | ||||
DB = DatabaseManager (CONFIG) | |||||
Model.set_connection_resolver (DB) | |||||
class DbConfig (TypedDict): | class DbConfig (TypedDict): | ||||
driver: str | driver: str | ||||
@@ -4,6 +4,8 @@ from datetime import datetime | |||||
from eloquent import Model # type: ignore | from eloquent import Model # type: ignore | ||||
from config import DB | |||||
class AnsweredFlag (Model): | class AnsweredFlag (Model): | ||||
id: int | id: int | ||||
@@ -4,9 +4,7 @@ import random | |||||
from datetime import datetime | from datetime import datetime | ||||
from typing import cast | from typing import cast | ||||
from eloquent import DatabaseManager, Model # type: ignore | |||||
from config import CONFIG | |||||
from config import DB | |||||
from consts import Character | from consts import Character | ||||
from models import Answer, Query, User | from models import Answer, Query, User | ||||
from talk import Talk | from talk import Talk | ||||
@@ -14,8 +12,7 @@ from talk import Talk | |||||
def main ( | def main ( | ||||
) -> None: | ) -> None: | ||||
db = DatabaseManager (CONFIG) | |||||
Model.set_connection_resolver (db) | |||||
DB.begin_transaction () | |||||
queries: list[Query] = Query.where ('answered', False).get () | queries: list[Query] = Query.where ('answered', False).get () | ||||
if not queries: | if not queries: | ||||
return | return | ||||
@@ -53,6 +50,7 @@ def main ( | |||||
answer.save () | answer.save () | ||||
query.answered = True | query.answered = True | ||||
query.save () | query.save () | ||||
DB.commit () | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||