|
12 | 12 | import numpy as np |
13 | 13 | import os |
14 | 14 | import pgvector.django |
15 | | -from pgvector import HalfVector, SparseVector |
| 15 | +from pgvector import Vector, HalfVector, SparseVector |
16 | 16 | from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance |
17 | 17 | from unittest import mock |
18 | 18 |
|
@@ -165,12 +165,11 @@ def setup_method(self): |
165 | 165 | def test_vector(self): |
166 | 166 | Item(id=1, embedding=[1, 2, 3]).save() |
167 | 167 | item = Item.objects.get(pk=1) |
168 | | - assert np.array_equal(item.embedding, np.array([1, 2, 3])) |
169 | | - assert item.embedding.dtype == np.float32 |
| 168 | + assert item.embedding == Vector([1, 2, 3]) |
170 | 169 |
|
171 | 170 | def test_vector_l2_distance(self): |
172 | 171 | create_items() |
173 | | - distance = L2Distance('embedding', [1, 1, 1]) |
| 172 | + distance = L2Distance('embedding', Vector([1, 1, 1])) |
174 | 173 | items = Item.objects.annotate(distance=distance).order_by(distance) |
175 | 174 | assert [v.id for v in items] == [1, 3, 2] |
176 | 175 | assert [v.distance for v in items] == [0, 1, sqrt(3)] |
@@ -293,31 +292,31 @@ def test_vector_avg(self): |
293 | 292 | Item(embedding=[1, 2, 3]).save() |
294 | 293 | Item(embedding=[4, 5, 6]).save() |
295 | 294 | avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg'] |
296 | | - assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) |
| 295 | + assert avg == Vector([2.5, 3.5, 4.5]) |
297 | 296 |
|
298 | 297 | def test_vector_sum(self): |
299 | 298 | sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] |
300 | 299 | assert sum is None |
301 | 300 | Item(embedding=[1, 2, 3]).save() |
302 | 301 | Item(embedding=[4, 5, 6]).save() |
303 | 302 | sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] |
304 | | - assert np.array_equal(sum, np.array([5, 7, 9])) |
| 303 | + assert sum == Vector([5, 7, 9]) |
305 | 304 |
|
306 | 305 | def test_halfvec_avg(self): |
307 | 306 | avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] |
308 | 307 | assert avg is None |
309 | 308 | Item(half_embedding=[1, 2, 3]).save() |
310 | 309 | Item(half_embedding=[4, 5, 6]).save() |
311 | 310 | avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] |
312 | | - assert avg.to_list() == [2.5, 3.5, 4.5] |
| 311 | + assert avg == HalfVector([2.5, 3.5, 4.5]) |
313 | 312 |
|
314 | 313 | def test_halfvec_sum(self): |
315 | 314 | sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum'] |
316 | 315 | assert sum is None |
317 | 316 | Item(half_embedding=[1, 2, 3]).save() |
318 | 317 | Item(half_embedding=[4, 5, 6]).save() |
319 | 318 | sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum'] |
320 | | - assert sum.to_list() == [5, 7, 9] |
| 319 | + assert sum == HalfVector([5, 7, 9]) |
321 | 320 |
|
322 | 321 | def test_serialization(self): |
323 | 322 | create_items() |
@@ -347,7 +346,7 @@ def test_vector_form_save(self): |
347 | 346 | assert form.has_changed() |
348 | 347 | assert form.is_valid() |
349 | 348 | assert form.save() |
350 | | - assert [4, 5, 6] == Item.objects.get(pk=1).embedding.tolist() |
| 349 | + assert [4, 5, 6] == Item.objects.get(pk=1).embedding.to_list() |
351 | 350 |
|
352 | 351 | def test_vector_form_save_missing(self): |
353 | 352 | Item(id=1).save() |
@@ -465,8 +464,8 @@ def test_vector_array(self): |
465 | 464 |
|
466 | 465 | # this fails if the driver does not cast arrays |
467 | 466 | item = Item.objects.get(pk=1) |
468 | | - assert item.embeddings[0].tolist() == [1, 2, 3] |
469 | | - assert item.embeddings[1].tolist() == [4, 5, 6] |
| 467 | + assert item.embeddings[0].to_list() == [1, 2, 3] |
| 468 | + assert item.embeddings[1].to_list() == [4, 5, 6] |
470 | 469 |
|
471 | 470 | def test_double_array(self): |
472 | 471 | Item(id=1, double_embedding=[1, 1, 1]).save() |
|
0 commit comments