diff --git a/talk.py b/talk.py index 93a7930..8e51fc8 100644 --- a/talk.py +++ b/talk.py @@ -1,7 +1,8 @@ -from connection import * - import sys import openai +from openai.types import * + +from connection import * class Talk: @@ -11,24 +12,23 @@ class Talk: top_p: int = 1 @classmethod - def main (cls, message: str) -> None: + def main (cls, message: str) -> str: openai.organization = OPENAI_ORGANISATION openai.api_key = OPENAI_API_KEY - print (cls.__get_answers (message)[0]) + return cls.__get_message (message).content @classmethod - def __get_answers (cls, message: str) -> list[str]: - return openai.Completion.create ( - engine = 'text-davinci-002', - prompt = message, - max_tokens = cls.max_tokens_count, - n = cls.temperature, - stop = None, - temperature = cls.temperature, - top_p = cls.top_p).choices + def __get_message (cls, message: str) \ + -> chat.chat_completion_message.ChatCompletionMessage: + return openai.chat.completions.create ( + model = 'gpt-3.5-turbo', + messages = [{'role': 'system', + 'content': 'すべての語尾に「ぬ゛ん」か「だぬ゛ん」をつけて質問に短く答えてください。定期的に「ぬ゛ぅ゛ぅ゛ぅ゛ぅ゛ん゛」、「ボッチチャーン」、「あっ」のいずれかを発してください'}, + {'role': 'user', + 'content': message}]).choices[0].message if __name__ == '__main__': - Talk.main (sys.argv[1] if len (sys.argv) > 1 else '') + print (Talk.main (sys.argv[1] if len (sys.argv) > 1 else ''))