ニジカ AI 共通サービス
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
2.7 KiB

  1. from __future__ import annotations
  2. import random
  3. from datetime import datetime
  4. from typing import TypedDict
  5. from nizika_ai.config import DB
  6. from nizika_ai.consts import AnswerType, Character, Platform
  7. from nizika_ai.models import Answer, AnsweredFlag, Query, User
  8. from nizika_ai.talk import Talk
  9. def main (
  10. ) -> None:
  11. queries: list[Query] = Query.where ('answered', False).get ()
  12. if not queries:
  13. return
  14. query: Query = random.choice (queries)
  15. DB.begin_transaction ()
  16. try:
  17. user_name = query.user.name if query.user_id else None
  18. histories: list[History] = []
  19. for history in query.answer_histories:
  20. if history.query is not None:
  21. histories.append ({ 'role': 'user', 'content': history.query.content })
  22. histories.append ({ 'role': 'assistant', 'content': history.content })
  23. for character in [Character.DEERJIKA, Character.GOATOH]:
  24. if query.target_character & character.value:
  25. add_answer (query, character, user_name, histories)
  26. query.answered = True
  27. query.save ()
  28. DB.commit ()
  29. except Exception:
  30. DB.rollback ()
  31. raise
  32. def add_answer (
  33. query: Query,
  34. character: Character,
  35. user_name: str | None,
  36. histories: list[History],
  37. ) -> None:
  38. message: str | list[dict[str, str | dict[str, str]]]
  39. if query.image_url:
  40. message = [{ 'type': 'text', 'text': query.content },
  41. { 'type': 'image_url', 'image_url': query.image_url }]
  42. else:
  43. message = query.content
  44. answer = Answer ()
  45. answer.query_id = query.id
  46. answer.character = character.value
  47. answer.content = Talk.main (message, user_name, histories,
  48. goatoh_mode = character == Character.GOATOH)
  49. answer.answer_type = query.query_type
  50. answer.sent_at = datetime.now ()
  51. answer.save ()
  52. add_answered_flags (answer)
  53. def add_answered_flags (
  54. answer: Answer,
  55. ) -> None:
  56. answer_type: AnswerType
  57. try:
  58. answer_type = AnswerType (answer.answer_type)
  59. except Exception:
  60. return
  61. if answer_type in (AnswerType.YOUTUBE_REPLY,):
  62. add_answered_flag (answer, Platform.YOUTUBE)
  63. if answer_type in (AnswerType.BLUESKY_REPLY,):
  64. add_answered_flag (answer, Platform.BLUESKY)
  65. def add_answered_flag (
  66. answer: Answer,
  67. platform: Platform,
  68. ) -> None:
  69. answered_flag = AnsweredFlag ()
  70. answered_flag.answer_id = answer.id
  71. answered_flag.platform = platform.value
  72. answered_flag.answered = False
  73. answered_flag.save ()
  74. class History (TypedDict):
  75. role: str
  76. content: str
  77. if __name__ == '__main__':
  78. main ()