db.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import asyncio
  2. from sqlmodel import Field, SQLModel, create_engine, Relationship, select, Session, JSON, Column
  3. from sqlmodel.ext.asyncio.session import AsyncSession
  4. from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession as _AsyncSession
  5. from sqlalchemy.orm import selectinload
  6. async_sqlite_url = "sqlite+aiosqlite:///papers.db"
  7. echo = False
  8. async_engine = create_async_engine(async_sqlite_url, echo=echo)
  9. engine = create_engine("sqlite:///papers.db", echo=echo)
  10. # models
  11. class Papers2Authors(SQLModel, table=True):
  12. __table_args__ = {'extend_existing': True}
  13. paper_id: int | None = Field(default=None, primary_key=True, foreign_key="papers.id_")
  14. author_id: int | None = Field(default=None, primary_key=True, foreign_key="authors.id_")
  15. class Authors(SQLModel, table=True):
  16. __table_args__ = {'extend_existing': True}
  17. id_: int | None = Field(default=None, primary_key=True)
  18. name: str = Field(unique=True)
  19. papers: list["Papers"] = Relationship(back_populates="authors", link_model=Papers2Authors)
  20. class Papers(SQLModel, table=True):
  21. __table_args__ = {'extend_existing': True}
  22. id_: str = Field(primary_key=True)
  23. title: str
  24. abstract: str
  25. highlights: list[str] | None = Field(default=None, sa_column=Column(JSON))
  26. findings: list[str] | None = Field(default=None, sa_column=Column(JSON))
  27. summary: list[str] | None = Field(default=None, sa_column=Column(JSON))
  28. figures_url: list[str] | None = Field(default=None, sa_column=Column(JSON))
  29. full_data: dict | None = Field(default=None, sa_column=Column(JSON))
  30. authors: list["Authors"] | None = Relationship(back_populates="papers", link_model=Papers2Authors)
  31. class Config:
  32. arbitrary_types_allowed = True
  33. def create_db():
  34. engine = create_engine("sqlite:///papers.db", echo=True)
  35. SQLModel.metadata.create_all(engine, checkfirst=True)
  36. # CRUD
  37. def add_authors_to_paper(id_, paper: dict, authors):
  38. authors = [Authors(name=name) for name in authors]
  39. with Session(engine) as session:
  40. paper_ = session.get(Papers, id_)
  41. if paper_:
  42. if paper_.authors:
  43. paper_.authors.extend(authors)
  44. session.add(paper_)
  45. session.commit()
  46. else:
  47. print("OOOOPS")
  48. else:
  49. print('OOOPS')
  50. async def add_authors_and_paper(id_: str, paper: dict):
  51. if await check_paper(id_):
  52. return
  53. if authors := paper.get("authors"):
  54. paper_ = Papers(
  55. id_ = id_,
  56. title = paper["title"],
  57. abstract = paper["abstract"],
  58. highlights = paper.get("highlights"),
  59. summary = paper.get("summary"),
  60. figures_url = paper.get("figures_url"),
  61. findings = paper.get("findings"),
  62. full_data = paper.get("full_data") if paper.get("full_data") else paper
  63. )
  64. async with AsyncSession(async_engine, expire_on_commit=False) as session:
  65. queries = [select(Authors).where(Authors.name==name).
  66. options(selectinload(Authors.papers))
  67. for name in authors]
  68. results = await asyncio.gather(*[session.exec(query) for query in queries]) # type: ignore
  69. exist_authors = [author for res in results if (author:=res.first())]
  70. exist_names = [author.name for author in exist_authors if author]
  71. [author.papers.append(paper_) for author in exist_authors]
  72. session.add_all(exist_authors)
  73. new_authors = [author for author in authors if author not in exist_names]
  74. new_authors = [Authors(name=name, papers=[paper_]) for name in new_authors]
  75. session.add_all(new_authors)
  76. await session.commit()
  77. else:
  78. await add_or_update_paper(id_, paper)
  79. def add_paper_to_authors(id_, paper, authors):
  80. paper_ = Papers(
  81. id_ = id_,
  82. title = paper["title"],
  83. abstract = paper["abstract"],
  84. highlights = paper.get("highlights"),
  85. summary = paper.get("summary"),
  86. figures_url = paper.get("figures_url"),
  87. findings = paper.get("findings"),
  88. full_data = paper.get("full_data") if paper.get("full_data") else paper
  89. )
  90. with Session(engine, expire_on_commit=False) as session:
  91. queries = [select(Authors).where(Authors.name==name) for name in authors]
  92. results = [session.exec(query) for query in queries]
  93. exist_authors = [author for author in [author.first() for author in results] if author]
  94. [author.papers.append(paper_) for author in exist_authors]
  95. session.add_all(exist_authors)
  96. session.commit()
  97. async def check_paper(id_: str, async_engine=async_engine):
  98. async with AsyncSession(async_engine) as session:
  99. return await session.get(Papers, id_)
  100. async def add_or_update_paper(id_:str, paper: dict, async_engine = async_engine):
  101. if exists_paper := await check_paper(id_):
  102. paper_ = exists_paper
  103. paper_.highlights = paper.get("highlights")
  104. paper_.summary = paper.get("summary")
  105. paper_.findings = paper.get("findings")
  106. paper_.figures_url = paper.get("figures_url")
  107. paper_.full_data = paper.get("full_data") if paper.get("full_data") else paper
  108. else:
  109. paper_ = Papers(
  110. id_ = id_,
  111. title = paper["title"],
  112. abstract = paper["abstract"],
  113. highlights = paper.get("highlights"),
  114. summary = paper.get("summary"),
  115. figures_url = paper.get("figures_url"),
  116. findings = paper.get("findings"),
  117. full_data = paper.get("full_data") if paper.get("full_data") else paper
  118. )
  119. async with AsyncSession(async_engine) as session:
  120. session.add(paper_)
  121. await session.commit()
  122. if __name__ == "__main__":
  123. import json
  124. with open("paper.json", 'r') as file:
  125. paper = json.load(file)
  126. paper["title"] = paper["metadata"]["title"]
  127. paper["authors"] = paper["metadata"]["author"].split(",").strip()
  128. paper["abstract"] = paper["metadata"]["abstract"]
  129. create_db()
  130. async def main():
  131. await add_authors_and_paper("2203.02155v1", paper)
  132. asyncio.run(main())