diff --git a/queries_to_answers.py b/queries_to_answers.py index 68c8670..20c7475 100644 --- a/queries_to_answers.py +++ b/queries_to_answers.py @@ -12,26 +12,34 @@ from nizika_ai.talk import Talk def main ( ) -> None: - DB.begin_transaction () queries: list[Query] = Query.where ('answered', False).get () if not queries: return + query: Query = random.choice (queries) - user: User = query.user - user_name: str | None = None - if query.user_id is not None: - user_name = user.name - histories: list[History] = [] - for history in query.answer_histories: - if history.query is not None: - histories.append ({ 'role': 'user', 'content': history.query.content }) - histories.append ({ 'role': 'assistant', 'content': history.content }) - for character in [Character.DEERJIKA, Character.GOATOH]: - if query.target_character & character.value: - add_answer (query, character, user_name, histories) - query.answered = True - query.save () - DB.commit () + + DB.begin_transaction () + try: + user_name = query.user.name if query.user_id else None + + histories: list[History] = [] + for history in query.answer_histories: + if history.query is not None: + histories.append ({ 'role': 'user', 'content': history.query.content }) + histories.append ({ 'role': 'assistant', 'content': history.content }) + + for character in [Character.DEERJIKA, Character.GOATOH]: + if query.target_character & character.value: + add_answer (query, character, user_name, histories) + + query.answered = True + + query.save () + + DB.commit () + except Exception: + DB.rollback () + raise def add_answer ( @@ -41,11 +49,12 @@ def add_answer ( histories: list[History], ) -> None: message: str | list[dict[str, str | dict[str, str]]] - if query.image_url is None: - message = query.content - else: + if query.image_url: message = [{ 'type': 'text', 'text': query.content }, { 'type': 'image_url', 'image_url': query.image_url }] + else: + message = query.content + answer = Answer () answer.query_id = query.id answer.character = character.value @@ -65,9 +74,10 @@ def add_answered_flags ( answer_type = AnswerType (answer.answer_type) except Exception: return - if answer_type in [AnswerType.YOUTUBE_REPLY]: + + if answer_type in (AnswerType.YOUTUBE_REPLY,): add_answered_flag (answer, Platform.YOUTUBE) - if answer_type in [AnswerType.BLUESKY_REPLY]: + if answer_type in (AnswerType.BLUESKY_REPLY,): add_answered_flag (answer, Platform.BLUESKY)