diff --git a/services/discord_bot.py b/services/discord_bot.py index b127d6b..7387084 100644 --- a/services/discord_bot.py +++ b/services/discord_bot.py @@ -2,6 +2,8 @@ from discord.ext import commands import asyncio from typing import Dict +import time +import threading class DiscordBot(commands.Bot): def __init__(self, config, llm_service): @@ -46,38 +48,203 @@ async def process_message(self, message: discord.Message): lock = self.user_locks.setdefault(user_id, asyncio.Lock()) async with lock: + # Initialize typing_task to None + typing_task = None try: - if not self.config.stream_mode: - typing_task = asyncio.create_task(self.start_typing(message.channel)) - self.typing_tasks[user_id] = typing_task - - response = await asyncio.to_thread( - self.llm.get_response, - user_id, - message.content, - message.author.name - ) - if self.config.stream_mode: - # Stream mode - sent_message = await message.channel.send("▌") + sent_message = None current_response = "" - for chunk in response: - current_response += chunk - await sent_message.edit(content=current_response + " ▌") - #await asyncio.sleep(0.5) - await sent_message.edit(content=current_response) + buffer = "" # Accumulate chunks between updates + last_update_time = 0.0 # Initialize timer + update_interval = 0.75 # Seconds between edits + edit_error_count = 0 + max_edit_errors = 3 # Stop editing after consecutive errors + + try: + # Queue for thread-safe communication between LLM thread and async loop + queue = asyncio.Queue() + loop = asyncio.get_running_loop() + # Event to signal when the LLM thread function has finished + llm_task_done = asyncio.Event() + + # Function to run the LLM generator in a separate thread + def llm_thread_target(): + try: + # Get the generator from the LLM service + llm_generator = self.llm.get_response( + user_id, message.content, message.author.name + ) + # Iterate over chunks yielded by the generator + for chunk in llm_generator: + # Put chunk into the async queue safely from the thread + future = asyncio.run_coroutine_threadsafe(queue.put(chunk), loop) + future.result() # Wait for put() to complete (optional, but ensures queue doesn't grow infinitely if consumer is slow) + + # Signal end of stream by putting None + asyncio.run_coroutine_threadsafe(queue.put(None), loop).result() + except Exception as e: + # Handle errors during generation within the thread + err_msg = f"🚨 Critical error in LLM thread: {e}" + print(err_msg) + # Put error message onto queue for main thread to handle + asyncio.run_coroutine_threadsafe(queue.put(err_msg), loop).result() + # Still signal completion after error + asyncio.run_coroutine_threadsafe(queue.put(None), loop).result() + finally: + # Signal that the thread function is finished executing + loop.call_soon_threadsafe(llm_task_done.set) + + # Start the LLM generation in the background thread + thread = threading.Thread(target=llm_thread_target, daemon=True) + thread.start() + + # Consume items from the queue in the main async event loop + while True: + chunk = await queue.get() + + # Check for end-of-stream signal + if chunk is None: + queue.task_done() # Mark item processed + break # Exit consumption loop + + # Check if the chunk is an error message + if chunk.startswith(("⚠️ Error:", "🚨 Critical error:")): + current_response = chunk # Display the error + buffer = "" # Clear any pending buffer + if sent_message: + try: + # Try to edit the existing message to show the error + await sent_message.edit(content=current_response) + except (discord.HTTPException, discord.NotFound): + # If edit fails, try sending error as new message + print(f"Failed to edit message with error, sending new.") + await self.safe_send(message.channel, current_response) + else: + # If no message sent yet, send the error directly + await self.safe_send(message.channel, current_response) + queue.task_done() + # Wait for thread cleanup before returning + await llm_task_done.wait() + return # Stop processing this request + + # Process a normal chunk + current_response += chunk + buffer += chunk + now = time.time() + + # Send initial message or edit existing one + if not sent_message: + try: + # Initial message send (use buffer first, then add cursor) + sent_message = await message.channel.send(buffer + " ▌") + current_response = buffer # Ensure current_response matches sent content + buffer = "" # Clear buffer after send + last_update_time = now # Start timer after first send + except discord.HTTPException as e: + print(f"Error sending initial stream message: {e}") + await message.channel.send("⚠️ Error starting stream.") + # Need to signal thread to stop? For now, just break locally. + break # Stop processing chunks + + # Check if update interval passed and there's new content + elif now - last_update_time >= update_interval and buffer: + if edit_error_count < max_edit_errors: + try: + await sent_message.edit(content=current_response + " ▌") + buffer = "" # Clear buffer on success + last_update_time = now + edit_error_count = 0 # Reset errors on success + await asyncio.sleep(0.05) # Small yield + except discord.NotFound: + print("Message not found during edit (deleted?). Stopping updates.") + break # Stop processing chunks + except discord.HTTPException as e: + edit_error_count += 1 + print(f"Failed to edit message ({edit_error_count}/{max_edit_errors}): {e}") + # Keep buffer, will try again or use in final edit + last_update_time = now # Prevent rapid retries + else: + # Max edit errors reached, stop trying to edit this message + if buffer: # Only print warning once + print("Max edit errors reached, further edits skipped.") + buffer = "" # Discard buffer for this interval to prevent spam + + queue.task_done() # Mark chunk as processed + + # --- End of queue consumption loop --- + + # Final edit after loop finishes to show complete response and remove cursor + if sent_message: + final_content = current_response # Contains full response now + if edit_error_count < max_edit_errors: + try: + # Edit one last time to remove cursor and ensure all content is present + await sent_message.edit(content=final_content) + except (discord.HTTPException, discord.NotFound) as e: + print(f"Failed final message edit: {e}") + # Optionally send remaining buffer if final edit failed + # if buffer: await self.safe_send(message.channel, "..." + buffer) + else: + # If editing failed too many times, maybe send the full thing as a new message + print("Final edit skipped due to previous errors.") + # await self.safe_send(message.channel, final_content) # Alternative + + # Wait for the background thread function to fully complete + await llm_task_done.wait() + + except Exception as e: + # Catch errors in the main async stream handling logic + print(f"Error during stream processing/queue handling: {e}") + await message.channel.send(f"🚨 An error occurred processing the stream.") + + # --- NON-STREAMING MODE --- else: - # No stream - await self.safe_send(message.channel, ''.join(response)) - + # Start typing indicator for non-stream mode + typing_task = asyncio.create_task(self.start_typing(message.channel)) + self.typing_tasks[user_id] = typing_task + + response_text = "" + try: + # Define helper to run generator and collect chunks in thread + def collect_chunks_target(): + generator = self.llm.get_response( + user_id, message.content, message.author.name + ) + # Consume the generator completely within the thread + chunks = list(generator) + # Check if the first (and potentially only) chunk is an error + if chunks and chunks[0].startswith(("⚠️ Error:", "🚨 Critical error:")): + return chunks[0] # Return only the error message + else: + return "".join(chunks) # Join normal chunks + + # Run the helper in a thread + response_text = await asyncio.to_thread(collect_chunks_target) + + # Send the complete response (or error) collected from the thread + await self.safe_send(message.channel, response_text) + + except Exception as e: + print(f"Critical error processing non-stream message: {e}") + await message.channel.send(f"🚨 Critical error: {str(e)}") + finally: + # Always cancel typing task in non-stream mode after completion/error + if typing_task: + typing_task.cancel() + # Remove task from dict (handled in outer finally too, but good practice here) + self.typing_tasks.pop(user_id, None) + except Exception as e: - print(f"Critical error: {str(e)}") - await message.channel.send(f"🚨 Critical error: {str(e)}") + # Catch-all for unexpected errors in process_message + print(f"Outer critical error in process_message: {str(e)}") + await message.channel.send(f"🚨 An unexpected critical error occurred.") finally: - typing_task.cancel() - self.typing_tasks.pop(user_id, None) + # Ensure typing task (if any started) is always cleaned up + # This handles cases where errors occurred before specific finally blocks + task = self.typing_tasks.pop(user_id, None) + if task and not task.done(): + task.cancel() async def on_ready(self): await self.change_presence(activity=discord.Game(name="DM me")) diff --git a/services/llm_service.py b/services/llm_service.py index d43ea0e..1d44e49 100644 --- a/services/llm_service.py +++ b/services/llm_service.py @@ -1,6 +1,6 @@ import threading from llama_cpp import Llama -from typing import Dict, List +from typing import Dict, List, Generator class LLMService: def __init__(self, config): @@ -10,64 +10,67 @@ def __init__(self, config): self.lock = threading.Lock() def initialize_model(self): - with self.lock: - if not self.model: - self.model = Llama( - model_path=self.config.model_path, - chat_format=self.config.chat_format, - n_ctx=self.config.model_params.get('n_ctx', 1024), - n_gpu_layers=self.config.model_params.get('n_gpu_layers', 0), - verbose=self.config.full_log - ) - print("LLM model initialized") + if not self.model: + self.model = Llama( + model_path=self.config.model_path, + chat_format=self.config.chat_format, + n_ctx=self.config.model_params.get('n_ctx', 1024), + n_gpu_layers=self.config.model_params.get('n_gpu_layers', 0), + verbose=self.config.full_log + ) + print("LLM model initialized") - def get_response(self, user_id, message, username=None): + def get_response(self, user_id: str, message: str, username: str = None) -> Generator[str, None, None]: with self.lock: - try: - history = self.conversations.get(user_id, []) - - messages = [ - {"role": "system", "content": self.config.system_prompt.replace("[user]", username or "user")}, - *history[-self.config.history_limit:], - {"role": "user", "content": message} - ] + history = self.conversations.get(user_id, []) + + messages = [ + {"role": "system", "content": self.config.system_prompt.replace("[user]", username or "user")}, + *history[-self.config.history_limit:], + {"role": "user", "content": message} + ] - print(f"Model config: {self.config.model_params}") if self.config.full_log else None - print(f"History: {messages}") if self.config.full_log else None + print(f"Model config: {self.config.model_params}") if self.config.full_log else None + print(f"History: {messages}") if self.config.full_log else None - completion_params = { - 'messages': messages, - 'stream': True, - 'max_tokens': self.config.model_params.get('max_tokens'), - 'temperature': self.config.model_params.get('temperature'), - 'top_k': self.config.model_params.get('top_k'), - 'top_p': self.config.model_params.get('top_p'), - 'repeat_penalty': self.config.model_params.get('repeat_penalty') - } + completion_params = { + 'messages': messages, + 'stream': True, + 'max_tokens': self.config.model_params.get('max_tokens'), + 'temperature': self.config.model_params.get('temperature'), + 'top_k': self.config.model_params.get('top_k'), + 'top_p': self.config.model_params.get('top_p'), + 'repeat_penalty': self.config.model_params.get('repeat_penalty') + } + # Store chunks temporarily to build the full response for history + response_chunks_for_history = [] + try: stream = self.model.create_chat_completion(**completion_params) - response_chunks = [] for part in stream: - delta = part["choices"][0]["delta"] - if "content" in delta: - response_chunks.append(delta["content"]) - - full_response = "".join(response_chunks) + delta = part["choices"][0].get("delta", {}) + chunk = delta.get("content") + if chunk: + response_chunks_for_history.append(chunk) + yield chunk # Yield chunk to the caller + + # ---- History Update ---- + # This part executes only after the generator has been fully iterated by the caller + full_response = "".join(response_chunks_for_history) - # Update history + # Update history if generation was successful new_history = history + [ {"role": "user", "content": message}, {"role": "assistant", "content": full_response} ] - self.conversations[user_id] = new_history[-self.config.history_limit * 2 :] - - return response_chunks - + # Apply history limit + self.conversations[user_id] = new_history[-(self.config.history_limit * 2):] + except Exception as e: self.conversations.pop(user_id, None) return [f"⚠️ Error: {str(e)}"] - + def clear_history(self, user_id): with self.lock: if user_id in self.conversations: