|
1 | 1 | import numpy as np |
2 | 2 | from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum |
3 | 3 | import pytest |
4 | | -from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY |
| 4 | +from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY |
5 | 5 | from sqlalchemy.exc import StatementError |
6 | 6 | from sqlalchemy.ext.automap import automap_base |
7 | 7 | from sqlalchemy.orm import declarative_base, Session |
|
20 | 20 | session.execute(text('CREATE EXTENSION IF NOT EXISTS vector')) |
21 | 21 | session.commit() |
22 | 22 |
|
| 23 | +array_engine = create_engine('postgresql+psycopg2://localhost/pgvector_python_test') |
| 24 | + |
| 25 | + |
| 26 | +@event.listens_for(array_engine, "connect") |
| 27 | +def connect(dbapi_connection, connection_record): |
| 28 | + from pgvector.psycopg2 import register_vector |
| 29 | + register_vector(dbapi_connection, globally=False, arrays=True) |
| 30 | + |
| 31 | + |
23 | 32 | Base = declarative_base() |
24 | 33 |
|
25 | 34 |
|
@@ -435,32 +444,24 @@ def test_automap(self): |
435 | 444 | assert item.embedding.tolist() == [1, 2, 3] |
436 | 445 |
|
437 | 446 | def test_vector_array(self): |
438 | | - session = Session(engine) |
| 447 | + session = Session(array_engine) |
439 | 448 | session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) |
440 | 449 | session.commit() |
441 | 450 |
|
442 | | - with session.connection() as connection: |
443 | | - from pgvector.psycopg2 import register_vector |
444 | | - register_vector(connection.connection.dbapi_connection, globally=False, arrays=True) |
445 | | - |
446 | | - # this fails if the driver does not cast arrays |
447 | | - item = session.get(Item, 1) |
448 | | - assert item.embeddings[0].tolist() == [1, 2, 3] |
449 | | - assert item.embeddings[1].tolist() == [4, 5, 6] |
| 451 | + # this fails if the driver does not cast arrays |
| 452 | + item = session.get(Item, 1) |
| 453 | + assert item.embeddings[0].tolist() == [1, 2, 3] |
| 454 | + assert item.embeddings[1].tolist() == [4, 5, 6] |
450 | 455 |
|
451 | 456 | def test_halfvec_array(self): |
452 | | - session = Session(engine) |
| 457 | + session = Session(array_engine) |
453 | 458 | session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) |
454 | 459 | session.commit() |
455 | 460 |
|
456 | | - with session.connection() as connection: |
457 | | - from pgvector.psycopg2 import register_vector |
458 | | - register_vector(connection.connection.dbapi_connection, globally=False, arrays=True) |
459 | | - |
460 | | - # this fails if the driver does not cast arrays |
461 | | - item = session.get(Item, 1) |
462 | | - assert item.half_embeddings[0].to_list() == [1, 2, 3] |
463 | | - assert item.half_embeddings[1].to_list() == [4, 5, 6] |
| 461 | + # this fails if the driver does not cast arrays |
| 462 | + item = session.get(Item, 1) |
| 463 | + assert item.half_embeddings[0].to_list() == [1, 2, 3] |
| 464 | + assert item.half_embeddings[1].to_list() == [4, 5, 6] |
464 | 465 |
|
465 | 466 | def test_half_precision(self): |
466 | 467 | create_items() |
|
0 commit comments