ニジカ AI 共通サービス
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
3.5 KiB

  1. """
  2. DB の queries テーブルにたまってゐるクエリを AI に処理させ answers テーブルに流す.
  3. """
  4. from __future__ import annotations
  5. import random
  6. from datetime import datetime
  7. from typing import TypedDict
  8. from nizika_ai.config import DB
  9. from nizika_ai.consts import AnswerType, Character, Platform
  10. from nizika_ai.models import Answer, AnsweredFlag, Query
  11. from nizika_ai.talk import Talk
  12. def main (
  13. ) -> None:
  14. """
  15. メーン処理
  16. """
  17. queries: list[Query] = Query.where ('answered', False).get ()
  18. if not queries:
  19. return
  20. query: Query = random.choice (queries)
  21. DB.begin_transaction ()
  22. try:
  23. user_name = query.user.name if query.user_id else None
  24. histories: list[History] = []
  25. for history in query.answer_histories:
  26. if history.query is not None:
  27. histories.append ({ 'role': 'user', 'content': history.query.content })
  28. histories.append ({ 'role': 'assistant', 'content': history.content })
  29. for character in [Character.DEERJIKA, Character.GOATOH]:
  30. if query.target_character & character.value:
  31. add_answer (query, character, user_name, histories)
  32. query.answered = True
  33. query.save ()
  34. DB.commit ()
  35. except Exception:
  36. DB.rollback ()
  37. raise
  38. def add_answer (
  39. query: Query,
  40. character: Character,
  41. user_name: str | None,
  42. histories: list[History],
  43. ) -> None:
  44. """
  45. AI の返答を DB に積む.
  46. Parameters
  47. ----------
  48. query: Query
  49. クエリ
  50. character: Character
  51. 返答するキャラクタ
  52. user_name: str | None
  53. クエリの主
  54. histories: list[History]
  55. 履歴
  56. """
  57. message: str | list[dict[str, str | dict[str, str]]]
  58. if query.image_url:
  59. message = [{ 'type': 'text', 'text': query.content },
  60. { 'type': 'image_url', 'image_url': query.image_url }]
  61. else:
  62. message = query.content
  63. answer = Answer ()
  64. answer.query_id = query.id
  65. answer.character = character.value
  66. answer.content = Talk.main (message, user_name, histories,
  67. goatoh_mode = character == Character.GOATOH)
  68. answer.answer_type = query.query_type
  69. answer.sent_at = datetime.now ()
  70. answer.save ()
  71. add_answered_flags (answer)
  72. def add_answered_flags (
  73. answer: Answer,
  74. ) -> None:
  75. """
  76. 返答済フラグを付与する.
  77. Parameters
  78. ----------
  79. answer: Answer
  80. 返答モデル
  81. """
  82. answer_type: AnswerType
  83. try:
  84. answer_type = AnswerType (answer.answer_type)
  85. except (TypeError, ValueError):
  86. return
  87. if answer_type in (AnswerType.YOUTUBE_REPLY,):
  88. add_answered_flag (answer, Platform.YOUTUBE)
  89. if answer_type in (AnswerType.BLUESKY_REPLY,):
  90. add_answered_flag (answer, Platform.BLUESKY)
  91. def add_answered_flag (
  92. answer: Answer,
  93. platform: Platform,
  94. ) -> None:
  95. """
  96. 返答済フラグを付与する.
  97. Parameters
  98. ----------
  99. answer: Answer
  100. 返答モデル
  101. platform: Platform
  102. プラットフォーム
  103. """
  104. answered_flag = AnsweredFlag ()
  105. answered_flag.answer_id = answer.id
  106. answered_flag.platform = platform.value
  107. answered_flag.answered = False
  108. answered_flag.save ()
  109. class History (TypedDict):
  110. """
  111. 会話履歴の 1 要素;ユーザや AI の発話を簡易に保持する型
  112. """
  113. role: str
  114. content: str
  115. if __name__ == '__main__':
  116. main ()