diff --git a/src/Scrutor/RegistrationStrategy.cs b/src/Scrutor/RegistrationStrategy.cs
index 48943849..d8503bec 100644
--- a/src/Scrutor/RegistrationStrategy.cs
+++ b/src/Scrutor/RegistrationStrategy.cs
@@ -6,10 +6,15 @@ namespace Scrutor;
public abstract class RegistrationStrategy
{
///
- /// Skips registrations for services that already exists.
+ /// Appends a new registration when no registration exists for the same Service type.
///
public static readonly RegistrationStrategy Skip = new SkipRegistrationStrategy();
+ ///
+ /// Appends a new registration when no registration exists for the same Service and Implementation type.
+ ///
+ public static readonly RegistrationStrategy Distinct = new DistinctRegistrationStrategy();
+
///
/// Appends a new registration for existing services.
///
@@ -49,6 +54,26 @@ private sealed class SkipRegistrationStrategy : RegistrationStrategy
public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.TryAdd(descriptor);
}
+ private sealed class DistinctRegistrationStrategy : RegistrationStrategy {
+ ///
+ /// Adds the service descriptor if the service collection does not contain a desriptor with the same Service and Implementation type.
+ ///
+ /// The service collection.
+ /// The descriptor to apply.
+ ///
+ /// Unable to use
+ /// TryAddEnumerable()
+ /// since it would throw an ArgumentException when used with AsSelf().
+ ///
+ public override void Apply(IServiceCollection services, ServiceDescriptor descriptor)
+ {
+ if (services.HasRegistration(descriptor)) {
+ return;
+ }
+ services.Add(descriptor);
+ }
+ }
+
private sealed class AppendRegistrationStrategy : RegistrationStrategy
{
public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.Add(descriptor);
diff --git a/src/Scrutor/ServiceCollectionExtensions.cs b/src/Scrutor/ServiceCollectionExtensions.cs
index d4002c61..f03d8e4f 100644
--- a/src/Scrutor/ServiceCollectionExtensions.cs
+++ b/src/Scrutor/ServiceCollectionExtensions.cs
@@ -10,4 +10,15 @@ public static bool HasRegistration(this IServiceCollection services, Type servic
{
return services.Any(x => x.ServiceType == serviceType);
}
-}
\ No newline at end of file
+
+ ///
+ /// Determines whether the service collection has a descriptor with the same Service and Implementation types.
+ ///
+ /// The service collection.
+ /// The service descriptor.
+ /// true if the service collection contains the specified service descriptor; otherwise, false.
+ public static bool HasRegistration(this IServiceCollection services, ServiceDescriptor descriptor)
+ {
+ return services.Any(x => x.ServiceType == descriptor.ServiceType && x.ImplementationType == descriptor.ImplementationType);
+ }
+}
diff --git a/test/Scrutor.Tests/ScanningTests.cs b/test/Scrutor.Tests/ScanningTests.cs
index 541f198f..976469a1 100644
--- a/test/Scrutor.Tests/ScanningTests.cs
+++ b/test/Scrutor.Tests/ScanningTests.cs
@@ -48,13 +48,30 @@ public void UsingRegistrationStrategy_None()
}
[Fact]
- public void UsingRegistrationStrategy_SkipIfExists()
+ public void UsingRegistrationStrategy_Skip()
{
Collection.Scan(scan => scan
.FromAssemblyOf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Skip)
+ .AsImplementedInterfaces()
+ .WithTransientLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(1, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_SkipAfterNone()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // registers 4
.AddClasses(classes => classes.AssignableTo())
.AsImplementedInterfaces()
.WithTransientLifetime()
+ // no new registrations
.AddClasses(classes => classes.AssignableTo())
.UsingRegistrationStrategy(RegistrationStrategy.Skip)
.AsImplementedInterfaces()
@@ -65,6 +82,82 @@ public void UsingRegistrationStrategy_SkipIfExists()
Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
}
+ [Fact]
+ public void UsingRegistrationStrategy_Distinct()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
+ .WithTransientLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_DistinctAfterSkip()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // registers 1
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Skip)
+ .AsImplementedInterfaces()
+ .WithTransientLifetime()
+ // registers the other three
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
+ .WithSingletonLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_DistinctAfterNone()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // register 4
+ .AddClasses(classes => classes.AssignableTo())
+ .AsImplementedInterfaces()
+ .WithTransientLifetime()
+ // no new registrations
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
+ .WithSingletonLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_DistinctWithSelf()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // registers 9
+ .AddClasses(classes => classes.AssignableTo())
+ .AsImplementedInterfaces()
+ .AsSelf()
+ .WithTransientLifetime()
+ // no new registrations, and does not throw due to not using TryAddEnumerable() with AsSelf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
+ .AsSelf()
+ .WithTransientLifetime());
+
+ Assert.Equal(9, Collection.Count);
+ }
+
[Fact]
public void UsingRegistrationStrategy_ReplaceDefault()
{