From 6cc85a2d6cdcf23e358d33697875bd359ba7086f Mon Sep 17 00:00:00 2001 From: Vishal Shenoy Date: Wed, 29 Jan 2025 11:10:17 -0800 Subject: [PATCH 1/3] . --- .../output_repo/database.py | 10 -- .../sqlalchemy_1.6_to_2.0/output_repo/main.py | 102 ------------------ .../output_repo/models.py | 32 ------ .../output_repo/schemas.py | 31 ------ examples/sqlalchemy_1.6_to_2.0/run.py | 89 +++++++++++++++ 5 files changed, 89 insertions(+), 175 deletions(-) delete mode 100644 examples/sqlalchemy_1.6_to_2.0/output_repo/database.py delete mode 100644 examples/sqlalchemy_1.6_to_2.0/output_repo/main.py delete mode 100644 examples/sqlalchemy_1.6_to_2.0/output_repo/models.py delete mode 100644 examples/sqlalchemy_1.6_to_2.0/output_repo/schemas.py create mode 100644 examples/sqlalchemy_1.6_to_2.0/run.py diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/database.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/database.py deleted file mode 100644 index 26ccd07..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/database.py +++ /dev/null @@ -1,10 +0,0 @@ -# database.py -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -SQLALCHEMY_DATABASE_URL = "postgresql://user:password@localhost/dbname" # Change to your database URL - -engine = create_engine(SQLALCHEMY_DATABASE_URL) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -Base = declarative_base() diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/main.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/main.py deleted file mode 100644 index 93a716f..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/main.py +++ /dev/null @@ -1,102 +0,0 @@ -# main.py -from fastapi import FastAPI, Depends, HTTPException -from sqlalchemy.orm import Session -import models -import schemas -from database import SessionLocal, engine -from typing import List - -# Initialize the app and create database tables -app = FastAPI() -models.Base.metadata.create_all(bind=engine) - -# Dependency for the database session -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - -# CRUD Operations - -@app.post("/books/", response_model=schemas.Book) -def create_book(book: schemas.BookCreate, db: Session = Depends(get_db)): - db_book = models.Book(**book.dict()) - db.add(db_book) - db.commit() - db.refresh(db_book) - return db_book - -@app.get("/books/", response_model=List[schemas.Book]) -def read_books(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)): - books = db.query()(models.Book).offset(skip).limit(limit).scalars().all() - return books - -@app.get("/books/{book_id}", response_model=schemas.Book) -def read_book(book_id: int, db: Session = Depends(get_db)): - book = db.query()(models.Book).where(models.Book.id == book_id).first() - if book is None: - raise HTTPException(status_code=404, detail="Book not found") - return book - -@app.put("/books/{book_id}", response_model=schemas.Book) -def update_book(book_id: int, book: schemas.BookCreate, db: Session = Depends(get_db)): - db_book = db.query()(models.Book).where(models.Book.id == book_id).first() - if db_book is None: - raise HTTPException(status_code=404, detail="Book not found") - for key, value in book.dict().items(): - setattr(db_book, key, value) - db.commit() - db.refresh(db_book) - return db_book - -@app.delete("/books/{book_id}", response_model=schemas.Book) -def delete_book(book_id: int, db: Session = Depends(get_db)): - db_book = db.query()(models.Book).where(models.Book.id == book_id).first() - if db_book is None: - raise HTTPException(status_code=404, detail="Book not found") - db.delete(db_book) - db.commit() - return db_book - -@app.post("/publishers/", response_model=schemas.Publisher) -def create_publisher(publisher: schemas.PublisherCreate, db: Session = Depends(get_db)): - db_publisher = models.Publisher(**publisher.dict()) - db.add(db_publisher) - db.commit() - db.refresh(db_publisher) - return db_publisher - -@app.get("/publishers/", response_model=List[schemas.Publisher]) -def read_publishers(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)): - publishers = db.query()(models.Publisher).offset(skip).limit(limit).scalars().all() - return publishers - -@app.get("/publishers/{publisher_id}", response_model=schemas.Publisher) -def read_publisher(publisher_id: int, db: Session = Depends(get_db)): - publisher = db.query()(models.Publisher).where(models.Publisher.id == publisher_id).first() - if not publisher: - raise HTTPException(status_code=404, detail="Publisher not found") - return publisher - -@app.put("/publishers/{publisher_id}", response_model=schemas.Publisher) -def update_publisher(publisher_id: int, publisher: schemas.PublisherCreate, db: Session = Depends(get_db)): - db_publisher = db.query()(models.Publisher).where(models.Publisher.id == publisher_id).first() - if not db_publisher: - raise HTTPException(status_code=404, detail="Publisher not found") - for key, value in publisher.dict().items(): - setattr(db_publisher, key, value) - db.commit() - db.refresh(db_publisher) - return db_publisher - -@app.delete("/publishers/{publisher_id}", response_model=schemas.Publisher) -def delete_publisher(publisher_id: int, db: Session = Depends(get_db)): - db_publisher = db.query()(models.Publisher).where(models.Publisher.id == publisher_id).first() - if not db_publisher: - raise HTTPException(status_code=404, detail="Publisher not found") - db.delete(db_publisher) - db.commit() - return db_publisher - diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/models.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/models.py deleted file mode 100644 index d6fcceb..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/models.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import List, Optional -from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.orm import relationship, Mapped, mapped_column -from database import Base - -class Publisher(Base): - __tablename__ = "publishers" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) - name: Mapped[str] = mapped_column(String, unique=True, index=True) - books: Mapped[List["Book"]] = relationship( - "Book", - back_populates="publisher", - lazy='selectin' - ) - -class Book(Base): - __tablename__ = "books" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) - title: Mapped[str] = mapped_column(String, index=True) - author: Mapped[str] = mapped_column(String, index=True) - description: Mapped[Optional[str]] = mapped_column(String, nullable=True) - publisher_id: Mapped[Optional[int]] = mapped_column( - Integer, - ForeignKey("publishers.id"), - nullable=True - ) - publisher: Mapped[Optional["Publisher"]] = relationship( - "Publisher", - back_populates="books" - ) \ No newline at end of file diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/schemas.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/schemas.py deleted file mode 100644 index daf4fb9..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/schemas.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel -from typing import List, Optional - -class PublisherBase(BaseModel): - name: str - -class PublisherCreate(PublisherBase): - pass - -class Publisher(PublisherBase): - id: int - books: List["Book"] = [] - - class Config: - orm_mode = True - -class BookBase(BaseModel): - title: str - author: str - description: str - publisher_id: Optional[int] - -class BookCreate(BookBase): - pass - -class Book(BookBase): - id: int - publisher: Optional[Publisher] - - class Config: - orm_mode = True diff --git a/examples/sqlalchemy_1.6_to_2.0/run.py b/examples/sqlalchemy_1.6_to_2.0/run.py new file mode 100644 index 0000000..8b3e224 --- /dev/null +++ b/examples/sqlalchemy_1.6_to_2.0/run.py @@ -0,0 +1,89 @@ +import codegen +from codegen import Codebase +from codegen.sdk.core.detached_symbols.function_call import FunctionCall +from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute + + +@codegen.function("sqlalchemy-1.6-to-2.0") +def run(codebase: Codebase): + """ + Convert SQLAlchemy 1.6 codebases to 2.0. + """ + files_modified = 0 + functions_modified = 0 + + print("\nStarting SQLAlchemy 1.6 to 2.0 migration...") + + for file in codebase.files: + file_modified = False + print(f"\nProcessing file: {file.path}") + + # Step 1: Convert Query to Select + for call in file.function_calls: + if call.name == "query": + chain = call + while chain.parent and isinstance(chain.parent, ChainedAttribute): + chain = chain.parent + + original_code = chain.source + new_query = chain.source.replace("query(", "select(") + if "filter(" in new_query: + new_query = new_query.replace(".filter(", ".where(") + if "filter_by(" in new_query: + model = call.args[0].value + conditions = chain.source.split("filter_by(")[1].split(")")[0] + new_conditions = [f"{model}.{cond.strip().replace('=', ' == ')}" for cond in conditions.split(",")] + new_query = f".where({' & '.join(new_conditions)})" + if "execute" not in chain.parent.source: + new_query = f"execute({new_query}).scalars()" + print("\nConverting query:") + print("Original:", original_code) + print("New:", new_query) + chain.edit(new_query) + file_modified = True + functions_modified += 1 + + # Step 2: Modernize ORM Relationships + for cls in file.classes: + for attr in cls.attributes: + if isinstance(attr.value, FunctionCall) and attr.value.name == "relationship": + if "lazy=" not in attr.value.source: + original_rel = attr.value.source + new_rel = original_rel + ', lazy="selectin"' + if "backref" in new_rel: + new_rel = new_rel.replace("backref", "back_populates") + print("\nUpdating relationship:") + print("Original:", original_rel) + print("New:", new_rel) + attr.value.edit(new_rel) + file_modified = True + functions_modified += 1 + + # Step 3: Convert Column Definitions to Type Annotations + for cls in file.classes: + for attr in cls.attributes: + if "Column(" in attr.source: + original_attr = attr.source + new_attr = original_attr.replace("Column", "mapped_column") + type_hint = "Mapped" + original_attr.split("= Column")[1] + new_attr = f"{attr.name}: {type_hint}" + print("\nUpdating column definition:") + print("Original:", original_attr) + print("New:", new_attr) + attr.edit(new_attr) + file_modified = True + functions_modified += 1 + + if file_modified: + files_modified += 1 + + print("\nMigration complete:") + print(f"Files modified: {files_modified}") + print(f"Functions modified: {functions_modified}") + + +if __name__ == "__main__": + print("Initializing codebase...") + codebase = Codebase("./input_repo") + print("Running SQLAlchemy 1.6 to 2.0 codemod...") + run(codebase) From 3afef6b5c18160765b4b534cac1d064212326c26 Mon Sep 17 00:00:00 2001 From: codegen-bot Date: Wed, 29 Jan 2025 11:16:56 -0800 Subject: [PATCH 2/3] . --- examples/sqlalchemy_1.6_to_2.0/run.py | 36 +++++++++++++++++++-------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/sqlalchemy_1.6_to_2.0/run.py b/examples/sqlalchemy_1.6_to_2.0/run.py index 8b3e224..5a71233 100644 --- a/examples/sqlalchemy_1.6_to_2.0/run.py +++ b/examples/sqlalchemy_1.6_to_2.0/run.py @@ -36,9 +36,14 @@ def run(codebase: Codebase): new_query = f".where({' & '.join(new_conditions)})" if "execute" not in chain.parent.source: new_query = f"execute({new_query}).scalars()" - print("\nConverting query:") - print("Original:", original_code) - print("New:", new_query) + + print(f"\nConverting query in {file.path}:\n") + print("Original code:") + print(original_code) + print("\nNew code:") + print(new_query) + print("-" * 50) + chain.edit(new_query) file_modified = True functions_modified += 1 @@ -52,9 +57,14 @@ def run(codebase: Codebase): new_rel = original_rel + ', lazy="selectin"' if "backref" in new_rel: new_rel = new_rel.replace("backref", "back_populates") - print("\nUpdating relationship:") - print("Original:", original_rel) - print("New:", new_rel) + + print(f"\nUpdating relationship in class {cls.name}:\n") + print("Original code:") + print(original_rel) + print("\nNew code:") + print(new_rel) + print("-" * 50) + attr.value.edit(new_rel) file_modified = True functions_modified += 1 @@ -67,9 +77,14 @@ def run(codebase: Codebase): new_attr = original_attr.replace("Column", "mapped_column") type_hint = "Mapped" + original_attr.split("= Column")[1] new_attr = f"{attr.name}: {type_hint}" - print("\nUpdating column definition:") - print("Original:", original_attr) - print("New:", new_attr) + + print(f"\nUpdating column definition in class {cls.name}:\n") + print("Original code:") + print(original_attr) + print("\nNew code:") + print(new_attr) + print("-" * 50) + attr.edit(new_attr) file_modified = True functions_modified += 1 @@ -83,7 +98,8 @@ def run(codebase: Codebase): if __name__ == "__main__": + repo_path = "./input_repo" print("Initializing codebase...") - codebase = Codebase("./input_repo") + codebase = Codebase(repo_path) print("Running SQLAlchemy 1.6 to 2.0 codemod...") run(codebase) From 312180971f2eda553d615e17d870564fc677e43b Mon Sep 17 00:00:00 2001 From: codegen-bot Date: Wed, 29 Jan 2025 11:25:56 -0800 Subject: [PATCH 3/3] select --- examples/sqlalchemy_1.6_to_2.0/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sqlalchemy_1.6_to_2.0/run.py b/examples/sqlalchemy_1.6_to_2.0/run.py index 5a71233..639dabd 100644 --- a/examples/sqlalchemy_1.6_to_2.0/run.py +++ b/examples/sqlalchemy_1.6_to_2.0/run.py @@ -54,7 +54,7 @@ def run(codebase: Codebase): if isinstance(attr.value, FunctionCall) and attr.value.name == "relationship": if "lazy=" not in attr.value.source: original_rel = attr.value.source - new_rel = original_rel + ', lazy="selectin"' + new_rel = original_rel + ', lazy="select"' if "backref" in new_rel: new_rel = new_rel.replace("backref", "back_populates")