Skip to content

Commit e350968

Browse files
committed
Changed vector type to return Vector class instead of NumPy array [skip ci]
1 parent 571bf42 commit e350968

File tree

9 files changed

+99
-124
lines changed

9 files changed

+99
-124
lines changed

pgvector/django/vector.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from django import forms
22
from django.db.models import Field
3-
import numpy as np
43
from .. import Vector
54

65

@@ -28,45 +27,33 @@ def from_db_value(self, value, expression, connection):
2827
return Vector._from_db(value)
2928

3029
def to_python(self, value):
31-
if isinstance(value, list):
32-
return np.array(value, dtype=np.float32)
33-
return Vector._from_db(value)
30+
if value is None or isinstance(value, Vector):
31+
return value
32+
elif isinstance(value, str):
33+
return Vector._from_db(value)
34+
else:
35+
return Vector(value)
3436

3537
def get_prep_value(self, value):
3638
return Vector._to_db(value)
3739

3840
def value_to_string(self, obj):
3941
return self.get_prep_value(self.value_from_object(obj))
4042

41-
def validate(self, value, model_instance):
42-
if isinstance(value, np.ndarray):
43-
value = value.tolist()
44-
super().validate(value, model_instance)
45-
46-
def run_validators(self, value):
47-
if isinstance(value, np.ndarray):
48-
value = value.tolist()
49-
super().run_validators(value)
50-
5143
def formfield(self, **kwargs):
5244
return super().formfield(form_class=VectorFormField, **kwargs)
5345

5446

5547
class VectorWidget(forms.TextInput):
5648
def format_value(self, value):
57-
if isinstance(value, np.ndarray):
58-
value = value.tolist()
49+
if isinstance(value, Vector):
50+
value = value.to_list()
5951
return super().format_value(value)
6052

6153

6254
class VectorFormField(forms.CharField):
6355
widget = VectorWidget
6456

65-
def has_changed(self, initial, data):
66-
if isinstance(initial, np.ndarray):
67-
initial = initial.tolist()
68-
return super().has_changed(initial, data)
69-
7057
def to_python(self, value):
7158
if isinstance(value, str) and value == '':
7259
return None

pgvector/vector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ def _to_db_binary(cls, value):
7070

7171
@classmethod
7272
def _from_db(cls, value):
73-
if value is None or isinstance(value, np.ndarray):
73+
if value is None or isinstance(value, cls):
7474
return value
7575

76-
return cls.from_text(value).to_numpy().astype(np.float32)
76+
return cls.from_text(value)
7777

7878
@classmethod
7979
def _from_db_binary(cls, value):
80-
if value is None or isinstance(value, np.ndarray):
80+
if value is None or isinstance(value, cls):
8181
return value
8282

83-
return cls.from_binary(value).to_numpy().astype(np.float32)
83+
return cls.from_binary(value)

tests/test_asyncpg.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncpg
22
import numpy as np
3-
from pgvector import SparseVector
3+
from pgvector import Vector, HalfVector, SparseVector
44
from pgvector.asyncpg import register_vector
55
import pytest
66

@@ -15,12 +15,11 @@ async def test_vector(self):
1515

1616
await register_vector(conn)
1717

18-
embedding = np.array([1.5, 2, 3])
18+
embedding = Vector([1.5, 2, 3])
1919
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
2020

2121
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
22-
assert np.array_equal(res[0]['embedding'], embedding)
23-
assert res[0]['embedding'].dtype == np.float32
22+
assert res[0]['embedding'] == embedding
2423
assert res[1]['embedding'] is None
2524

2625
# ensures binary format is correct
@@ -38,11 +37,11 @@ async def test_halfvec(self):
3837

3938
await register_vector(conn)
4039

41-
embedding = [1.5, 2, 3]
40+
embedding = HalfVector([1.5, 2, 3])
4241
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
4342

4443
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
45-
assert res[0]['embedding'].to_list() == [1.5, 2, 3]
44+
assert res[0]['embedding'] == embedding
4645
assert res[1]['embedding'] is None
4746

4847
# ensures binary format is correct
@@ -87,7 +86,7 @@ async def test_sparsevec(self):
8786
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
8887

8988
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
90-
assert res[0]['embedding'].to_list() == [1.5, 2, 3]
89+
assert res[0]['embedding'] == embedding
9190
assert res[1]['embedding'] is None
9291

9392
# ensures binary format is correct
@@ -105,12 +104,12 @@ async def test_vector_array(self):
105104

106105
await register_vector(conn)
107106

108-
embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])]
107+
embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])]
109108
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings[0], embeddings[1])
110109

111110
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
112-
assert np.array_equal(res[0]['embeddings'][0], embeddings[0])
113-
assert np.array_equal(res[0]['embeddings'][1], embeddings[1])
111+
assert res[0]['embeddings'][0] == embeddings[0]
112+
assert res[0]['embeddings'][1] == embeddings[1]
114113

115114
await conn.close()
116115

@@ -126,10 +125,9 @@ async def init(conn):
126125
await conn.execute('DROP TABLE IF EXISTS asyncpg_items')
127126
await conn.execute('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding vector(3))')
128127

129-
embedding = np.array([1.5, 2, 3])
128+
embedding = Vector([1.5, 2, 3])
130129
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding)
131130

132131
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
133-
assert np.array_equal(res[0]['embedding'], embedding)
134-
assert res[0]['embedding'].dtype == np.float32
132+
assert res[0]['embedding'] == embedding
135133
assert res[1]['embedding'] is None

tests/test_django.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313
import os
1414
import pgvector.django
15-
from pgvector import HalfVector, SparseVector
15+
from pgvector import Vector, HalfVector, SparseVector
1616
from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance
1717
from unittest import mock
1818

@@ -165,12 +165,11 @@ def setup_method(self):
165165
def test_vector(self):
166166
Item(id=1, embedding=[1, 2, 3]).save()
167167
item = Item.objects.get(pk=1)
168-
assert np.array_equal(item.embedding, np.array([1, 2, 3]))
169-
assert item.embedding.dtype == np.float32
168+
assert item.embedding == Vector([1, 2, 3])
170169

171170
def test_vector_l2_distance(self):
172171
create_items()
173-
distance = L2Distance('embedding', [1, 1, 1])
172+
distance = L2Distance('embedding', Vector([1, 1, 1]))
174173
items = Item.objects.annotate(distance=distance).order_by(distance)
175174
assert [v.id for v in items] == [1, 3, 2]
176175
assert [v.distance for v in items] == [0, 1, sqrt(3)]
@@ -293,31 +292,31 @@ def test_vector_avg(self):
293292
Item(embedding=[1, 2, 3]).save()
294293
Item(embedding=[4, 5, 6]).save()
295294
avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg']
296-
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
295+
assert avg == Vector([2.5, 3.5, 4.5])
297296

298297
def test_vector_sum(self):
299298
sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum']
300299
assert sum is None
301300
Item(embedding=[1, 2, 3]).save()
302301
Item(embedding=[4, 5, 6]).save()
303302
sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum']
304-
assert np.array_equal(sum, np.array([5, 7, 9]))
303+
assert sum == Vector([5, 7, 9])
305304

306305
def test_halfvec_avg(self):
307306
avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg']
308307
assert avg is None
309308
Item(half_embedding=[1, 2, 3]).save()
310309
Item(half_embedding=[4, 5, 6]).save()
311310
avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg']
312-
assert avg.to_list() == [2.5, 3.5, 4.5]
311+
assert avg == HalfVector([2.5, 3.5, 4.5])
313312

314313
def test_halfvec_sum(self):
315314
sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum']
316315
assert sum is None
317316
Item(half_embedding=[1, 2, 3]).save()
318317
Item(half_embedding=[4, 5, 6]).save()
319318
sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum']
320-
assert sum.to_list() == [5, 7, 9]
319+
assert sum == HalfVector([5, 7, 9])
321320

322321
def test_serialization(self):
323322
create_items()
@@ -347,7 +346,7 @@ def test_vector_form_save(self):
347346
assert form.has_changed()
348347
assert form.is_valid()
349348
assert form.save()
350-
assert [4, 5, 6] == Item.objects.get(pk=1).embedding.tolist()
349+
assert [4, 5, 6] == Item.objects.get(pk=1).embedding.to_list()
351350

352351
def test_vector_form_save_missing(self):
353352
Item(id=1).save()
@@ -465,8 +464,8 @@ def test_vector_array(self):
465464

466465
# this fails if the driver does not cast arrays
467466
item = Item.objects.get(pk=1)
468-
assert item.embeddings[0].tolist() == [1, 2, 3]
469-
assert item.embeddings[1].tolist() == [4, 5, 6]
467+
assert item.embeddings[0].to_list() == [1, 2, 3]
468+
assert item.embeddings[1].to_list() == [4, 5, 6]
470469

471470
def test_double_array(self):
472471
Item(id=1, double_embedding=[1, 1, 1]).save()

tests/test_peewee.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from math import sqrt
22
import numpy as np
33
from peewee import Model, PostgresqlDatabase, fn
4-
from pgvector import SparseVector
4+
from pgvector import Vector, HalfVector, SparseVector
55
from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField
66

77
db = PostgresqlDatabase('pgvector_python_test')
@@ -43,8 +43,7 @@ def setup_method(self):
4343
def test_vector(self):
4444
Item.create(id=1, embedding=[1, 2, 3])
4545
item = Item.get_by_id(1)
46-
assert np.array_equal(item.embedding, np.array([1, 2, 3]))
47-
assert item.embedding.dtype == np.float32
46+
assert item.embedding == Vector([1, 2, 3])
4847

4948
def test_vector_l2_distance(self):
5049
create_items()
@@ -170,31 +169,31 @@ def test_vector_avg(self):
170169
Item.create(embedding=[1, 2, 3])
171170
Item.create(embedding=[4, 5, 6])
172171
avg = Item.select(fn.avg(Item.embedding).coerce(True)).scalar()
173-
assert np.array_equal(avg, np.array([2.5, 3.5, 4.5]))
172+
assert avg == Vector([2.5, 3.5, 4.5])
174173

175174
def test_vector_sum(self):
176175
sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar()
177176
assert sum is None
178177
Item.create(embedding=[1, 2, 3])
179178
Item.create(embedding=[4, 5, 6])
180179
sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar()
181-
assert np.array_equal(sum, np.array([5, 7, 9]))
180+
assert sum == Vector([5, 7, 9])
182181

183182
def test_halfvec_avg(self):
184183
avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar()
185184
assert avg is None
186185
Item.create(half_embedding=[1, 2, 3])
187186
Item.create(half_embedding=[4, 5, 6])
188187
avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar()
189-
assert avg.to_list() == [2.5, 3.5, 4.5]
188+
assert avg == HalfVector([2.5, 3.5, 4.5])
190189

191190
def test_halfvec_sum(self):
192191
sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
193192
assert sum is None
194193
Item.create(half_embedding=[1, 2, 3])
195194
Item.create(half_embedding=[4, 5, 6])
196195
sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
197-
assert sum.to_list() == [5, 7, 9]
196+
assert sum == HalfVector([5, 7, 9])
198197

199198
def test_get_or_create(self):
200199
Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]})

0 commit comments

Comments
 (0)