diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/database.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/database.py deleted file mode 100644 index 26ccd07..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/database.py +++ /dev/null @@ -1,10 +0,0 @@ -# database.py -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -SQLALCHEMY_DATABASE_URL = "postgresql://user:password@localhost/dbname" # Change to your database URL - -engine = create_engine(SQLALCHEMY_DATABASE_URL) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -Base = declarative_base() diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/main.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/main.py deleted file mode 100644 index 93a716f..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/main.py +++ /dev/null @@ -1,102 +0,0 @@ -# main.py -from fastapi import FastAPI, Depends, HTTPException -from sqlalchemy.orm import Session -import models -import schemas -from database import SessionLocal, engine -from typing import List - -# Initialize the app and create database tables -app = FastAPI() -models.Base.metadata.create_all(bind=engine) - -# Dependency for the database session -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - -# CRUD Operations - -@app.post("/books/", response_model=schemas.Book) -def create_book(book: schemas.BookCreate, db: Session = Depends(get_db)): - db_book = models.Book(**book.dict()) - db.add(db_book) - db.commit() - db.refresh(db_book) - return db_book - -@app.get("/books/", response_model=List[schemas.Book]) -def read_books(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)): - books = db.query()(models.Book).offset(skip).limit(limit).scalars().all() - return books - -@app.get("/books/{book_id}", response_model=schemas.Book) -def read_book(book_id: int, db: Session = Depends(get_db)): - book = db.query()(models.Book).where(models.Book.id == book_id).first() - if book is None: - raise HTTPException(status_code=404, detail="Book not found") - return book - -@app.put("/books/{book_id}", response_model=schemas.Book) -def update_book(book_id: int, book: schemas.BookCreate, db: Session = Depends(get_db)): - db_book = db.query()(models.Book).where(models.Book.id == book_id).first() - if db_book is None: - raise HTTPException(status_code=404, detail="Book not found") - for key, value in book.dict().items(): - setattr(db_book, key, value) - db.commit() - db.refresh(db_book) - return db_book - -@app.delete("/books/{book_id}", response_model=schemas.Book) -def delete_book(book_id: int, db: Session = Depends(get_db)): - db_book = db.query()(models.Book).where(models.Book.id == book_id).first() - if db_book is None: - raise HTTPException(status_code=404, detail="Book not found") - db.delete(db_book) - db.commit() - return db_book - -@app.post("/publishers/", response_model=schemas.Publisher) -def create_publisher(publisher: schemas.PublisherCreate, db: Session = Depends(get_db)): - db_publisher = models.Publisher(**publisher.dict()) - db.add(db_publisher) - db.commit() - db.refresh(db_publisher) - return db_publisher - -@app.get("/publishers/", response_model=List[schemas.Publisher]) -def read_publishers(skip: int = 0, limit: int = 10, db: Session = Depends(get_db)): - publishers = db.query()(models.Publisher).offset(skip).limit(limit).scalars().all() - return publishers - -@app.get("/publishers/{publisher_id}", response_model=schemas.Publisher) -def read_publisher(publisher_id: int, db: Session = Depends(get_db)): - publisher = db.query()(models.Publisher).where(models.Publisher.id == publisher_id).first() - if not publisher: - raise HTTPException(status_code=404, detail="Publisher not found") - return publisher - -@app.put("/publishers/{publisher_id}", response_model=schemas.Publisher) -def update_publisher(publisher_id: int, publisher: schemas.PublisherCreate, db: Session = Depends(get_db)): - db_publisher = db.query()(models.Publisher).where(models.Publisher.id == publisher_id).first() - if not db_publisher: - raise HTTPException(status_code=404, detail="Publisher not found") - for key, value in publisher.dict().items(): - setattr(db_publisher, key, value) - db.commit() - db.refresh(db_publisher) - return db_publisher - -@app.delete("/publishers/{publisher_id}", response_model=schemas.Publisher) -def delete_publisher(publisher_id: int, db: Session = Depends(get_db)): - db_publisher = db.query()(models.Publisher).where(models.Publisher.id == publisher_id).first() - if not db_publisher: - raise HTTPException(status_code=404, detail="Publisher not found") - db.delete(db_publisher) - db.commit() - return db_publisher - diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/models.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/models.py deleted file mode 100644 index d6fcceb..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/models.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import List, Optional -from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.orm import relationship, Mapped, mapped_column -from database import Base - -class Publisher(Base): - __tablename__ = "publishers" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) - name: Mapped[str] = mapped_column(String, unique=True, index=True) - books: Mapped[List["Book"]] = relationship( - "Book", - back_populates="publisher", - lazy='selectin' - ) - -class Book(Base): - __tablename__ = "books" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) - title: Mapped[str] = mapped_column(String, index=True) - author: Mapped[str] = mapped_column(String, index=True) - description: Mapped[Optional[str]] = mapped_column(String, nullable=True) - publisher_id: Mapped[Optional[int]] = mapped_column( - Integer, - ForeignKey("publishers.id"), - nullable=True - ) - publisher: Mapped[Optional["Publisher"]] = relationship( - "Publisher", - back_populates="books" - ) \ No newline at end of file diff --git a/examples/sqlalchemy_1.6_to_2.0/output_repo/schemas.py b/examples/sqlalchemy_1.6_to_2.0/output_repo/schemas.py deleted file mode 100644 index daf4fb9..0000000 --- a/examples/sqlalchemy_1.6_to_2.0/output_repo/schemas.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel -from typing import List, Optional - -class PublisherBase(BaseModel): - name: str - -class PublisherCreate(PublisherBase): - pass - -class Publisher(PublisherBase): - id: int - books: List["Book"] = [] - - class Config: - orm_mode = True - -class BookBase(BaseModel): - title: str - author: str - description: str - publisher_id: Optional[int] - -class BookCreate(BookBase): - pass - -class Book(BookBase): - id: int - publisher: Optional[Publisher] - - class Config: - orm_mode = True diff --git a/examples/sqlalchemy_1.6_to_2.0/run.py b/examples/sqlalchemy_1.6_to_2.0/run.py new file mode 100644 index 0000000..639dabd --- /dev/null +++ b/examples/sqlalchemy_1.6_to_2.0/run.py @@ -0,0 +1,105 @@ +import codegen +from codegen import Codebase +from codegen.sdk.core.detached_symbols.function_call import FunctionCall +from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute + + +@codegen.function("sqlalchemy-1.6-to-2.0") +def run(codebase: Codebase): + """ + Convert SQLAlchemy 1.6 codebases to 2.0. + """ + files_modified = 0 + functions_modified = 0 + + print("\nStarting SQLAlchemy 1.6 to 2.0 migration...") + + for file in codebase.files: + file_modified = False + print(f"\nProcessing file: {file.path}") + + # Step 1: Convert Query to Select + for call in file.function_calls: + if call.name == "query": + chain = call + while chain.parent and isinstance(chain.parent, ChainedAttribute): + chain = chain.parent + + original_code = chain.source + new_query = chain.source.replace("query(", "select(") + if "filter(" in new_query: + new_query = new_query.replace(".filter(", ".where(") + if "filter_by(" in new_query: + model = call.args[0].value + conditions = chain.source.split("filter_by(")[1].split(")")[0] + new_conditions = [f"{model}.{cond.strip().replace('=', ' == ')}" for cond in conditions.split(",")] + new_query = f".where({' & '.join(new_conditions)})" + if "execute" not in chain.parent.source: + new_query = f"execute({new_query}).scalars()" + + print(f"\nConverting query in {file.path}:\n") + print("Original code:") + print(original_code) + print("\nNew code:") + print(new_query) + print("-" * 50) + + chain.edit(new_query) + file_modified = True + functions_modified += 1 + + # Step 2: Modernize ORM Relationships + for cls in file.classes: + for attr in cls.attributes: + if isinstance(attr.value, FunctionCall) and attr.value.name == "relationship": + if "lazy=" not in attr.value.source: + original_rel = attr.value.source + new_rel = original_rel + ', lazy="select"' + if "backref" in new_rel: + new_rel = new_rel.replace("backref", "back_populates") + + print(f"\nUpdating relationship in class {cls.name}:\n") + print("Original code:") + print(original_rel) + print("\nNew code:") + print(new_rel) + print("-" * 50) + + attr.value.edit(new_rel) + file_modified = True + functions_modified += 1 + + # Step 3: Convert Column Definitions to Type Annotations + for cls in file.classes: + for attr in cls.attributes: + if "Column(" in attr.source: + original_attr = attr.source + new_attr = original_attr.replace("Column", "mapped_column") + type_hint = "Mapped" + original_attr.split("= Column")[1] + new_attr = f"{attr.name}: {type_hint}" + + print(f"\nUpdating column definition in class {cls.name}:\n") + print("Original code:") + print(original_attr) + print("\nNew code:") + print(new_attr) + print("-" * 50) + + attr.edit(new_attr) + file_modified = True + functions_modified += 1 + + if file_modified: + files_modified += 1 + + print("\nMigration complete:") + print(f"Files modified: {files_modified}") + print(f"Functions modified: {functions_modified}") + + +if __name__ == "__main__": + repo_path = "./input_repo" + print("Initializing codebase...") + codebase = Codebase(repo_path) + print("Running SQLAlchemy 1.6 to 2.0 codemod...") + run(codebase)