diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs
deleted file mode 100644
index d1d4bfa..0000000
--- a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs
+++ /dev/null
@@ -1,19 +0,0 @@
-using Microsoft.EntityFrameworkCore.Storage;
-using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
-using NpgsqlTypes;
-
-namespace Pgvector.EntityFrameworkCore;
-
-public class HalfvecTypeMapping : RelationalTypeMapping
-{
- public static HalfvecTypeMapping Default { get; } = new();
-
- public HalfvecTypeMapping() : base("halfvec", typeof(HalfVector)) { }
-
- public HalfvecTypeMapping(string storeType) : base(storeType, typeof(HalfVector)) { }
-
- protected HalfvecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }
-
- protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
- => new HalfvecTypeMapping(parameters);
-}
diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs
deleted file mode 100644
index 663bf71..0000000
--- a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs
+++ /dev/null
@@ -1,11 +0,0 @@
-using Microsoft.EntityFrameworkCore.Storage;
-
-namespace Pgvector.EntityFrameworkCore;
-
-public class HalfvecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
-{
- public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
- => mappingInfo.ClrType == typeof(HalfVector)
- ? new HalfvecTypeMapping(mappingInfo.StoreTypeName ?? "halfvec")
- : null;
-}
diff --git a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj
index f90f39f..604772c 100644
--- a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj
+++ b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj
@@ -20,6 +20,10 @@
+
+ True
+ build
+
diff --git a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs
deleted file mode 100644
index 422a7cd..0000000
--- a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs
+++ /dev/null
@@ -1,19 +0,0 @@
-using Microsoft.EntityFrameworkCore.Storage;
-using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
-using NpgsqlTypes;
-
-namespace Pgvector.EntityFrameworkCore;
-
-public class SparsevecTypeMapping : RelationalTypeMapping
-{
- public static SparsevecTypeMapping Default { get; } = new();
-
- public SparsevecTypeMapping() : base("sparsevec", typeof(SparseVector)) { }
-
- public SparsevecTypeMapping(string storeType) : base(storeType, typeof(SparseVector)) { }
-
- protected SparsevecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }
-
- protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
- => new SparsevecTypeMapping(parameters);
-}
diff --git a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs
deleted file mode 100644
index 290e391..0000000
--- a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs
+++ /dev/null
@@ -1,11 +0,0 @@
-using Microsoft.EntityFrameworkCore.Storage;
-
-namespace Pgvector.EntityFrameworkCore;
-
-public class SparsevecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
-{
- public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
- => mappingInfo.ClrType == typeof(SparseVector)
- ? new SparsevecTypeMapping(mappingInfo.StoreTypeName ?? "sparsevec")
- : null;
-}
diff --git a/src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs b/src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs
new file mode 100644
index 0000000..0abaf5a
--- /dev/null
+++ b/src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Reflection;
+using Microsoft.EntityFrameworkCore;
+using Microsoft.EntityFrameworkCore.Design;
+using Microsoft.EntityFrameworkCore.Scaffolding;
+using Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure;
+
+namespace Pgvector.EntityFrameworkCore;
+
+public class VectorCodeGeneratorPlugin : ProviderCodeGeneratorPlugin
+{
+ private static readonly MethodInfo _useVectorMethodInfo
+ = typeof(VectorDbContextOptionsBuilderExtensions).GetMethod(
+ nameof(VectorDbContextOptionsBuilderExtensions.UseVector),
+ [typeof(NpgsqlDbContextOptionsBuilder)])!;
+
+ ///
+ /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
+ /// the same compatibility standards as public APIs. It may be changed or removed without notice in
+ /// any release. You should only use it directly in your code with extreme caution and knowing that
+ /// doing so can result in application failures when updating to a new Entity Framework Core release.
+ ///
+ public override MethodCallCodeFragment GenerateProviderOptions()
+ => new(_useVectorMethodInfo);
+
+}
diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs
index ca1af79..570e4dc 100644
--- a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs
+++ b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs
@@ -17,8 +17,6 @@ public void ApplyServices(IServiceCollection services)
.TryAdd();
services.AddSingleton();
- services.AddSingleton();
- services.AddSingleton();
}
public void Validate(IDbContextOptions options) { }
diff --git a/src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs b/src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs
new file mode 100644
index 0000000..23e6a96
--- /dev/null
+++ b/src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs
@@ -0,0 +1,14 @@
+using Microsoft.EntityFrameworkCore.Design;
+using Microsoft.EntityFrameworkCore.Scaffolding;
+using Microsoft.EntityFrameworkCore.Storage;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Pgvector.EntityFrameworkCore;
+
+public class VectorDesignTimeServices : IDesignTimeServices
+{
+ public virtual void ConfigureDesignTimeServices(IServiceCollection serviceCollection)
+ => serviceCollection
+ .AddSingleton()
+ .AddSingleton();
+}
diff --git a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs
index 0452c16..6ac4e0b 100644
--- a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs
+++ b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs
@@ -6,11 +6,18 @@ namespace Pgvector.EntityFrameworkCore;
public class VectorTypeMapping : RelationalTypeMapping
{
- public static VectorTypeMapping Default { get; } = new();
+ public static VectorTypeMapping Default { get; } = new("vector", typeof(Vector));
- public VectorTypeMapping() : base("vector", typeof(Vector)) { }
-
- public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector)) { }
+ public VectorTypeMapping(string storeType, Type clrType, int? size = null)
+ : this(
+ new RelationalTypeMappingParameters(
+ new CoreTypeMappingParameters(clrType),
+ storeType,
+ StoreTypePostfix.Size,
+ size: size,
+ fixedLength: size is not null))
+ {
+ }
protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }
diff --git a/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs
index be2beab..d299068 100644
--- a/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs
+++ b/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs
@@ -5,7 +5,30 @@ namespace Pgvector.EntityFrameworkCore;
public class VectorTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
{
public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
- => mappingInfo.ClrType == typeof(Vector)
- ? new VectorTypeMapping(mappingInfo.StoreTypeName ?? "vector")
- : null;
+ {
+ if (mappingInfo.StoreTypeName is not null)
+ {
+ VectorTypeMapping? mapping = (mappingInfo.StoreTypeNameBase ?? mappingInfo.StoreTypeName) switch
+ {
+ "vector" => new(mappingInfo.StoreTypeName, typeof(Vector), mappingInfo.Size),
+ "halfvec" => new(mappingInfo.StoreTypeName, typeof(HalfVector), mappingInfo.Size),
+ "sparsevec" => new(mappingInfo.StoreTypeName, typeof(SparseVector), mappingInfo.Size),
+ _ => null,
+ };
+
+ // If the caller hasn't specified a CLR type (this is scaffolding), or if the user has specified
+ // the one matching the store type, return the mapping.
+ return mappingInfo.ClrType is null || mappingInfo.ClrType == mapping?.ClrType
+ ? mapping : null;
+ }
+
+ // No store type specified, look up by the CLR type only
+ return mappingInfo.ClrType switch
+ {
+ var t when t == typeof(Vector) => new VectorTypeMapping("vector", typeof(Vector), mappingInfo.Size),
+ var t when t == typeof(HalfVector) => new VectorTypeMapping("halfvec", typeof(HalfVector), mappingInfo.Size),
+ var t when t == typeof(SparseVector) => new VectorTypeMapping("sparsevec", typeof(SparseVector), mappingInfo.Size),
+ _ => null,
+ };
+ }
}
diff --git a/src/Pgvector.EntityFrameworkCore/build/netstandard2.0/Pgvector.EntityFrameworkCore.targets b/src/Pgvector.EntityFrameworkCore/build/netstandard2.0/Pgvector.EntityFrameworkCore.targets
new file mode 100644
index 0000000..7850e59
--- /dev/null
+++ b/src/Pgvector.EntityFrameworkCore/build/netstandard2.0/Pgvector.EntityFrameworkCore.targets
@@ -0,0 +1,46 @@
+
+
+ $(MSBuildAllProjects);$(MSBuildThisFileFullPath)
+ $(IntermediateOutputPath)EFCoreNpgsqlPgvector$(DefaultLanguageSourceExtension)
+
+
+
+
+
+
+ CompileBefore
+
+
+
+
+ CompileAfter
+
+
+
+
+
+
+ Compile
+
+
+
+
+
+
+ <_Parameter1>Pgvector.EntityFrameworkCore.VectorDesignTimeServices, Pgvector.EntityFrameworkCore
+ <_Parameter2>Npgsql.EntityFrameworkCore.PostgreSQL
+
+
+
+
+
+
+
+
diff --git a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs
index fc242e7..be09baa 100644
--- a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs
+++ b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs
@@ -66,65 +66,65 @@ public async Task Main()
var embedding = new Vector(new float[] { 1, 1, 1 });
var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
- Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
- Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray());
+ Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());
+ Assert.Equal([(Half)1, (Half)1, (Half)1], items[0].HalfEmbedding!.ToArray());
Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!);
- Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray());
+ Assert.Equal([1, 1, 1], items[0].SparseEmbedding!.ToArray());
// vector distance functions
items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
- Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
+ Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());
items = await ctx.Items.OrderBy(x => x.Embedding!.MaxInnerProduct(embedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);
items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
// halfvec distance functions
var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
// sparsevec distance functions
var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
// bit distance functions
var binaryEmbedding = new BitArray(new bool[] { true, false, true });
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync();
- Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
// additional
@@ -132,13 +132,49 @@ public async Task Main()
.OrderBy(x => x.Id)
.Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
.ToListAsync();
- Assert.Equal(new int[] { 1, 3 }, items.Select(v => v.Id).ToArray());
+ Assert.Equal([1, 3], items.Select(v => v.Id).ToArray());
var neighbors = await ctx.Items
.OrderBy(x => x.Embedding!.L2Distance(embedding))
.Select(x => new { Entity = x, Distance = x.Embedding!.L2Distance(embedding) })
.ToListAsync();
- Assert.Equal(new int[] { 1, 3, 2 }, neighbors.Select(v => v.Entity.Id).ToArray());
- Assert.Equal(new double[] { 0, 1, Math.Sqrt(3) }, neighbors.Select(v => v.Distance).ToArray());
+ Assert.Equal([1, 3, 2], neighbors.Select(v => v.Entity.Id).ToArray());
+ Assert.Equal([0, 1, Math.Sqrt(3)], neighbors.Select(v => v.Distance).ToArray());
+ }
+
+ [Theory]
+ [InlineData(typeof(Vector), null, "vector")]
+ [InlineData(typeof(Vector), 3, "vector(3)")]
+ [InlineData(typeof(HalfVector), null, "halfvec")]
+ [InlineData(typeof(HalfVector), 3, "halfvec(3)")]
+ [InlineData(typeof(SparseVector), null, "sparsevec")]
+ [InlineData(typeof(SparseVector), 3, "sparsevec(3)")]
+ public void By_StoreType(Type type, int? size, string expectedStoreType)
+ {
+ using var ctx = new ItemContext();
+ var typeMappingSource = ctx.GetService();
+
+ var typeMapping = typeMappingSource.FindMapping(type, storeTypeName: null, size: size)!;
+ Assert.Equal(expectedStoreType, typeMapping.StoreType);
+ Assert.Same(type, typeMapping.ClrType);
+ Assert.Equal(size, typeMapping.Size);
+ }
+
+ [Theory]
+ [InlineData("vector", typeof(Vector), null)]
+ [InlineData("vector(3)", typeof(Vector), 3)]
+ [InlineData("halfvec", typeof(HalfVector), null)]
+ [InlineData("halfvec(3)", typeof(HalfVector), 3)]
+ [InlineData("sparsevec", typeof(SparseVector), null)]
+ [InlineData("sparsevec(3)", typeof(SparseVector), 3)]
+ public void By_ClrType(string storeType, Type expectedType, int? expectedSize)
+ {
+ using var ctx = new ItemContext();
+ var typeMappingSource = ctx.GetService();
+
+ var typeMapping = typeMappingSource.FindMapping(storeType)!;
+ Assert.Equal(storeType, typeMapping.StoreType);
+ Assert.Same(expectedType, typeMapping.ClrType);
+ Assert.Equal(expectedSize, typeMapping.Size);
}
}