Skip to content

Commit 030def9

Browse files
committed
Improved example and tests for arrays with SQLAlchemy - #101 [skip ci]
1 parent 0a76066 commit 030def9

File tree

2 files changed

+24
-21
lines changed

2 files changed

+24
-21
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,11 @@ And register the types with the underlying driver
284284

285285
```python
286286
from pgvector.psycopg2 import register_vector
287+
from sqlalchemy import engine
287288

288-
with session.connection() as connection:
289-
register_vector(connection.connection.dbapi_connection, globally=True, arrays=True)
289+
@event.listens_for(engine, "connect")
290+
def connect(dbapi_connection, connection_record):
291+
register_vector(dbapi_connection, arrays=True)
290292
```
291293

292294
## SQLModel

tests/test_sqlalchemy.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
33
import pytest
4-
from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
4+
from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY
55
from sqlalchemy.exc import StatementError
66
from sqlalchemy.ext.automap import automap_base
77
from sqlalchemy.orm import declarative_base, Session
@@ -20,6 +20,15 @@
2020
session.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
2121
session.commit()
2222

23+
array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test')
24+
25+
26+
@event.listens_for(array_engine, "connect")
27+
def connect(dbapi_connection, connection_record):
28+
from pgvector.psycopg2 import register_vector
29+
register_vector(dbapi_connection, globally=False, arrays=True)
30+
31+
2332
Base = declarative_base()
2433

2534

@@ -435,32 +444,24 @@ def test_automap(self):
435444
assert item.embedding.tolist() == [1, 2, 3]
436445

437446
def test_vector_array(self):
438-
session = Session(engine)
447+
session = Session(array_engine)
439448
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
440449
session.commit()
441450

442-
with session.connection() as connection:
443-
from pgvector.psycopg2 import register_vector
444-
register_vector(connection.connection.dbapi_connection, globally=False, arrays=True)
445-
446-
# this fails if the driver does not cast arrays
447-
item = session.get(Item, 1)
448-
assert item.embeddings[0].tolist() == [1, 2, 3]
449-
assert item.embeddings[1].tolist() == [4, 5, 6]
451+
# this fails if the driver does not cast arrays
452+
item = session.get(Item, 1)
453+
assert item.embeddings[0].tolist() == [1, 2, 3]
454+
assert item.embeddings[1].tolist() == [4, 5, 6]
450455

451456
def test_halfvec_array(self):
452-
session = Session(engine)
457+
session = Session(array_engine)
453458
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
454459
session.commit()
455460

456-
with session.connection() as connection:
457-
from pgvector.psycopg2 import register_vector
458-
register_vector(connection.connection.dbapi_connection, globally=False, arrays=True)
459-
460-
# this fails if the driver does not cast arrays
461-
item = session.get(Item, 1)
462-
assert item.half_embeddings[0].to_list() == [1, 2, 3]
463-
assert item.half_embeddings[1].to_list() == [4, 5, 6]
461+
# this fails if the driver does not cast arrays
462+
item = session.get(Item, 1)
463+
assert item.half_embeddings[0].to_list() == [1, 2, 3]
464+
assert item.half_embeddings[1].to_list() == [4, 5, 6]
464465

465466
def test_half_precision(self):
466467
create_items()

0 commit comments

Comments
 (0)