Parcourir la source

feat: add database and summarizer

add funcionality of summarizing papers with external api and sqlite database with crud and models to store papers and its summary there

BREAKING CHANGE:
metya il y a 2 ans
Parent
commit
3709695f77
12 fichiers modifiés avec 744 ajouts et 229 suppressions
  1. 1 0
      .gitattributes
  2. 13 1
      CHANGELOG.md
  3. 3 2
      Dockerfile
  4. 2 2
      config.py
  5. 169 0
      db.py
  6. 162 126
      develop.yml
  7. 91 43
      dialog.py
  8. 3 0
      papers.db
  9. 125 0
      progress.py
  10. 5 2
      requirements.txt
  11. 143 44
      summarize.py
  12. 27 9
      vanitybot.py

+ 1 - 0
.gitattributes

@@ -0,0 +1 @@
+papers.db filter=lfs diff=lfs merge=lfs -text

+ 13 - 1
CHANGELOG.md

@@ -1,6 +1,18 @@
-## 0.1.0 (2021-02-11)
+## 0.2.0 (2023-02-08)
+
+### Refactor
+
+- **linter**: ignore some mypy warnings
 
 ### Feat
 
+- add paper summarizer
+- //WIP before пр actions
+- add docker-compose and update README
 - add docker files for convinient deploy
 - add bot's business-logic files
+
+### Fix
+
+- fix spaces in descriptions
+- fix dockerfile

+ 3 - 2
Dockerfile

@@ -5,10 +5,11 @@ ENV PIP_NO_CACHE_DIR=off \
 ARG API_TOKEN
 ENV API_TOKEN=$API_TOKEN
 
-RUN apk add --no-cache --virtual .build-deps gcc musl-dev
+RUN apk add --no-cache --virtual .build-deps gcc musl-dev libffi-dev curl
+RUN apk add --no-cache git git-lfs
 
 WORKDIR /app
-ADD prod-requirements.txt /app
+ADD requirements.txt /app
 RUN pip install --no-cache-dir -r requirements.txt
 RUN apk del .build-deps
 

+ 2 - 2
config.py

@@ -5,7 +5,7 @@ if env_var := dotenv_values('token'):
     API_TOKEN = env_var["API_TOKEN"]
 else:
     API_TOKEN = getenv("API_TOKEN")
-    
-    
+
+
 if __name__ == "__main__":
     print(API_TOKEN)

+ 169 - 0
db.py

@@ -0,0 +1,169 @@
+import asyncio
+from sqlmodel import Field, SQLModel, create_engine, Relationship, select, Session, JSON, Column
+from sqlmodel.ext.asyncio.session import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession as _AsyncSession
+from sqlalchemy.orm import selectinload
+
+
+async_sqlite_url = "sqlite+aiosqlite:///papers.db"
+echo = False
+
+async_engine = create_async_engine(async_sqlite_url, echo=echo)
+engine = create_engine("sqlite:///papers.db", echo=echo)
+
+
+# models
+class Papers2Authors(SQLModel, table=True):
+    __table_args__ = {'extend_existing': True}
+    paper_id: int | None = Field(default=None, primary_key=True, foreign_key="papers.id_")
+    author_id: int | None = Field(default=None, primary_key=True, foreign_key="authors.id_")
+
+
+class Authors(SQLModel, table=True):
+    __table_args__ = {'extend_existing': True}
+    id_: int | None = Field(default=None, primary_key=True)
+    name: str = Field(unique=True)
+    papers: list["Papers"] = Relationship(back_populates="authors", link_model=Papers2Authors)
+
+
+class Papers(SQLModel, table=True):
+    __table_args__ = {'extend_existing': True}
+    id_: str = Field(primary_key=True)
+    title: str
+    abstract: str
+    highlights: list[str] | None = Field(default=None, sa_column=Column(JSON))
+    findings: list[str] | None = Field(default=None, sa_column=Column(JSON))
+    summary: list[str] | None = Field(default=None, sa_column=Column(JSON))
+    figures_url: list[str] | None = Field(default=None, sa_column=Column(JSON))
+    full_data: dict | None = Field(default=None, sa_column=Column(JSON))
+    authors: list["Authors"] | None = Relationship(back_populates="papers", link_model=Papers2Authors)
+    class Config:
+        arbitrary_types_allowed = True
+
+
+def create_db():
+    engine = create_engine("sqlite:///papers.db", echo=True)
+    SQLModel.metadata.create_all(engine, checkfirst=True)
+
+
+# CRUD
+def add_authors_to_paper(id_, paper: dict, authors):
+    authors = [Authors(name=name) for name in authors]
+    with Session(engine) as session:
+        paper_ = session.get(Papers, id_)
+        if paper_:
+            if paper_.authors:
+                paper_.authors.extend(authors)
+                session.add(paper_)
+                session.commit()
+            else:
+                print("OOOOPS")
+        else:
+            print('OOOPS')
+
+
+async def add_authors_and_paper(id_: str, paper: dict):
+    if await check_paper(id_):
+        return
+
+    if authors := paper.get("authors"):
+        paper_ = Papers(
+            id_ = id_,
+            title = paper["title"],
+            abstract = paper["abstract"],
+            highlights = paper.get("highlights"),
+            summary = paper.get("summary"),
+            figures_url = paper.get("figures_url"),
+            findings = paper.get("findings"),
+            full_data = paper.get("full_data") if paper.get("full_data") else paper
+        )
+        async with AsyncSession(async_engine, expire_on_commit=False) as session:
+            queries = [select(Authors).where(Authors.name==name).
+                       options(selectinload(Authors.papers))
+                       for name in authors]
+            results = await asyncio.gather(*[session.exec(query) for query in queries]) # type: ignore
+            exist_authors = [author for res in results if (author:=res.first())]
+            exist_names = [author.name for author in exist_authors if author]
+            [author.papers.append(paper_) for author in exist_authors]
+            session.add_all(exist_authors)
+            new_authors = [author for author in authors if author not in exist_names]
+            new_authors = [Authors(name=name, papers=[paper_]) for name in new_authors]
+            session.add_all(new_authors)
+            await session.commit()
+    else:
+        await add_or_update_paper(id_, paper)
+
+
+def add_paper_to_authors(id_, paper, authors): 
+    paper_ = Papers(
+            id_ = id_,
+            title = paper["title"],
+            abstract = paper["abstract"],
+            highlights = paper.get("highlights"),
+            summary = paper.get("summary"),
+            figures_url = paper.get("figures_url"),
+            findings = paper.get("findings"),
+            full_data = paper.get("full_data") if paper.get("full_data") else paper
+    )
+    with Session(engine, expire_on_commit=False) as session:
+        queries = [select(Authors).where(Authors.name==name) for name in authors]
+        results = [session.exec(query) for query in queries]
+        exist_authors = [author for author in [author.first() for author in results] if author]
+        [author.papers.append(paper_) for author in exist_authors]
+        session.add_all(exist_authors)
+        session.commit()
+
+
+async def check_paper(id_: str, async_engine=async_engine):
+    async with AsyncSession(async_engine) as session:
+        return await session.get(Papers, id_)
+
+
+async def add_or_update_paper(id_:str, paper: dict, async_engine = async_engine):
+    if exists_paper := await check_paper(id_):
+        paper_ = exists_paper
+        paper_.highlights = paper.get("highlights")
+        paper_.summary = paper.get("summary")
+        paper_.findings = paper.get("findings")
+        paper_.figures_url = paper.get("figures_url")
+        paper_.full_data = paper.get("full_data") if paper.get("full_data") else paper
+    else:
+        paper_ = Papers(
+            id_ = id_,
+            title = paper["title"],
+            abstract = paper["abstract"],
+            highlights = paper.get("highlights"),
+            summary = paper.get("summary"),
+            figures_url = paper.get("figures_url"),
+            findings = paper.get("findings"),
+            full_data = paper.get("full_data") if paper.get("full_data") else paper
+            )
+    async with AsyncSession(async_engine) as session:
+        session.add(paper_)
+        await session.commit()
+
+
+
+
+if __name__ == "__main__":
+    import json
+    with open("paper.json", 'r') as file:
+        paper = json.load(file)
+    # print(paper)
+    paper["title"] = paper["metadata"]["title"]
+    paper["authors"] = paper["metadata"]["author"].split(",").strip()
+    paper["abstract"] = paper["metadata"]["abstract"]
+    create_db()
+
+    async def main():
+        await add_authors_and_paper("2203.02155v1", paper)
+
+        # paper = await check_paper("1234.1212v1")
+        # if paper:
+        #     print("\n\n\n\n")
+        #     print(paper.summary)
+        #     print("\n\n")
+        #     print(paper.summary[0])
+
+    asyncio.run(main())
+    

+ 162 - 126
develop.yml

@@ -2,132 +2,165 @@ name: vanity
 channels:
   - conda-forge
 dependencies:
-  - aiodns=3.0.0
-  - aiohttp=3.8.3
-  - aiosignal=1.3.1
-  - appdirs=1.4.4
-  - appnope=0.1.3
-  - argcomplete=1.12.3
-  - asttokens=2.2.1
-  - async-timeout=4.0.2
-  - attrs=22.2.0
-  - backcall=0.2.0
-  - backports=1.0
-  - backports.functools_lru_cache=1.6.4
-  - beautifulsoup4=4.11.2
-  - brotlipy=0.7.0
-  - bzip2=1.0.8
-  - ca-certificates=2022.12.7
-  - cchardet=2.1.7
-  - certifi=2022.12.7
-  - cffi=1.15.1
-  - cfgv=3.3.1
-  - charset-normalizer=2.1.1
-  - click=8.1.3
-  - colorama=0.4.6
-  - comm=0.1.2
-  - commitizen=2.28.1
-  - cryptography=39.0.0
-  - cssselect=1.2.0
-  - debugpy=1.6.6
-  - decli=0.5.2
-  - decorator=5.1.1
-  - distlib=0.3.6
-  - executing=1.2.0
-  - fake-useragent=1.1.1
-  - filelock=3.9.0
-  - frozenlist=1.3.3
-  - icu=70.1
-  - identify=2.5.17
-  - idna=3.4
-  - importlib-metadata=6.0.0
-  - importlib_metadata=6.0.0
-  - ipykernel=6.21.1
-  - ipython=8.9.0
-  - jedi=0.18.2
-  - jinja2=3.1.2
-  - jupyter_client=8.0.2
-  - jupyter_core=5.2.0
-  - libcxx=14.0.6
-  - libffi=3.4.2
-  - libiconv=1.17
-  - libsodium=1.0.18
-  - libsqlite=3.40.0
-  - libuv=1.44.2
-  - libxml2=2.10.3
-  - libxslt=1.1.37
-  - libzlib=1.2.13
-  - lxml=4.9.2
-  - markupsafe=2.1.2
-  - matplotlib-inline=0.1.6
-  - multidict=6.0.4
-  - mypy=1.0.0
-  - mypy_extensions=1.0.0
-  - ncurses=6.3
-  - nest-asyncio=1.5.6
-  - nodeenv=1.7.0
-  - openssl=3.0.8
-  - packaging=21.3
-  - parse=1.19.0
-  - parso=0.8.3
-  - pexpect=4.8.0
-  - pickleshare=0.7.5
-  - pip=23.0
-  - platformdirs=2.6.2
-  - pre-commit=3.0.4
-  - prompt-toolkit=3.0.36
-  - prompt_toolkit=3.0.36
-  - psutil=5.9.4
-  - ptyprocess=0.7.0
-  - pure_eval=0.2.2
-  - pycares=4.0.0
-  - pycparser=2.21
-  - pyee=8.1.0
-  - pygments=2.14.0
-  - pyopenssl=23.0.0
-  - pyparsing=3.0.9
-  - pyppeteer=1.0.2
-  - pyquery=2.0.0
-  - pysocks=1.7.1
-  - python=3.11.0
-  - python-dateutil=2.8.2
-  - python-dotenv=0.21.1
-  - python_abi=3.11
-  - pyyaml=6.0
-  - pyzmq=25.0.0
-  - questionary=1.10.0
-  - readline=8.1.2
-  - requests=2.28.2
-  - requests-html=0.10.0
-  - setuptools=67.1.0
-  - six=1.16.0
-  - soupsieve=2.3.2.post1
-  - stack_data=0.6.2
-  - termcolor=1.1.0
-  - tk=8.6.12
-  - tomli=2.0.1
-  - tomlkit=0.11.6
-  - tornado=6.2
-  - tqdm=4.64.1
-  - traitlets=5.9.0
-  - typing=3.10.0.0
-  - typing-extensions=4.4.0
-  - typing_extensions=4.4.0
-  - tzdata=2022g
-  - ujson=5.7.0
-  - ukkonen=1.0.1
-  - urllib3=1.26.14
-  - uvloop=0.17.0
-  - virtualenv=20.18.0
-  - w3lib=2.1.1
-  - wcwidth=0.2.6
-  - websockets=10.4
-  - wheel=0.38.4
-  - xz=5.2.6
-  - yaml=0.2.5
-  - yarl=1.8.2
-  - zeromq=4.3.4
-  - zipp=3.12.1
+  - aiodns=3.0.0=pyhd8ed1ab_0
+  - aiohttp=3.8.3=py311he2be06e_1
+  - aiosignal=1.3.1=pyhd8ed1ab_0
+  - aiosqlite=0.18.0=pyhd8ed1ab_0
+  - appdirs=1.4.4=pyh9f0ad1d_0
+  - appnope=0.1.3=pyhd8ed1ab_0
+  - argcomplete=1.12.3=pyhd8ed1ab_0
+  - asttokens=2.2.1=pyhd8ed1ab_0
+  - async-timeout=4.0.2=pyhd8ed1ab_0
+  - async_generator=1.10=py_0
+  - attrs=22.2.0=pyh71513ae_0
+  - backcall=0.2.0=pyh9f0ad1d_0
+  - backports=1.0=pyhd8ed1ab_3
+  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
+  - beautifulsoup4=4.11.2=pyha770c72_0
+  - brotlipy=0.7.0=py311he2be06e_1005
+  - bzip2=1.0.8=h3422bc3_4
+  - ca-certificates=2022.12.7=h4653dfc_0
+  - cchardet=2.1.7=py311ha397e9f_4
+  - certifi=2022.12.7=pyhd8ed1ab_0
+  - cffi=1.15.1=py311hae827db_3
+  - cfgv=3.3.1=pyhd8ed1ab_0
+  - charset-normalizer=2.1.1=pyhd8ed1ab_0
+  - click=8.1.3=unix_pyhd8ed1ab_2
+  - colorama=0.4.6=pyhd8ed1ab_0
+  - comm=0.1.2=pyhd8ed1ab_0
+  - commitizen=2.28.1=pyhd8ed1ab_0
+  - cryptography=39.0.0=py311h507f6e9_0
+  - cssselect=1.2.0=pyhd8ed1ab_0
+  - debugpy=1.6.6=py311ha397e9f_0
+  - decli=0.5.2=pyhd8ed1ab_0
+  - decorator=5.1.1=pyhd8ed1ab_0
+  - distlib=0.3.6=pyhd8ed1ab_0
+  - docopt=0.6.2=py_1
+  - executing=1.2.0=pyhd8ed1ab_0
+  - fake-useragent=1.1.1=pyhd8ed1ab_0
+  - filelock=3.9.0=pyhd8ed1ab_0
+  - freetype=2.12.1=hd633e50_1
+  - frozenlist=1.3.3=py311he2be06e_0
+  - greenlet=2.0.2=py311ha397e9f_0
+  - h11=0.14.0=pyhd8ed1ab_0
+  - icu=70.1=h6b3803e_0
+  - identify=2.5.17=pyhd8ed1ab_0
+  - idna=3.4=pyhd8ed1ab_0
+  - importlib-metadata=6.0.0=pyha770c72_0
+  - importlib_metadata=6.0.0=hd8ed1ab_0
+  - ipykernel=6.21.1=pyh736e0ef_0
+  - ipython=8.9.0=pyhd1c38e8_0
+  - jedi=0.18.2=pyhd8ed1ab_0
+  - jinja2=3.1.2=pyhd8ed1ab_1
+  - jpeg=9e=h1a8c8d9_3
+  - jupyter_client=8.0.2=pyhd8ed1ab_0
+  - jupyter_core=5.2.0=py311h267d04e_0
+  - lcms2=2.14=h481adae_1
+  - lerc=4.0.0=h9a09cb3_0
+  - libcxx=14.0.6=h2692d47_0
+  - libdeflate=1.17=h1a8c8d9_0
+  - libffi=3.4.2=h3422bc3_5
+  - libiconv=1.17=he4db4b2_0
+  - libpng=1.6.39=h76d750c_0
+  - libsodium=1.0.18=h27ca646_1
+  - libsqlite=3.40.0=h76d750c_0
+  - libtiff=4.5.0=h5dffbdd_2
+  - libuv=1.44.2=he4db4b2_0
+  - libwebp-base=1.2.4=h57fd34a_0
+  - libxcb=1.13=h9b22ae9_1004
+  - libxml2=2.10.3=h87b0503_0
+  - libxslt=1.1.37=h1bd8bc4_0
+  - libzlib=1.2.13=h03a7124_4
+  - lxml=4.9.2=py311h246f609_0
+  - markupsafe=2.1.2=py311he2be06e_0
+  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
+  - multidict=6.0.4=py311he2be06e_0
+  - mypy=1.0.0=py311he2be06e_0
+  - mypy_extensions=1.0.0=pyha770c72_0
+  - ncurses=6.3=h07bb92c_1
+  - nest-asyncio=1.5.6=pyhd8ed1ab_0
+  - nodeenv=1.7.0=pyhd8ed1ab_0
+  - openjpeg=2.5.0=hbc2ba62_2
+  - openssl=3.0.8=h03a7124_0
+  - outcome=1.2.0=pyhd8ed1ab_0
+  - packaging=21.3=pyhd8ed1ab_0
+  - parse=1.19.0=pyh44b312d_0
+  - parso=0.8.3=pyhd8ed1ab_0
+  - pexpect=4.8.0=pyh1a96a4e_2
+  - pickleshare=0.7.5=py_1003
+  - pillow=9.4.0=py311h627eb56_1
+  - pip=23.0=pyhd8ed1ab_0
+  - pipreqs=0.4.11=pyhd8ed1ab_0
+  - platformdirs=2.6.2=pyhd8ed1ab_0
+  - pre-commit=3.0.4=py311h267d04e_0
+  - prompt-toolkit=3.0.36=pyha770c72_0
+  - prompt_toolkit=3.0.36=hd8ed1ab_0
+  - psutil=5.9.4=py311he2be06e_0
+  - pthread-stubs=0.4=h27ca646_1001
+  - ptyprocess=0.7.0=pyhd3deb0d_0
+  - pure_eval=0.2.2=pyhd8ed1ab_0
+  - pycares=4.0.0=py311he2be06e_2
+  - pycparser=2.21=pyhd8ed1ab_0
+  - pydantic=1.10.4=py311he2be06e_1
+  - pyee=8.1.0=pyhd8ed1ab_0
+  - pygments=2.14.0=pyhd8ed1ab_0
+  - pyopenssl=23.0.0=pyhd8ed1ab_0
+  - pyparsing=3.0.9=pyhd8ed1ab_0
+  - pyppeteer=1.0.2=pyhd8ed1ab_0
+  - pyquery=2.0.0=pyhd8ed1ab_0
+  - pysocks=1.7.1=pyha2e5f31_6
+  - python=3.11.0=h3ba56d0_1_cpython
+  - python-dateutil=2.8.2=pyhd8ed1ab_0
+  - python-dotenv=0.21.1=pyhd8ed1ab_0
+  - python_abi=3.11=3_cp311
+  - pyyaml=6.0=py311he2be06e_5
+  - pyzmq=25.0.0=py311h0f351f6_0
+  - questionary=1.10.0=pyhd8ed1ab_1
+  - readline=8.1.2=h46ed386_0
+  - requests=2.28.2=pyhd8ed1ab_0
+  - requests-html=0.10.0=pyhd8ed1ab_0
+  - selenium=4.7.2=pyhd8ed1ab_0
+  - setuptools=67.1.0=pyhd8ed1ab_0
+  - six=1.16.0=pyh6c4a22f_0
+  - sniffio=1.3.0=pyhd8ed1ab_0
+  - sortedcontainers=2.4.0=pyhd8ed1ab_0
+  - soupsieve=2.3.2.post1=pyhd8ed1ab_0
+  - stack_data=0.6.2=pyhd8ed1ab_0
+  - termcolor=1.1.0=pyhd8ed1ab_3
+  - tk=8.6.12=he1e0b03_0
+  - tomli=2.0.1=pyhd8ed1ab_0
+  - tomlkit=0.11.6=pyha770c72_0
+  - tornado=6.2=py311he2be06e_1
+  - tqdm=4.64.1=pyhd8ed1ab_0
+  - traitlets=5.9.0=pyhd8ed1ab_0
+  - trio=0.22.0=py311h267d04e_1
+  - trio-websocket=0.9.2=pyhd8ed1ab_0
+  - types-requests=2.28.11.12=pyhd8ed1ab_0
+  - types-urllib3=1.26.25.5=pyhd8ed1ab_0
+  - typing=3.10.0.0=pyhd8ed1ab_0
+  - typing-extensions=4.4.0=hd8ed1ab_0
+  - typing_extensions=4.4.0=pyha770c72_0
+  - tzdata=2022g=h191b570_0
+  - ujson=5.7.0=py311ha397e9f_0
+  - ukkonen=1.0.1=py311hd6ee22a_3
+  - urllib3=1.26.14=pyhd8ed1ab_0
+  - uvloop=0.17.0=py311he2be06e_1
+  - virtualenv=20.18.0=pyhd8ed1ab_0
+  - w3lib=2.1.1=pyhd8ed1ab_0
+  - wcwidth=0.2.6=pyhd8ed1ab_0
+  - webdriver-manager=3.8.5=pyhd8ed1ab_0
+  - websockets=10.4=py311he2be06e_1
+  - wheel=0.38.4=pyhd8ed1ab_0
+  - wsproto=1.2.0=pyhd8ed1ab_0
+  - xorg-libxau=1.0.9=h27ca646_0
+  - xorg-libxdmcp=1.1.3=h27ca646_0
+  - xz=5.2.6=h57fd34a_0
+  - yaml=0.2.5=h3422bc3_2
+  - yarg=0.1.9=py_1
+  - yarl=1.8.2=py311he2be06e_0
+  - zeromq=4.3.4=hbdafb3b_1
+  - zipp=3.12.1=pyhd8ed1ab_0
+  - zstd=1.5.2=hf913c23_6
   - pip:
       - aiogram==2.25.1
       - aiogram-dialog==1.9.0
@@ -135,4 +168,7 @@ dependencies:
       - cachetools==4.2.4
       - magic-filter==1.0.9
       - pytz==2022.7.1
+      - sqlalchemy==1.4.41
+      - sqlalchemy2-stubs==0.0.2a32
+      - sqlmodel==0.0.8
 prefix: /opt/homebrew/Caskroom/mambaforge/base/envs/vanity

+ 91 - 43
dialog.py

@@ -1,74 +1,122 @@
 
-from typing import Any
 import operator
+from asyncio import sleep, create_task, gather, get_event_loop, Queue, current_task
+from typing import Any, List
 from aiogram.dispatcher.filters.state import StatesGroup, State
 from aiogram.types import Message, CallbackQuery
-from aiogram_dialog import Window, Dialog, DialogManager, StartMode
+from aiogram_dialog import Window, Dialog, DialogManager, StartMode, Data
 from aiogram_dialog.widgets.kbd import Radio
 from aiogram_dialog.widgets.text import Format
 from aiogram.types import ParseMode
 
-from summarize import get_paper_desc, get_key_moments, get_summary
+from summarize import get_summary
 
+from progress import Bg, background
+
+warn = "***Note, that it is an AI generated summary, and it may contain complete bullshit***"
+
+class CrossData():
+    queue: Queue = Queue()
+    intent_queue: dict = {}
+    context: dict = {}
 
 class MySG(StatesGroup):
     main = State()
 
+cross_data = CrossData()
+
 buttons = [
         ("Abstract", '1'),
         ("Summary", '2'),
-        ("Key Moments", '3'),
+        ("Highlights", '3'),
+        ("Findings", '4')
     ]
 
+
 async def get_data(dialog_manager: DialogManager, **kwargs):
-    data = dialog_manager.current_context()
-    item_id = data.widget_data.get('radio_buttons') #type: ignore
-    p = {"text": "OOOPS!"}
-    title = data.start_data["title"] #type: ignore
-    url = data.start_data["url"] #type: ignore
-    
-    if data.dialog_data.get('abs'): #type: ignore
-        abst = data.dialog_data.get('abs') #type: ignore
-    else:
-        data.dialog_data["abs"] = data.start_data["reply_message"] #type: ignore
-        abst = data.dialog_data.get('abs') #type: ignore
-
-    if item_id == "2":
-        if data.dialog_data.get("summary"):
-            summ = data.dialog_data["summary"]
-            p = {"text": f"{url}\n\n***{title}***\n\n{summ}"}
-    elif item_id == "3":
-        if data.dialog_data.get("key_moments"):
-            keys = data.dialog_data.get("key_moments")
-            p = {"text": f"{url}\n\n***{title}***\n\n{keys}"}
+    text_message = {"text": "OOOPS!"}
+    if data := dialog_manager.current_context():
+        if cr_data:=cross_data.context.get(data.id):
+            data.dialog_data.update(cr_data.dialog_data)
+            del cross_data.context[data.id]
+        else:
+            cross_data.context[data.id] = None
+        item_id = data.widget_data.get('radio_buttons') #type: ignore
+        title = data.start_data["title"] #type: ignore
+        url = data.start_data["url"] #type: ignore
+        if data.dialog_data.get('abs'): #type: ignore
+            abstract = data.dialog_data.get('abs') #type: ignore
+        else:
+            data.dialog_data["abs"] = data.start_data["reply_message"] #type: ignore
+            abstract = data.dialog_data.get('abs') #type: ignore
+        if item_id == "2":
+            if summary:=data.dialog_data.get("summary"):
+                if isinstance(summary, list):
+                    while sum(len(sentence) for sentence in summary) > 3700:
+                        summary = summary[:-1]
+                    summary = " ".join(summary)
+                text_message = {"text": f"{url}\n\n***{title}***\n\n{summary}\n\n{warn}"}
+        elif item_id == "3":
+            if highlights:=data.dialog_data.get("highlights"):
+                if isinstance(highlights, list):
+                    highlights = "\n\n- ".join(highlights)
+                text_message = {"text": f"{url}\n\n***{title}***\n\n- {highlights}\n\n{warn}"}
+        elif item_id == "4":
+            if findings:=data.dialog_data.get("findings"):
+                if isinstance(findings, list):
+                    findings = "\n\n- ".join(findings)
+                text_message = {"text": f"{url}\n\n***{title}***\n\n- {findings}\n\n{warn}"}
+        else:
+            text_message = {"text": abstract}
+        return text_message
     else:
-        p = {"text": abst}
-    return p
+        return text_message
 
 
 async def on_button_selected(c: CallbackQuery, widget: Any, manager: DialogManager, item_id: str):
-    data = manager.current_context()
-    if item_id == "2":
-        if data.dialog_data.get('summary'):
-            pass
+    if context := manager.current_context():
+        id_, url = context.start_data["id"], context.start_data["url"] # type: ignore
+        context.widget_data["radio_buttons"] = item_id
+        cross_data.intent_queue[context.id] = None
+        cross_data.context[context.id] = context
+        if item_id == "2":
+            if context.dialog_data.get('summary'):
+                pass 
+            else:
+                await c.answer("Getting summary, please wait")
+                await manager.start(Bg.progress)
+                gather(
+                    background(c, manager.bg(), cross_data, context.id),
+                    get_summary(cross_data=cross_data, paper_id=id_, paper_url=url, context_id=context.id),
+                    )
+        elif item_id == "3":
+            if context.dialog_data.get("highlights"):
+                pass
+            else:
+                await c.answer("Getting highlights, please wait")
+                await manager.start(Bg.progress)
+                gather(
+                    background(c, manager.bg(), cross_data, context.id),
+                    get_summary(cross_data=cross_data, paper_id=id_, paper_url=url, context_id=context.id)
+                    )
+        elif item_id == "4":
+            if context.dialog_data.get("findings"):
+                pass
+            else:
+                await c.answer("Getting findings, please wait")
+                await manager.start(Bg.progress)
+                gather(
+                    background(c, manager.bg(), cross_data, context.id),
+                    get_summary(cross_data=cross_data, paper_id=id_, paper_url=url, context_id=context.id)
+                    )
         else:
-            await c.answer("Getting Summary, please wait")
-            summary = await get_summary(url = data.start_data["url"])
-            data.dialog_data["summary"] = summary
-    elif item_id == "3":
-        if data.dialog_data.get("key_moments"):
             pass
-        else:
-            await c.answer("Getting Key Moments, please wait")
-            key_moments = await get_key_moments(url=data.start_data["url"])
-            data.dialog_data["key_moments"] = key_moments
     else:
-        pass
+        return {"text": 1}
         
     return {"text": item_id}
 
 
-
 buttons_kbd = Radio(
     Format("✓ {item[0]}"),
     Format("{item[0]}"),
@@ -78,15 +126,15 @@ buttons_kbd = Radio(
     on_click=on_button_selected,
 )
 
+
 dialog = Dialog(
     Window(
         Format("{text}"),
         buttons_kbd,
         state=MySG.main,
         getter=get_data,
-        parse_mode=ParseMode.MARKDOWN,
+        parse_mode=ParseMode.MARKDOWN, # type: ignore
         # preview_data={"button": "1"}
     )
 )
 
-

+ 3 - 0
papers.db

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfe1c3e9f0cf4072550d42488db3e81008fd8279bd4c7f6326fd758c210cf9f9
+size 106496

+ 125 - 0
progress.py

@@ -0,0 +1,125 @@
+import asyncio
+import logging
+
+from aiogram import Bot, Dispatcher
+from aiogram.contrib.fsm_storage.memory import MemoryStorage
+from aiogram.dispatcher.filters.state import StatesGroup, State
+from aiogram.types import Message, CallbackQuery
+
+from aiogram_dialog import Dialog, DialogManager, Window, DialogRegistry, BaseDialogManager, \
+    StartMode
+from aiogram_dialog.widgets.kbd import Button
+from aiogram_dialog.widgets.text import Const, Multi, Text
+
+from typing import Any
+
+from config import API_TOKEN
+
+
+class Processing(Text):
+    def __init__(self, field: str, width: int = 10, filled="🟩", empty="⬜", when = None):
+        super().__init__(when)
+        self.field = field
+        self.width = width
+        self.filled = filled
+        self.empty = empty
+
+    async def _render_text(self, data: dict, manager: DialogManager) -> str:
+        if manager.is_preview():
+            percent = 15
+        else:
+            percent = data[self.field]
+        rest = self.width - percent
+
+        return f"processing: {self.filled * percent + self.empty * rest}"
+
+# name progress dialog
+class Bg(StatesGroup):
+    progress = State()
+
+
+async def get_bg_data(dialog_manager: DialogManager, **kwargs):
+    if context := dialog_manager.current_context():
+        return {"progress": context.dialog_data.get("progress", 0)}
+    else:
+        return {"progress": "OOOOPS"}
+
+
+bg_dialog = Dialog(
+    Window(
+        Multi(
+            Const("Summarizing is processing, please wait...\n"),
+            Processing("progress", 10),
+        ),
+        state=Bg.progress,
+        getter=get_bg_data,
+    ),
+)
+
+
+# main dialog
+class MainSG(StatesGroup):
+    main = State()
+
+
+async def start_bg(c: CallbackQuery, button: Button, manager: DialogManager):
+    await manager.start(Bg.progress)
+    asyncio.create_task(background(c, manager.bg()))
+
+
+async def background(c: CallbackQuery,
+                     manager: DialogManager | BaseDialogManager,
+                     cross_data: Any = None,
+                     context_id: str | None = None,
+                     time_out = 40):
+    i = 0
+    time = 0
+    await asyncio.sleep(1)
+    while True and time < time_out:
+
+        i = i + 1 if i < 10 else 0
+        time += 1
+
+        await manager.update({"progress": i})
+        await asyncio.sleep(1)
+
+        if mess:=cross_data.intent_queue[context_id]:
+            if mess == 'done':
+                del cross_data.intent_queue[context_id]
+                await manager.done()
+                return
+
+    await manager.done()
+
+
+main_menu = Dialog(
+    Window(
+        Const("Press button to start processing"),
+        Button(Const("Start 👀"), id="start", on_click=start_bg),
+        state=MainSG.main,
+        getter=get_bg_data
+    ),
+)
+
+
+async def start(m: Message, dialog_manager: DialogManager):
+    await dialog_manager.start(MainSG.main, mode=StartMode.RESET_STACK)
+
+
+async def main():
+    # real main
+    logging.basicConfig(level=logging.INFO)
+    logging.getLogger("aiogram_dialog").setLevel(logging.DEBUG)
+    storage = MemoryStorage()
+    bot = Bot(token=API_TOKEN) # type: ignore
+    dp = Dispatcher(bot, storage=storage)
+    registry = DialogRegistry(dp)
+    registry.register(bg_dialog)
+    registry.register(main_menu)
+    dp.register_message_handler(start, text="/start", state="*")
+
+    await dp.start_polling()
+
+
+if __name__ == '__main__':
+    asyncio.run(main())

+ 5 - 2
requirements.txt

@@ -1,9 +1,12 @@
 aiogram
+# cython
 aiohttp
 aiodns
 ujson
 uvloop
-cchardet
+# cchardet
 beautifulsoup4
 python-dotenv
-requests-html
+aiogram_dialog
+sqlmodel
+aiosqlite

+ 143 - 44
summarize.py

@@ -1,49 +1,148 @@
-from requests_html import AsyncHTMLSession
+import re
+import asyncio
 from bs4 import BeautifulSoup
 from contextlib import suppress
-from requests import get
-
-def get_paper_desc(id_paper: str) -> tuple | None:
-    request = get(f'https://arxiv.org/abs/{id_paper}')
-    if request.ok:
-        soup = BeautifulSoup(request.content, features="lxml")
-        with suppress(TypeError): 
-            url = soup.find('meta', property='og:url').get('content')
-            title = soup.find('meta', property='og:title').get('content')
-            description = soup.find('meta', property='og:description').get('content').replace('\n', ' ')
-            return url, title, description
-    return None
+from aiohttp import ClientSession
+from aiogram_dialog import DialogManager
+from typing import Any
 
-async def get_summary(url: str = "https://arxiv.org/abs/2102.12092v2") -> str:
-    url = url.replace("abs", "pdf")
-    async_session = AsyncHTMLSession()
-    async_response = await async_session.get(f"https://labs.kagi.com/ai/sum?url={url}.pdf")
-    await async_response.html.arender(sleep=5)
-    if res := async_response.html.find("p.description", first = True).text:
-        await async_session.close()
-        return res
-    else:
-        await async_response.html.arender(sleep=10)
-        if  res := async_response.html.find("p.description", first = True).text:
-            await async_session.close() 
-            return res
+from db import add_authors_and_paper, add_or_update_paper, check_paper
+
+base_url = "https://engine.scholarcy.com/api/"
+extract_url = "metadata/extract"
+highlights_url = "highlights/extract"
+summarize_endpoint = "https://summarizer.scholarcy.com/summarize"
+extract_endpoint = base_url + extract_url
+highlights_endpoint = base_url + highlights_url
+
+params = dict(
+    external_metadata="false",
+    parse_references="false",
+    generate_summary="true",
+    summary_engine="v4",
+    replace_pronouns="false",
+    strip_dialogue="false",
+    summary_size="400",
+    summary_percent="0",
+    structured_summary="false",
+    keyword_method="sgrank+acr",
+    keyword_limit="25",
+    abbreviation_method="schwartz",
+    extract_claims="true",
+    key_points="5",
+    citation_contexts="false",
+    inline_citation_links="false",
+    extract_pico="false",
+    extract_tables="false",
+    extract_figures="true",
+    require_captions="false",
+    extract_sections="false",
+    unstructured_content="false",
+    include_markdown="true",
+    extract_snippets="true",
+    engine="v2",
+    image_engine="v1+v2"
+)
+
+async def get_summary(cross_data: Any, 
+                      paper_id: str,
+                      paper_url: str,
+                      synopsys=False,
+                      highlights=True,
+                      context_id: str = "qwe"):
+
+    async def fetch_summary(paper_url: str, synopsys=False, highlights=False):
+        pdf_url = paper_url.replace("abs", "pdf") + ".pdf"
+        if highlights:
+            url = highlights_endpoint
         else:
-            await async_session.close()
-            return "Nothing to retrieve :("
-
-async def get_key_moments(url: str = "https://arxiv.org/abs/2102.12092v2") -> str:
-    url = url.replace("abs", "pdf")
-    async_session = AsyncHTMLSession()
-    async_response = await async_session.get(f"https://labs.kagi.com/ai/sum?url={url}.pdf&expand=1")
-    await async_response.html.arender(sleep=5)
-    if res := async_response.html.find("p.description", first = True).text:
-        await async_session.close() 
-        return res
-    else:
-        await async_response.html.arender(sleep=10)
-        if  res := async_response.html.find("p.description", first = True).text:
-            await async_session.close()
-            return res
+            url = extract_endpoint
+        if synopsys:
+            url = summarize_endpoint
+        params["url"] = pdf_url
+        async with ClientSession() as session:
+            async with await session.get(url, params=params) as response:
+                if response.ok:
+                    data = await response.json()
+                    if data.get("response"):
+                        return data["response"]
+                    else:
+                        return data
+                else:
+                    try:
+                        data = {"code_error": response.status,
+                                "message": (await response.json()).get("message")}
+                        return data
+                    except Exception as e:
+                        data = {"code_error": response.status,
+                                "message": e}
+                        return data
+
+    if paper := await check_paper(paper_id):
+        if paper.highlights:
+            data = {
+                "id": paper.id_,
+                "title": paper.title,
+                "abstract": paper.abstract,
+                "highlights": paper.highlights,
+                "findings": paper.findings,
+                "summary": paper.summary,
+                "figures_url": paper.figures_url,
+                "full_data": paper.full_data,
+                # "authors": paper.authors,
+            }
         else:
-            await async_session.close()
-            return "Nothing to retrieve :("
+            data = await fetch_summary(paper_url, synopsys, highlights)
+            if not data.get("code_error"):
+                await add_or_update_paper(paper_id, data)
+                data["id"] = paper.id_
+                data["title"] = paper.title
+                data["abstract"] = paper.abstract
+                # data["authors"] = paper.authors
+    else:
+        data = await fetch_summary(paper_url, synopsys, highlights)
+        if data.get("metadata"):
+            data["authors"] = data["metadata"].get("author").split(",").strip()
+            data["title"] = data["metadata"].get('title')
+            data["abstract"] = data["metadata"].get("abstract")
+            await add_authors_and_paper(paper_id, data)
+
+    # await asyncio.sleep(1)
+    cross_data.intent_queue[context_id] = "done"
+    cross_data.context[context_id].dialog_data.update(data)
+
+    return
+
+
+
+
+async def get_paper_desc(id_paper: str) -> dict | None:
+    if paper_ := await check_paper(id_paper):
+        paper = {
+            "id_" : paper_.id_,
+            "url" : f"https://arxiv.org/abs/{paper_.id_}",
+            "title" : paper_.title,
+            "abstract" : paper_.abstract,
+            "authors": None
+        }
+        return paper
+    else:
+        async with ClientSession() as session:
+            async with await session.get(f'https://arxiv.org/abs/{id_paper}') as request:
+                if request.ok:
+                    soup = BeautifulSoup(await request.text(), features="xml")
+                    try:
+                        url = soup.find('meta', property='og:url').get('content') # type: ignore
+                        paper = {
+                        "id_": re.findall(r'arxiv.org\/(?:abs|pdf)\/(\d{4}\.\d{5}[v]?[\d]?)', url)[0], # type: ignore
+                        "url" : url,  # type: ignore
+                        "title" : soup.find('meta', property='og:title').get('content'), # type: ignore
+                        "abstract" : soup.find('meta', property='og:description').get('content').replace('\n', ' '), # type: ignore
+                        "authors" : [name.text for name in soup.find("div", class_="authors").find_all("a")] # type: ignore
+                        }
+                        await add_authors_and_paper(paper["id_"], paper)
+                        return paper
+                    except TypeError:
+                        pass
+    return None
+

+ 27 - 9
vanitybot.py

@@ -1,17 +1,20 @@
 import re
 import logging
+import asyncio
 from aiogram.contrib.fsm_storage.memory import MemoryStorage
 from aiogram import Bot, Dispatcher, executor, types
 from aiogram_dialog import DialogManager, DialogRegistry, StartMode
 from config import API_TOKEN
 from dialog import dialog, MySG
+from progress import bg_dialog
 from summarize import get_paper_desc
 
 # Initialize bot and dispatcher
-bot = Bot(token=API_TOKEN)
+bot = Bot(token=API_TOKEN) # type: ignore
 dp = Dispatcher(bot, storage=MemoryStorage())
 registry = DialogRegistry(dp)
 registry.register(dialog)
+registry.register(bg_dialog)
 
 help_message = "Hello!\n\n\
 Send me a link paper from arxiv.org and \
@@ -27,13 +30,22 @@ async def process_start_command(message: types.Message):
 @dp.message_handler(commands=['help'])
 async def process_help_command(message: types.Message):
     await message.reply(help_message)
+    
+@dp.message_handler(commands=['long'])
+async def long(message: types.Message):
+    import random
+    long = "".join(str(random.randint(1,10)) for _ in range(3700))
+    await message.reply(long)
 
 
 @dp.message_handler(regexp=r'arxiv.org\/(?:abs|pdf)\/\d{4}\.\d{5}')
 async def vanitify(message: types.Message, dialog_manager: DialogManager):
-    papers_ids = re.findall(r'arxiv.org\/(?:abs|pdf)\/(\d{4}\.\d{5})', message.text)
+    papers_ids = re.findall(r'arxiv.org\/(?:abs|pdf)\/(\d{4}\.\d{5}[v]?[\d]?)', message.text)
+    
+    async def start_dialog(manager=dialog_manager.bg(), state=MySG.main, mode=StartMode.NEW_STACK, data={}):
+        await manager.start(state=state, mode=mode, data=data)
 
-    for id_ in papers_ids:
+    async def get_paper_abs(id_):
         reply_message = f"[Here you can read the paper in mobile friendly way](https://www.arxiv-vanity.com/papers/{id_})"
         data = {
             "id": id_,
@@ -42,22 +54,28 @@ async def vanitify(message: types.Message, dialog_manager: DialogManager):
             "title": None,
             "abs": None
         }
-        if desc := get_paper_desc(id_):
-            url, title, description = desc
-            reply_message = f'{url}\n\n***{title}***\n\n{description}\n\n{reply_message}'
+        if paper := await get_paper_desc(id_):
+            id_, url, title, abstract, authors = paper.values()
+            reply_message = f'{url}\n\n***{title}***\n\n{abstract}\n\n{reply_message}'
             data.update({
+                "id": id_,
                 "reply_message": reply_message,
                 "url": url,
                 "title": title,
-                "abs": description
+                "abs": abstract,
+                "authors": authors
                 })
         else:
             reply_message = f'Something went wrong. Can not reach arxiv.com :('
             data["reply_message"] = reply_message
+        return data
 
-        await dialog_manager.start(MySG.main, mode=StartMode.NEW_STACK, data=data)
+    list_data = await asyncio.gather(*[get_paper_abs(id_) for id_ in papers_ids])
+    asyncio.gather(*[start_dialog(data=data) for data in list_data])
 
 
 if __name__ == "__main__":
-    logging.basicConfig(level=logging.DEBUG)
+    logging.basicConfig(level=logging.INFO)
+    logging.getLogger("asyncio").setLevel(logging.DEBUG)
+    logging.getLogger("aiogram_dialog").setLevel(logging.DEBUG)
     executor.start_polling(dp, skip_updates=True)