Skip to content

Commit

Permalink
Merge pull request #2074 from oalexandere/feature/async-exception-han…
Browse files Browse the repository at this point in the history
…dler

Fix ExceptionHandler handle() method as async for async telebot
  • Loading branch information
Badiboy authored Nov 17, 2023
2 parents f91f423 + a1bb695 commit 9ff9e11
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 54 deletions.
19 changes: 8 additions & 11 deletions examples/asynchronous_telebot/exception_handler.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
import logging

import telebot
from telebot.async_telebot import AsyncTeleBot


import logging
from telebot.async_telebot import AsyncTeleBot, ExceptionHandler

logger = telebot.logger
telebot.logger.setLevel(logging.DEBUG) # Outputs debug messages to console.
telebot.logger.setLevel(logging.DEBUG) # Outputs debug messages to console.

class ExceptionHandler(telebot.ExceptionHandler):
def handle(self, exception):
logger.error(exception)

bot = AsyncTeleBot('TOKEN',exception_handler=ExceptionHandler())
class MyExceptionHandler(ExceptionHandler):
async def handle(self, exception):
logger.error(exception)


bot = AsyncTeleBot('TOKEN', exception_handler=MyExceptionHandler())


@bot.message_handler(commands=['photo'])
async def photo_send(message: telebot.types.Message):
await bot.send_message(message.chat.id, 'Hi, this is an example of exception handlers.')
raise Exception('test') # Exception goes to ExceptionHandler if it is set
raise Exception('test') # Exception goes to ExceptionHandler if it is set



import asyncio
asyncio.run(bot.polling())
37 changes: 13 additions & 24 deletions telebot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,12 @@ def polling(self, non_stop: Optional[bool]=False, skip_pending: Optional[bool]=F
self.__non_threaded_polling(non_stop=non_stop, interval=interval, timeout=timeout, long_polling_timeout=long_polling_timeout,
logger_level=logger_level, allowed_updates=allowed_updates)

def _handle_exception(self, exception: Exception) -> bool:
if self.exception_handler is None:
return False

handled = self.exception_handler.handle(exception)
return handled

def __threaded_polling(self, non_stop = False, interval = 0, timeout = None, long_polling_timeout = None,
logger_level=logging.ERROR, allowed_updates=None):
Expand Down Expand Up @@ -1074,10 +1080,7 @@ def __threaded_polling(self, non_stop = False, interval = 0, timeout = None, lon
self.worker_pool.raise_exceptions()
error_interval = 0.25
except apihelper.ApiException as e:
if self.exception_handler is not None:
handled = self.exception_handler.handle(e)
else:
handled = False
handled = self._handle_exception(e)
if not handled:
if logger_level and logger_level >= logging.ERROR:
logger.error("Threaded polling exception: %s", str(e))
Expand Down Expand Up @@ -1107,10 +1110,7 @@ def __threaded_polling(self, non_stop = False, interval = 0, timeout = None, lon
self.__stop_polling.set()
break
except Exception as e:
if self.exception_handler is not None:
handled = self.exception_handler.handle(e)
else:
handled = False
handled = self._handle_exception(e)
if not handled:
polling_thread.stop()
polling_thread.clear_exceptions() #*
Expand Down Expand Up @@ -1144,11 +1144,7 @@ def __non_threaded_polling(self, non_stop=False, interval=0, timeout=None, long_
self.__retrieve_updates(timeout, long_polling_timeout, allowed_updates=allowed_updates)
error_interval = 0.25
except apihelper.ApiException as e:
if self.exception_handler is not None:
handled = self.exception_handler.handle(e)
else:
handled = False

handled = self._handle_exception(e)
if not handled:
if logger_level and logger_level >= logging.ERROR:
logger.error("Polling exception: %s", str(e))
Expand All @@ -1171,10 +1167,7 @@ def __non_threaded_polling(self, non_stop=False, interval=0, timeout=None, long_
self.__stop_polling.set()
break
except Exception as e:
if self.exception_handler is not None:
handled = self.exception_handler.handle(e)
else:
handled = False
handled = self._handle_exception(e)
if not handled:
raise e
else:
Expand All @@ -1190,10 +1183,7 @@ def _exec_task(self, task, *args, **kwargs):
try:
task(*args, **kwargs)
except Exception as e:
if self.exception_handler is not None:
handled = self.exception_handler.handle(e)
else:
handled = False
handled = self._handle_exception(e)
if not handled:
raise e

Expand Down Expand Up @@ -6858,9 +6848,8 @@ def _run_middlewares_and_handler(self, message, handlers, middlewares, update_ty
break
except Exception as e:
handler_error = e
if self.exception_handler:
self.exception_handler.handle(e)
else:
handled = self._handle_exception(e)
if not handled:
logger.error(str(e))
logger.debug("Exception traceback:\n%s", traceback.format_exc())

Expand Down
35 changes: 16 additions & 19 deletions telebot/async_telebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ExceptionHandler:
"""

# noinspection PyMethodMayBeStatic,PyUnusedLocal
def handle(self, exception):
async def handle(self, exception):
return False


Expand Down Expand Up @@ -368,6 +368,16 @@ async def infinity_polling(self, timeout: Optional[int]=20, skip_pending: Option
if logger_level and logger_level >= logging.INFO:
logger.error("Break infinity polling")

async def _handle_exception(self, exception: Exception) -> bool:
if self.exception_handler is None:
return False

if iscoroutinefunction(self.exception_handler.handle):
handled = await self.exception_handler.handle(exception)
else:
handled = self.exception_handler.handle(exception) # noqa
return handled

async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: int=20,
request_timeout: int=None, allowed_updates: Optional[List[str]]=None):
"""
Expand Down Expand Up @@ -415,11 +425,7 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:
except asyncio.CancelledError:
return
except asyncio_helper.RequestTimeout as e:
handled = False
if self.exception_handler:
self.exception_handler.handle(e)
handled = True

handled = await self._handle_exception(e)
if not handled:
logger.error('Unhandled exception (full traceback for debug level): %s', str(e))
logger.debug(traceback.format_exc())
Expand All @@ -430,11 +436,7 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:
else:
return
except asyncio_helper.ApiException as e:
handled = False
if self.exception_handler:
self.exception_handler.handle(e)
handled = True

handled = await self._handle_exception(e)
if not handled:
logger.error('Unhandled exception (full traceback for debug level): %s', str(e))
logger.debug(traceback.format_exc())
Expand All @@ -444,11 +446,7 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:
else:
break
except Exception as e:
handled = False
if self.exception_handler:
self.exception_handler.handle(e)
handled = True

handled = await self._handle_exception(e)
if not handled:
logger.error('Unhandled exception (full traceback for debug level): %s', str(e))
logger.debug(traceback.format_exc())
Expand Down Expand Up @@ -545,9 +543,8 @@ async def _run_middlewares_and_handlers(self, message, handlers, middlewares, up
break
except Exception as e:
handler_error = e
if self.exception_handler:
self.exception_handler.handle(e)
else:
handled = await self._handle_exception(e)
if not handled:
logger.error(str(e))
logger.debug("Exception traceback:\n%s", traceback.format_exc())

Expand Down

0 comments on commit 9ff9e11

Please sign in to comment.