diff --git a/main.py b/main.py index 0bfbc96..0c7b9b5 100644 --- a/main.py +++ b/main.py @@ -11,11 +11,12 @@ import random import subprocess from asyncio import Lock from datetime import date, datetime, time, timedelta -from typing import Any, cast +from typing import Any, TypedDict, cast import nicolib import queries_to_answers as q2a from nicolib import VideoInfo +from nizika_ai.config import DB from nizika_ai.consts import Character, GPTModel, QueryType from nizika_ai.models import Query @@ -45,7 +46,8 @@ async def main ( report_nico (), update_kiriban_list (), report_snack_time (), - report_hot_spring_time ()) + report_hot_spring_time (), + reconnect_db ()) async def queries_to_answers ( @@ -92,7 +94,7 @@ async def report_kiriban ( if comments: prompt += f"人気のコメントは次の通りです:「{ '」、「'.join (c['content'] for c in popular_comments) }」\n" if latest_comments != popular_comments: - prompt += f"最新のコメントは次の通りです:「{ '」、「'.join (c.content for c in latest_comments) }」\n" + prompt += f"最新のコメントは次の通りです:「{ '」、「'.join (c['content'] for c in latest_comments) }」\n" prompt += f""" 概要には次のように書かれています: ```html @@ -156,16 +158,17 @@ def fetch_kiriban_list ( """ result = subprocess.run ( - ['python3', str (base_date), *map (str, KIRABAN_VIEWS_COUNTS)], + ['python3', '/root/nizika_nico/get_kiriban_list.py', + str (base_date), *map (str, KIRIBAN_VIEWS_COUNTS)], cwd = '/root/nizika_nico', env = os.environ, capture_output = True, text = True) kl: list[list[int | str]] = json.loads (result.stdout) - return map (lambda k: (cast (int, k[0]), - nicolib.fetch_video_info (cast (str, k[1])), - datetime.strptime (cast (str, k[2]), '%Y-%m-%d %H:%M:%S.%f')), kl) + return [(cast (int, k[0]), video_info, str_to_datetime (cast (str, k[2]))) + for k in kl + if (video_info := nicolib.fetch_video_info (cast (str, k[1]))) is not None] def fetch_comments ( @@ -186,15 +189,15 @@ def fetch_comments ( """ result = subprocess.run ( - ['python3', video_code], + ['python3', 'get_comments_by_video_code.py', video_code], cwd = '/root/nizika_nico', env = os.environ, capture_output = True, text = True) rows: list[dict[str, Any]] = json.loads (result.stdout) comments: list[CommentDict] = [] - for row in comments: - row['posted_at'] = datetime.strptime (row['posted_at'], '%Y-%m-%d %H:%M:%S.%f') + for row in rows: + row['posted_at'] = str_to_datetime (row['posted_at']) comments.append (cast (CommentDict, row)) return comments @@ -281,6 +284,19 @@ async def report_hot_spring_time ( _add_query ('温泉に入ろう!!!', QueryType.HOT_SPRING) +async def reconnect_db ( +) -> None: + while True: + await asyncio.sleep (1800) + try: + DB.reconnect ('mysql') + except Exception as e: + if getattr (e, 'args', [None])[0] == 2006: + print ('堕ちたな(確信).') + else: + raise + + def _add_query ( content: str, query_type: QueryType, @@ -291,7 +307,7 @@ def _add_query ( query.target_character = Character.DEERJIKA.value query.content = content query.query_type = query_type.value - query.model = GPTModel.GPT3_TURBO.value + query.model = GPTModel.GPT4_O.value query.sent_at = datetime.now () query.answered = False query.transfer_data = transfer_data @@ -324,6 +340,20 @@ def _format_elapsed ( return f"{ days }日{ hours }時間{ mins }分{ seconds }秒" +def str_to_datetime ( + s: str, +) -> datetime: + formats: list[str] = [ + '%Y-%m-%d %H:%M:%S.%f', + '%Y-%m-%d %H:%M:%S'] + for f in formats: + try: + return datetime.strptime (s, f) + except ValueError: + pass + raise ValueError ('うんち!w') + + class CommentDict (TypedDict): id: int video_id: int diff --git a/nicolib b/nicolib index 32ecf2d..8567098 160000 --- a/nicolib +++ b/nicolib @@ -1 +1 @@ -Subproject commit 32ecf2d00fcba876ed5afe48626058b2b4795399 +Subproject commit 85670982f0cd84b3127219dcb1d9ec8efe072aea diff --git a/nizika_ai b/nizika_ai index 3be6d90..5dae2ae 160000 --- a/nizika_ai +++ b/nizika_ai @@ -1 +1 @@ -Subproject commit 3be6d9063c987deaceee24a1d16296d21319778c +Subproject commit 5dae2ae038e1109f3c70d853d2b7dd7542e5e88e diff --git a/queries_to_answers.py b/queries_to_answers.py index cd9d96f..0950c05 100644 --- a/queries_to_answers.py +++ b/queries_to_answers.py @@ -84,6 +84,7 @@ def add_answer ( answer.content = Talk.main (message, user_name, histories, goatoh_mode = character == Character.GOATOH) answer.sent_at = datetime.now () + answer.answer_type = query.query_type answer.save () add_answered_flags (answer) @@ -102,7 +103,7 @@ def add_answered_flags ( answer_type: QueryType try: - answer_type = QueryType (answer.query_ref.query_type) + answer_type = QueryType (answer.query_rel.query_type) except (TypeError, ValueError): return