ニジカ AI 共通サービス
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
4.0 KiB

  1. """
  2. DB の queries テーブルにたまってゐるクエリを AI に処理させ answers テーブルに流す.
  3. """
  4. from __future__ import annotations
  5. import random
  6. from datetime import datetime
  7. from typing import TypedDict
  8. from nizika_ai.config import DB
  9. from nizika_ai.consts import Character, Platform, QueryType
  10. from nizika_ai.models import Answer, AnsweredFlag, Query
  11. from nizika_ai.talk import Talk
  12. def main (
  13. ) -> None:
  14. """
  15. メーン処理
  16. """
  17. queries: list[Query] = Query.where ('answered', False).get ()
  18. if not queries:
  19. return
  20. query: Query = random.choice (queries)
  21. DB.begin_transaction ()
  22. try:
  23. user_name = query.user.name if query.user_id else None
  24. histories: list[History] = []
  25. for history in query.answer_histories:
  26. if history.query is not None:
  27. histories.append ({ 'role': 'user', 'content': history.query.content })
  28. histories.append ({ 'role': 'assistant', 'content': history.content })
  29. for character in [Character.DEERJIKA, Character.GOATOH]:
  30. if query.target_character & character.value:
  31. add_answer (query, character, user_name, histories)
  32. query.answered = True
  33. query.save ()
  34. DB.commit ()
  35. except Exception:
  36. DB.rollback ()
  37. raise
  38. def add_answer (
  39. query: Query,
  40. character: Character,
  41. user_name: str | None,
  42. histories: list[History],
  43. ) -> None:
  44. """
  45. AI の返答を DB に積む.
  46. Parameters
  47. ----------
  48. query: Query
  49. クエリ
  50. character: Character
  51. 返答するキャラクタ
  52. user_name: str | None
  53. クエリの主
  54. histories: list[History]
  55. 履歴
  56. """
  57. message: str | list[dict[str, str | dict[str, str]]]
  58. if query.image_url:
  59. message = [{ 'type': 'text', 'text': query.content },
  60. { 'type': 'image_url', 'image_url': query.image_url }]
  61. else:
  62. message = query.content
  63. answer = Answer ()
  64. answer.query_id = query.id
  65. answer.character = character.value
  66. answer.content = Talk.main (message, user_name, histories,
  67. goatoh_mode = character == Character.GOATOH)
  68. answer.sent_at = datetime.now ()
  69. answer.save ()
  70. add_answered_flags (answer)
  71. def add_answered_flags (
  72. answer: Answer,
  73. ) -> None:
  74. """
  75. 返答済フラグを付与する.
  76. Parameters
  77. ----------
  78. answer: Answer
  79. 返答モデル
  80. """
  81. answer_type: QueryType
  82. try:
  83. answer_type = QueryType (answer.query.query_type)
  84. except (TypeError, ValueError):
  85. return
  86. if answer_type in (QueryType.YOUTUBE_COMMENT,
  87. QueryType.YOUTUBE_COMMENT,
  88. QueryType.KIRIBAN,
  89. QueryType.NICO_REPORT,
  90. QueryType.SNACK_TIME,
  91. QueryType.HOT_SPRING):
  92. add_answered_flag (answer, Platform.YOUTUBE)
  93. if answer_type in (QueryType.BLUESKY_COMMENT,
  94. QueryType.BLUESKY_SYSTEM,
  95. QueryType.KIRIBAN,
  96. QueryType.NICO_REPORT,
  97. QueryType.SNACK_TIME,
  98. QueryType.HOT_SPRING):
  99. add_answered_flag (answer, Platform.BLUESKY)
  100. def add_answered_flag (
  101. answer: Answer,
  102. platform: Platform,
  103. ) -> None:
  104. """
  105. 返答済フラグを付与する.
  106. Parameters
  107. ----------
  108. answer: Answer
  109. 返答モデル
  110. platform: Platform
  111. プラットフォーム
  112. """
  113. answered_flag = AnsweredFlag ()
  114. answered_flag.answer_id = answer.id
  115. answered_flag.platform = platform.value
  116. answered_flag.answered = False
  117. answered_flag.save ()
  118. class History (TypedDict):
  119. """
  120. 会話履歴の 1 要素;ユーザや AI の発話を簡易に保持する型
  121. """
  122. role: str
  123. content: str
  124. if __name__ == '__main__':
  125. main ()