Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class ProjectableDescriptor

public ParameterListSyntax? ParametersList { get; set; }

public IEnumerable<string>? ParameterTypeNames { get; set; }

public TypeParameterListSyntax? TypeParameterList { get; set; }

public SyntaxList<TypeParameterConstraintClauseSyntax>? ConstraintClauses { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ x is IPropertySymbol xProperty &&
var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(memberSymbol.ContainingType, nullConditionalRewriteSupport, semanticModel, context);
var declarationSyntaxRewriter = new DeclarationSyntaxRewriter(semanticModel);

var methodSymbol = memberSymbol as IMethodSymbol;

var descriptor = new ProjectableDescriptor {

UsingDirectives = member.SyntaxTree.GetRoot().DescendantNodes().OfType<UsingDirectiveSyntax>(),
Expand All @@ -128,6 +130,14 @@ x is IPropertySymbol xProperty &&
ParametersList = SyntaxFactory.ParameterList()
};

// Collect parameter type names for method overload disambiguation
if (methodSymbol is not null)
{
descriptor.ParameterTypeNames = methodSymbol.Parameters
.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
.ToList();
}

if (memberSymbol.ContainingType is INamedTypeSymbol { IsGenericType: true } containingNamedType)
{
descriptor.ClassTypeParameterList = SyntaxFactory.TypeParameterList();
Expand Down Expand Up @@ -196,8 +206,6 @@ x is IPropertySymbol xProperty &&
);
}

var methodSymbol = memberSymbol as IMethodSymbol;

if (methodSymbol is { IsExtensionMethod: true })
{
var targetTypeSymbol = methodSymbol.Parameters.First().Type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static void Execute(MemberDeclarationSyntax member, Compilation compilation, Sou
throw new InvalidOperationException("Expected a memberName here");
}

var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName, projectable.ParameterTypeNames);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";

var classSyntax = ClassDeclaration(generatedClassName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,31 @@ public static class ProjectionExpressionClassNameGenerator
public const string Namespace = "EntityFrameworkCore.Projectables.Generated";

public static string GenerateName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName)
{
return GenerateName(namespaceName, nestedInClassNames, memberName, null);
}

public static string GenerateName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName, IEnumerable<string>? parameterTypeNames)
{
var stringBuilder = new StringBuilder();

return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName);
return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName, parameterTypeNames);
}

public static string GenerateFullName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName)
{
return GenerateFullName(namespaceName, nestedInClassNames, memberName, null);
}

public static string GenerateFullName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName, IEnumerable<string>? parameterTypeNames)
{
var stringBuilder = new StringBuilder(Namespace);
stringBuilder.Append('.');

return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName);
return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName, parameterTypeNames);
}

static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName)
static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName, IEnumerable<string>? parameterTypeNames)
{
stringBuilder.Append(namespaceName?.Replace('.', '_'));
stringBuilder.Append('_');
Expand Down Expand Up @@ -57,6 +67,35 @@ static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceNam
}
stringBuilder.Append(memberName);

// Add parameter types to make method overloads unique
if (parameterTypeNames is not null)
{
var parameterIndex = 0;
foreach (var parameterTypeName in parameterTypeNames)
{
stringBuilder.Append("_P");
stringBuilder.Append(parameterIndex);
stringBuilder.Append('_');
// Replace characters that are not valid in type names with underscores
var sanitizedTypeName = parameterTypeName
.Replace("global::", "") // Remove global:: prefix
.Replace('.', '_')
.Replace('<', '_')
.Replace('>', '_')
.Replace(',', '_')
.Replace(' ', '_')
.Replace('[', '_')
.Replace(']', '_')
.Replace('`', '_')
.Replace(':', '_') // Additional safety for any remaining colons
.Replace('?', '_'); // Handle nullable reference types
stringBuilder.Append(sanitizedTypeName);
parameterIndex++;
}
}

// Add generic arity at the very end (after parameter types)
// This matches how the CLR names generic types
if (arity > 0)
{
stringBuilder.Append('`');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,41 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo
static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo)
{
var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here");
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);

// Keep track of the original declaring type's generic arguments for later use
var originalDeclaringType = declaringType;

// For generic types, use the generic type definition to match the generated name
// which is based on the open generic type
if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition)
{
declaringType = declaringType.GetGenericTypeDefinition();
}

// Get parameter types for method overload disambiguation
// Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat
// which uses C# keywords for primitive types (int, string, etc.)
string[]? parameterTypeNames = null;
if (projectableMemberInfo is MethodInfo method)
{
// For generic methods, use the generic definition to get parameter types
// This ensures type parameters like TEntity are used instead of concrete types
var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method;

parameterTypeNames = methodToInspect.GetParameters()
.Select(p => GetFullTypeName(p.ParameterType))
.ToArray();
}

var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name, parameterTypeNames);

var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName);

if (expressionFactoryType is not null)
{
if (expressionFactoryType.IsGenericTypeDefinition)
{
expressionFactoryType = expressionFactoryType.MakeGenericType(declaringType.GenericTypeArguments);
expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments);
}

var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic);
Expand All @@ -93,6 +119,99 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo

return null;
}

static string GetFullTypeName(Type type)
{
// Handle generic type parameters (e.g., T, TEntity)
if (type.IsGenericParameter)
{
return type.Name;
}

// Handle nullable value types (e.g., int? -> int?)
var underlyingType = Nullable.GetUnderlyingType(type);
if (underlyingType != null)
{
return $"{GetFullTypeName(underlyingType)}?";
}

// Handle array types
if (type.IsArray)
{
var elementType = type.GetElementType();
if (elementType == null)
{
// Fallback for edge cases where GetElementType() might return null
return type.Name;
}

var rank = type.GetArrayRank();
var elementTypeName = GetFullTypeName(elementType);

if (rank == 1)
{
return $"{elementTypeName}[]";
}
else
{
var commas = new string(',', rank - 1);
return $"{elementTypeName}[{commas}]";
}
}

// Map primitive types to their C# keyword equivalents to match Roslyn's output
var typeKeyword = GetCSharpKeyword(type);
if (typeKeyword != null)
{
return typeKeyword;
}

// For generic types, construct the full name matching Roslyn's format
if (type.IsGenericType)
{
var genericTypeDef = type.GetGenericTypeDefinition();
var genericArgs = type.GetGenericArguments();
var baseName = genericTypeDef.FullName ?? genericTypeDef.Name;

// Remove the `n suffix (e.g., `1, `2)
var backtickIndex = baseName.IndexOf('`');
if (backtickIndex > 0)
{
baseName = baseName.Substring(0, backtickIndex);
}

var args = string.Join(", ", genericArgs.Select(GetFullTypeName));
return $"{baseName}<{args}>";
}

if (type.FullName != null)
{
// Replace + with . for nested types to match Roslyn's format
return type.FullName.Replace('+', '.');
}

return type.Name;
}

static string? GetCSharpKeyword(Type type)
{
if (type == typeof(bool)) return "bool";
if (type == typeof(byte)) return "byte";
if (type == typeof(sbyte)) return "sbyte";
if (type == typeof(char)) return "char";
if (type == typeof(decimal)) return "decimal";
if (type == typeof(double)) return "double";
if (type == typeof(float)) return "float";
if (type == typeof(int)) return "int";
if (type == typeof(uint)) return "uint";
if (type == typeof(long)) return "long";
if (type == typeof(ulong)) return "ulong";
if (type == typeof(short)) return "short";
if (type == typeof(ushort)) return "ushort";
if (type == typeof(object)) return "object";
if (type == typeof(string)) return "string";
return null;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT [t].[Id]
FROM [TestEntity] AS [t]
WHERE 1 > [t].[Id] AND NEWID() <> '00000000-0000-0000-0000-000000000000'
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT [t].[Id]
FROM [TestEntity] AS [t]
WHERE 1 > [t].[Id] AND NEWID() <> '00000000-0000-0000-0000-000000000000'
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT [t].[Id]
FROM [TestEntity] AS [t]
WHERE 1 > [t].[Id] AND NEWID() <> '00000000-0000-0000-0000-000000000000'
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public class TestEntity

[Projectable]
public bool IsValid3(params int[] validIds) => validIds.Contains(Id);

[Projectable]
public bool IsValid4(int? validId, Guid? guid) => validId > Id && guid != Guid.Parse("00000000-0000-0000-0000-000000000000");
}

[Fact]
Expand Down Expand Up @@ -52,8 +55,7 @@ public Task ArrayOfPrimitivesArguments()

return Verifier.Verify(query.ToQueryString());
}



[Fact]
public Task ParamsOfPrimitivesArguments()
{
Expand All @@ -64,5 +66,16 @@ public Task ParamsOfPrimitivesArguments()

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task NullableValueTypes()
{
using var dbContext = new SampleDbContext<TestEntity>();

var query = dbContext.Set<TestEntity>()
.Where(x => x.IsValid4(1, Guid.NewGuid()));

return Verifier.Verify(query.ToQueryString());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [g].[Id]
FROM [GenericObject<int>] AS [g]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [g].[Id]
FROM [GenericObject<int>] AS [g]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [g].[Id]
FROM [GenericObject<int>] AS [g]
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public class Order

public DateTime RecordDate { get; set; }
}

public class GenericObject<T>
{
public T Id { get; set; }
}

[Fact]
public Task ProjectOverNavigationProperty()
Expand Down Expand Up @@ -98,5 +103,16 @@ public Task ProjectQueryFilters()

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task ProjectOverGenericType()
{
using var dbContext = new SampleDbContext<GenericObject<int>>();

var query = dbContext.Set<GenericObject<int>>()
.Select(x => x.Id);

return Verifier.Verify(query.ToQueryString());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], [e].[Id] + 10 AS [Result]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], [e].[Id] + 10 AS [Result]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], [e].[Id] + 10 AS [Result]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], CAST(LEN(N'Hello_' + [e].[Name]) AS int) AS [Result]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], CAST(LEN(N'Hello_' + [e].[Name]) AS int) AS [Result]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], CAST(LEN(N'Hello_' + [e].[Name]) AS int) AS [Result]
FROM [Entity] AS [e]
Loading