dialog.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import operator
  2. from asyncio import sleep, create_task, gather, get_event_loop, Queue, current_task
  3. from typing import Any, List
  4. from aiogram.dispatcher.filters.state import StatesGroup, State
  5. from aiogram.types import Message, CallbackQuery
  6. from aiogram_dialog import Window, Dialog, DialogManager, StartMode, Data
  7. from aiogram_dialog.widgets.kbd import Radio
  8. from aiogram_dialog.widgets.text import Format
  9. from aiogram.types import ParseMode
  10. from summarize import get_summary
  11. from progress import Bg, background
  12. warn = "***Note, that it is an AI generated summary, and it may contain complete bullshit***"
  13. class CrossData():
  14. queue: Queue = Queue()
  15. intent_queue: dict = {}
  16. context: dict = {}
  17. class MySG(StatesGroup):
  18. main = State()
  19. cross_data = CrossData()
  20. buttons = [
  21. ("Abstract", '1'),
  22. ("Summary", '2'),
  23. ("Highlights", '3'),
  24. ("Findings", '4')
  25. ]
  26. async def get_data(dialog_manager: DialogManager, **kwargs):
  27. text_message = {"text": "OOOPS!"}
  28. if data := dialog_manager.current_context():
  29. if cr_data:=cross_data.context.get(data.id):
  30. data.dialog_data.update(cr_data.dialog_data)
  31. del cross_data.context[data.id]
  32. else:
  33. cross_data.context[data.id] = None
  34. item_id = data.widget_data.get('radio_buttons') #type: ignore
  35. title = data.start_data["title"] #type: ignore
  36. url = data.start_data["url"] #type: ignore
  37. if data.dialog_data.get('abs'): #type: ignore
  38. abstract = data.dialog_data.get('abs') #type: ignore
  39. else:
  40. data.dialog_data["abs"] = data.start_data["reply_message"] #type: ignore
  41. abstract = data.dialog_data.get('abs') #type: ignore
  42. if item_id == "2":
  43. if summary:=data.dialog_data.get("summary"):
  44. if isinstance(summary, list):
  45. while sum(len(sentence) for sentence in summary) > 3700:
  46. summary = summary[:-1]
  47. summary = " ".join(summary)
  48. text_message = {"text": f"{url}\n\n***{title}***\n\n{summary}\n\n{warn}"}
  49. elif item_id == "3":
  50. if highlights:=data.dialog_data.get("highlights"):
  51. if isinstance(highlights, list):
  52. highlights = "\n\n- ".join(highlights)
  53. text_message = {"text": f"{url}\n\n***{title}***\n\n- {highlights}\n\n{warn}"}
  54. elif item_id == "4":
  55. if findings:=data.dialog_data.get("findings"):
  56. if isinstance(findings, list):
  57. findings = "\n\n- ".join(findings)
  58. text_message = {"text": f"{url}\n\n***{title}***\n\n- {findings}\n\n{warn}"}
  59. else:
  60. text_message = {"text": abstract}
  61. return text_message
  62. else:
  63. return text_message
  64. async def on_button_selected(c: CallbackQuery, widget: Any, manager: DialogManager, item_id: str):
  65. if context := manager.current_context():
  66. id_, url = context.start_data["id"], context.start_data["url"] # type: ignore
  67. context.widget_data["radio_buttons"] = item_id
  68. cross_data.intent_queue[context.id] = None
  69. cross_data.context[context.id] = context
  70. if item_id == "2":
  71. if context.dialog_data.get('summary'):
  72. pass
  73. else:
  74. await c.answer("Getting summary, please wait")
  75. await manager.start(Bg.progress)
  76. gather(
  77. background(c, manager.bg(), cross_data, context.id),
  78. get_summary(cross_data=cross_data, paper_id=id_, paper_url=url, context_id=context.id),
  79. )
  80. elif item_id == "3":
  81. if context.dialog_data.get("highlights"):
  82. pass
  83. else:
  84. await c.answer("Getting highlights, please wait")
  85. await manager.start(Bg.progress)
  86. gather(
  87. background(c, manager.bg(), cross_data, context.id),
  88. get_summary(cross_data=cross_data, paper_id=id_, paper_url=url, context_id=context.id)
  89. )
  90. elif item_id == "4":
  91. if context.dialog_data.get("findings"):
  92. pass
  93. else:
  94. await c.answer("Getting findings, please wait")
  95. await manager.start(Bg.progress)
  96. gather(
  97. background(c, manager.bg(), cross_data, context.id),
  98. get_summary(cross_data=cross_data, paper_id=id_, paper_url=url, context_id=context.id)
  99. )
  100. else:
  101. pass
  102. else:
  103. return {"text": 1}
  104. return {"text": item_id}
  105. buttons_kbd = Radio(
  106. Format("✓ {item[0]}"),
  107. Format("{item[0]}"),
  108. id="radio_buttons",
  109. item_id_getter=operator.itemgetter(1),
  110. items=buttons,
  111. on_click=on_button_selected,
  112. )
  113. dialog = Dialog(
  114. Window(
  115. Format("{text}"),
  116. buttons_kbd,
  117. state=MySG.main,
  118. getter=get_data,
  119. parse_mode=ParseMode.MARKDOWN, # type: ignore
  120. # preview_data={"button": "1"}
  121. )
  122. )