ニジカ AI 共通サービス
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

queries_to_answers.py 2.6 KiB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from __future__ import annotations
  2. import random
  3. from datetime import datetime
  4. from typing import TypedDict
  5. from nizika_ai.config import DB
  6. from nizika_ai.consts import AnswerType, Character, Platform
  7. from nizika_ai.models import Answer, AnsweredFlag, Query, User
  8. from nizika_ai.talk import Talk
  9. def main (
  10. ) -> None:
  11. DB.begin_transaction ()
  12. queries: list[Query] = Query.where ('answered', False).get ()
  13. if not queries:
  14. return
  15. query: Query = random.choice (queries)
  16. user: User = query.user
  17. user_name: str | None = None
  18. if query.user_id is not None:
  19. user_name = user.name
  20. histories: list[History] = []
  21. for history in query.answer_histories:
  22. if history.query is not None:
  23. histories.append ({ 'role': 'user', 'content': history.query.content })
  24. histories.append ({ 'role': 'assistant', 'content': history.content })
  25. for character in [Character.DEERJIKA, Character.GOATOH]:
  26. if query.target_character & character.value:
  27. add_answer (query, character, user_name, histories)
  28. query.answered = True
  29. query.save ()
  30. DB.commit ()
  31. def add_answer (
  32. query: Query,
  33. character: Character,
  34. user_name: str | None,
  35. histories: list[History],
  36. ) -> None:
  37. message: str | list[dict[str, str | dict[str, str]]]
  38. if query.image_url is None:
  39. message = query.content
  40. else:
  41. message = [{ 'type': 'text', 'text': query.content },
  42. { 'type': 'image_url', 'image_url': query.image_url }]
  43. answer = Answer ()
  44. answer.query_id = query.id
  45. answer.character = character.value
  46. answer.content = Talk.main (message, user_name, histories,
  47. goatoh_mode = character == Character.GOATOH)
  48. answer.answer_type = query.query_type
  49. answer.sent_at = datetime.now ()
  50. answer.save ()
  51. add_answered_flags (answer)
  52. def add_answered_flags (
  53. answer: Answer,
  54. ) -> None:
  55. answer_type: AnswerType
  56. try:
  57. answer_type = AnswerType (answer.answer_type)
  58. except Exception:
  59. return
  60. if answer_type in [AnswerType.YOUTUBE_REPLY]:
  61. add_answered_flag (answer, Platform.YOUTUBE)
  62. if answer_type in [AnswerType.BLUESKY_REPLY]:
  63. add_answered_flag (answer, Platform.BLUESKY)
  64. def add_answered_flag (
  65. answer: Answer,
  66. platform: Platform,
  67. ) -> None:
  68. answered_flag = AnsweredFlag ()
  69. answered_flag.answer_id = answer.id
  70. answered_flag.platform = platform.value
  71. answered_flag.answered = False
  72. answered_flag.save ()
  73. class History (TypedDict):
  74. role: str
  75. content: str
  76. if __name__ == '__main__':
  77. main ()