From 990f4206e0dcc14e05ad0c5c716ae23e51f7242e Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 24 Mar 2025 12:11:29 +0100 Subject: [PATCH 1/2] Improvements to EF type mapping * Support scaffolding of vector types * Implemented more proper support for the size facet * Consolidated different mappins and plugins into the same files * A bit of test code cleanup Closes #44 --- .../HalfvecTypeMapping.cs | 19 ----- .../HalfvecTypeMappingSourcePlugin.cs | 11 --- .../SparsevecTypeMapping.cs | 19 ----- .../SparsevecTypeMappingSourcePlugin.cs | 11 --- .../VectorDbContextOptionsExtension.cs | 2 - .../VectorTypeMapping.cs | 15 +++- .../VectorTypeMappingSourcePlugin.cs | 29 +++++++- .../EntityFrameworkCoreTests.cs | 74 ++++++++++++++----- 8 files changed, 92 insertions(+), 88 deletions(-) delete mode 100644 src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs delete mode 100644 src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs delete mode 100644 src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs delete mode 100644 src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs deleted file mode 100644 index d1d4bfa..0000000 --- a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs +++ /dev/null @@ -1,19 +0,0 @@ -using Microsoft.EntityFrameworkCore.Storage; -using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; -using NpgsqlTypes; - -namespace Pgvector.EntityFrameworkCore; - -public class HalfvecTypeMapping : RelationalTypeMapping -{ - public static HalfvecTypeMapping Default { get; } = new(); - - public HalfvecTypeMapping() : base("halfvec", typeof(HalfVector)) { } - - public HalfvecTypeMapping(string storeType) : base(storeType, typeof(HalfVector)) { } - - protected HalfvecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { } - - protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) - => new HalfvecTypeMapping(parameters); -} diff --git a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs deleted file mode 100644 index 663bf71..0000000 --- a/src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Microsoft.EntityFrameworkCore.Storage; - -namespace Pgvector.EntityFrameworkCore; - -public class HalfvecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin -{ - public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo) - => mappingInfo.ClrType == typeof(HalfVector) - ? new HalfvecTypeMapping(mappingInfo.StoreTypeName ?? "halfvec") - : null; -} diff --git a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs deleted file mode 100644 index 422a7cd..0000000 --- a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs +++ /dev/null @@ -1,19 +0,0 @@ -using Microsoft.EntityFrameworkCore.Storage; -using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; -using NpgsqlTypes; - -namespace Pgvector.EntityFrameworkCore; - -public class SparsevecTypeMapping : RelationalTypeMapping -{ - public static SparsevecTypeMapping Default { get; } = new(); - - public SparsevecTypeMapping() : base("sparsevec", typeof(SparseVector)) { } - - public SparsevecTypeMapping(string storeType) : base(storeType, typeof(SparseVector)) { } - - protected SparsevecTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { } - - protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters) - => new SparsevecTypeMapping(parameters); -} diff --git a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs deleted file mode 100644 index 290e391..0000000 --- a/src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Microsoft.EntityFrameworkCore.Storage; - -namespace Pgvector.EntityFrameworkCore; - -public class SparsevecTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin -{ - public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo) - => mappingInfo.ClrType == typeof(SparseVector) - ? new SparsevecTypeMapping(mappingInfo.StoreTypeName ?? "sparsevec") - : null; -} diff --git a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs index ca1af79..570e4dc 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs @@ -17,8 +17,6 @@ public void ApplyServices(IServiceCollection services) .TryAdd(); services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); } public void Validate(IDbContextOptions options) { } diff --git a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs index 0452c16..6ac4e0b 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs @@ -6,11 +6,18 @@ namespace Pgvector.EntityFrameworkCore; public class VectorTypeMapping : RelationalTypeMapping { - public static VectorTypeMapping Default { get; } = new(); + public static VectorTypeMapping Default { get; } = new("vector", typeof(Vector)); - public VectorTypeMapping() : base("vector", typeof(Vector)) { } - - public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector)) { } + public VectorTypeMapping(string storeType, Type clrType, int? size = null) + : this( + new RelationalTypeMappingParameters( + new CoreTypeMappingParameters(clrType), + storeType, + StoreTypePostfix.Size, + size: size, + fixedLength: size is not null)) + { + } protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { } diff --git a/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs b/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs index be2beab..d299068 100644 --- a/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs +++ b/src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs @@ -5,7 +5,30 @@ namespace Pgvector.EntityFrameworkCore; public class VectorTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin { public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo) - => mappingInfo.ClrType == typeof(Vector) - ? new VectorTypeMapping(mappingInfo.StoreTypeName ?? "vector") - : null; + { + if (mappingInfo.StoreTypeName is not null) + { + VectorTypeMapping? mapping = (mappingInfo.StoreTypeNameBase ?? mappingInfo.StoreTypeName) switch + { + "vector" => new(mappingInfo.StoreTypeName, typeof(Vector), mappingInfo.Size), + "halfvec" => new(mappingInfo.StoreTypeName, typeof(HalfVector), mappingInfo.Size), + "sparsevec" => new(mappingInfo.StoreTypeName, typeof(SparseVector), mappingInfo.Size), + _ => null, + }; + + // If the caller hasn't specified a CLR type (this is scaffolding), or if the user has specified + // the one matching the store type, return the mapping. + return mappingInfo.ClrType is null || mappingInfo.ClrType == mapping?.ClrType + ? mapping : null; + } + + // No store type specified, look up by the CLR type only + return mappingInfo.ClrType switch + { + var t when t == typeof(Vector) => new VectorTypeMapping("vector", typeof(Vector), mappingInfo.Size), + var t when t == typeof(HalfVector) => new VectorTypeMapping("halfvec", typeof(HalfVector), mappingInfo.Size), + var t when t == typeof(SparseVector) => new VectorTypeMapping("sparsevec", typeof(SparseVector), mappingInfo.Size), + _ => null, + }; + } } diff --git a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs index fc242e7..be09baa 100644 --- a/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs +++ b/tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs @@ -66,65 +66,65 @@ public async Task Main() var embedding = new Vector(new float[] { 1, 1, 1 }); var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); - Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); - Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray()); + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray()); + Assert.Equal([(Half)1, (Half)1, (Half)1], items[0].HalfEmbedding!.ToArray()); Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!); - Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray()); + Assert.Equal([1, 1, 1], items[0].SparseEmbedding!.ToArray()); // vector distance functions items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); - Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray()); + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray()); items = await ctx.Items.OrderBy(x => x.Embedding!.MaxInnerProduct(embedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync(); Assert.Equal(3, items[2].Id); items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); // halfvec distance functions var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 }); items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync(); Assert.Equal(3, items[2].Id); items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); // sparsevec distance functions var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 }); items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync(); Assert.Equal(3, items[2].Id); items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray()); // bit distance functions var binaryEmbedding = new BitArray(new bool[] { true, false, true }); items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync(); - Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray()); // additional @@ -132,13 +132,49 @@ public async Task Main() .OrderBy(x => x.Id) .Where(x => x.Embedding!.L2Distance(embedding) < 1.5) .ToListAsync(); - Assert.Equal(new int[] { 1, 3 }, items.Select(v => v.Id).ToArray()); + Assert.Equal([1, 3], items.Select(v => v.Id).ToArray()); var neighbors = await ctx.Items .OrderBy(x => x.Embedding!.L2Distance(embedding)) .Select(x => new { Entity = x, Distance = x.Embedding!.L2Distance(embedding) }) .ToListAsync(); - Assert.Equal(new int[] { 1, 3, 2 }, neighbors.Select(v => v.Entity.Id).ToArray()); - Assert.Equal(new double[] { 0, 1, Math.Sqrt(3) }, neighbors.Select(v => v.Distance).ToArray()); + Assert.Equal([1, 3, 2], neighbors.Select(v => v.Entity.Id).ToArray()); + Assert.Equal([0, 1, Math.Sqrt(3)], neighbors.Select(v => v.Distance).ToArray()); + } + + [Theory] + [InlineData(typeof(Vector), null, "vector")] + [InlineData(typeof(Vector), 3, "vector(3)")] + [InlineData(typeof(HalfVector), null, "halfvec")] + [InlineData(typeof(HalfVector), 3, "halfvec(3)")] + [InlineData(typeof(SparseVector), null, "sparsevec")] + [InlineData(typeof(SparseVector), 3, "sparsevec(3)")] + public void By_StoreType(Type type, int? size, string expectedStoreType) + { + using var ctx = new ItemContext(); + var typeMappingSource = ctx.GetService(); + + var typeMapping = typeMappingSource.FindMapping(type, storeTypeName: null, size: size)!; + Assert.Equal(expectedStoreType, typeMapping.StoreType); + Assert.Same(type, typeMapping.ClrType); + Assert.Equal(size, typeMapping.Size); + } + + [Theory] + [InlineData("vector", typeof(Vector), null)] + [InlineData("vector(3)", typeof(Vector), 3)] + [InlineData("halfvec", typeof(HalfVector), null)] + [InlineData("halfvec(3)", typeof(HalfVector), 3)] + [InlineData("sparsevec", typeof(SparseVector), null)] + [InlineData("sparsevec(3)", typeof(SparseVector), 3)] + public void By_ClrType(string storeType, Type expectedType, int? expectedSize) + { + using var ctx = new ItemContext(); + var typeMappingSource = ctx.GetService(); + + var typeMapping = typeMappingSource.FindMapping(storeType)!; + Assert.Equal(storeType, typeMapping.StoreType); + Assert.Same(expectedType, typeMapping.ClrType); + Assert.Equal(expectedSize, typeMapping.Size); } } From 747e6c62c4bbcde15d4a5656ceaa1611b0b8590e Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 25 Mar 2025 20:52:26 +0100 Subject: [PATCH 2/2] Add missing design-time infra --- .../Pgvector.EntityFrameworkCore.csproj | 4 ++ .../VectorCodeGeneratorPlugin.cs | 26 +++++++++++ .../VectorDesignTimeServices.cs | 14 ++++++ .../Pgvector.EntityFrameworkCore.targets | 46 +++++++++++++++++++ 4 files changed, 90 insertions(+) create mode 100644 src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs create mode 100644 src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs create mode 100644 src/Pgvector.EntityFrameworkCore/build/netstandard2.0/Pgvector.EntityFrameworkCore.targets diff --git a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj index f90f39f..604772c 100644 --- a/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj +++ b/src/Pgvector.EntityFrameworkCore/Pgvector.EntityFrameworkCore.csproj @@ -20,6 +20,10 @@ + + True + build + diff --git a/src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs b/src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs new file mode 100644 index 0000000..0abaf5a --- /dev/null +++ b/src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs @@ -0,0 +1,26 @@ +using System; +using System.Reflection; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Design; +using Microsoft.EntityFrameworkCore.Scaffolding; +using Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure; + +namespace Pgvector.EntityFrameworkCore; + +public class VectorCodeGeneratorPlugin : ProviderCodeGeneratorPlugin +{ + private static readonly MethodInfo _useVectorMethodInfo + = typeof(VectorDbContextOptionsBuilderExtensions).GetMethod( + nameof(VectorDbContextOptionsBuilderExtensions.UseVector), + [typeof(NpgsqlDbContextOptionsBuilder)])!; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override MethodCallCodeFragment GenerateProviderOptions() + => new(_useVectorMethodInfo); + +} diff --git a/src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs b/src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs new file mode 100644 index 0000000..23e6a96 --- /dev/null +++ b/src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs @@ -0,0 +1,14 @@ +using Microsoft.EntityFrameworkCore.Design; +using Microsoft.EntityFrameworkCore.Scaffolding; +using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.Extensions.DependencyInjection; + +namespace Pgvector.EntityFrameworkCore; + +public class VectorDesignTimeServices : IDesignTimeServices +{ + public virtual void ConfigureDesignTimeServices(IServiceCollection serviceCollection) + => serviceCollection + .AddSingleton() + .AddSingleton(); +} diff --git a/src/Pgvector.EntityFrameworkCore/build/netstandard2.0/Pgvector.EntityFrameworkCore.targets b/src/Pgvector.EntityFrameworkCore/build/netstandard2.0/Pgvector.EntityFrameworkCore.targets new file mode 100644 index 0000000..7850e59 --- /dev/null +++ b/src/Pgvector.EntityFrameworkCore/build/netstandard2.0/Pgvector.EntityFrameworkCore.targets @@ -0,0 +1,46 @@ + + + $(MSBuildAllProjects);$(MSBuildThisFileFullPath) + $(IntermediateOutputPath)EFCoreNpgsqlPgvector$(DefaultLanguageSourceExtension) + + + + + + + CompileBefore + + + + + CompileAfter + + + + + + + Compile + + + + + + + <_Parameter1>Pgvector.EntityFrameworkCore.VectorDesignTimeServices, Pgvector.EntityFrameworkCore + <_Parameter2>Npgsql.EntityFrameworkCore.PostgreSQL + + + + + + + +