diff --git a/src/Scrutor/RegistrationStrategy.cs b/src/Scrutor/RegistrationStrategy.cs index 48943849..d8503bec 100644 --- a/src/Scrutor/RegistrationStrategy.cs +++ b/src/Scrutor/RegistrationStrategy.cs @@ -6,10 +6,15 @@ namespace Scrutor; public abstract class RegistrationStrategy { /// - /// Skips registrations for services that already exists. + /// Appends a new registration when no registration exists for the same Service type. /// public static readonly RegistrationStrategy Skip = new SkipRegistrationStrategy(); + /// + /// Appends a new registration when no registration exists for the same Service and Implementation type. + /// + public static readonly RegistrationStrategy Distinct = new DistinctRegistrationStrategy(); + /// /// Appends a new registration for existing services. /// @@ -49,6 +54,26 @@ private sealed class SkipRegistrationStrategy : RegistrationStrategy public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.TryAdd(descriptor); } + private sealed class DistinctRegistrationStrategy : RegistrationStrategy { + /// + /// Adds the service descriptor if the service collection does not contain a desriptor with the same Service and Implementation type. + /// + /// The service collection. + /// The descriptor to apply. + /// + /// Unable to use + /// TryAddEnumerable() + /// since it would throw an ArgumentException when used with AsSelf(). + /// + public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) + { + if (services.HasRegistration(descriptor)) { + return; + } + services.Add(descriptor); + } + } + private sealed class AppendRegistrationStrategy : RegistrationStrategy { public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.Add(descriptor); diff --git a/src/Scrutor/ServiceCollectionExtensions.cs b/src/Scrutor/ServiceCollectionExtensions.cs index d4002c61..f03d8e4f 100644 --- a/src/Scrutor/ServiceCollectionExtensions.cs +++ b/src/Scrutor/ServiceCollectionExtensions.cs @@ -10,4 +10,15 @@ public static bool HasRegistration(this IServiceCollection services, Type servic { return services.Any(x => x.ServiceType == serviceType); } -} \ No newline at end of file + + /// + /// Determines whether the service collection has a descriptor with the same Service and Implementation types. + /// + /// The service collection. + /// The service descriptor. + /// true if the service collection contains the specified service descriptor; otherwise, false. + public static bool HasRegistration(this IServiceCollection services, ServiceDescriptor descriptor) + { + return services.Any(x => x.ServiceType == descriptor.ServiceType && x.ImplementationType == descriptor.ImplementationType); + } +} diff --git a/test/Scrutor.Tests/ScanningTests.cs b/test/Scrutor.Tests/ScanningTests.cs index 541f198f..976469a1 100644 --- a/test/Scrutor.Tests/ScanningTests.cs +++ b/test/Scrutor.Tests/ScanningTests.cs @@ -48,13 +48,30 @@ public void UsingRegistrationStrategy_None() } [Fact] - public void UsingRegistrationStrategy_SkipIfExists() + public void UsingRegistrationStrategy_Skip() { Collection.Scan(scan => scan .FromAssemblyOf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Skip) + .AsImplementedInterfaces() + .WithTransientLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(1, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_SkipAfterNone() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // registers 4 .AddClasses(classes => classes.AssignableTo()) .AsImplementedInterfaces() .WithTransientLifetime() + // no new registrations .AddClasses(classes => classes.AssignableTo()) .UsingRegistrationStrategy(RegistrationStrategy.Skip) .AsImplementedInterfaces() @@ -65,6 +82,82 @@ public void UsingRegistrationStrategy_SkipIfExists() Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); } + [Fact] + public void UsingRegistrationStrategy_Distinct() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() + .WithTransientLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_DistinctAfterSkip() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // registers 1 + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Skip) + .AsImplementedInterfaces() + .WithTransientLifetime() + // registers the other three + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() + .WithSingletonLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_DistinctAfterNone() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // register 4 + .AddClasses(classes => classes.AssignableTo()) + .AsImplementedInterfaces() + .WithTransientLifetime() + // no new registrations + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() + .WithSingletonLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_DistinctWithSelf() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // registers 9 + .AddClasses(classes => classes.AssignableTo()) + .AsImplementedInterfaces() + .AsSelf() + .WithTransientLifetime() + // no new registrations, and does not throw due to not using TryAddEnumerable() with AsSelf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() + .AsSelf() + .WithTransientLifetime()); + + Assert.Equal(9, Collection.Count); + } + [Fact] public void UsingRegistrationStrategy_ReplaceDefault() {