diff --git a/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs b/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs index 0e86de15b3..6e0ffde6fe 100644 --- a/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs +++ b/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs @@ -282,10 +282,8 @@ private static TestPropertyModel BuildProperty(IPropertySymbol property) Attributes: BuildAttributes(CollectInheritedAttributes(property))); // Mirror the runtime behavior of MemberInfo.GetCustomAttributes(inherit: true): walk the - // overridden-method chain and union attributes, keeping the most-derived application when - // the same attribute type appears on multiple levels — but respect - // [AttributeUsage(Inherited = false)] (the attribute is NOT visible past the level it was - // declared on) and [AttributeUsage(AllowMultiple = true)] (every occurrence is kept). + // overridden-member chain, honor AttributeUsageAttribute.Inherited, and keep only the + // most-derived application for attributes that do not allow multiple instances. private static ImmutableArray CollectInheritedAttributes(IMethodSymbol method) { ImmutableArray own = method.GetAttributes(); @@ -296,10 +294,10 @@ private static ImmutableArray CollectInheritedAttributes(IMethodS var seen = new HashSet(StringComparer.Ordinal); ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); - AppendUnique(builder, seen, own, isInheritedLevel: false); + AppendAttributes(builder, seen, own, inheritedOnly: false); for (IMethodSymbol? baseMethod = method.OverriddenMethod; baseMethod is not null; baseMethod = baseMethod.OverriddenMethod) { - AppendUnique(builder, seen, baseMethod.GetAttributes(), isInheritedLevel: true); + AppendAttributes(builder, seen, baseMethod.GetAttributes(), inheritedOnly: true); } return builder.ToImmutable(); @@ -315,20 +313,20 @@ private static ImmutableArray CollectInheritedAttributes(IPropert var seen = new HashSet(StringComparer.Ordinal); ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); - AppendUnique(builder, seen, own, isInheritedLevel: false); + AppendAttributes(builder, seen, own, inheritedOnly: false); for (IPropertySymbol? baseProperty = property.OverriddenProperty; baseProperty is not null; baseProperty = baseProperty.OverriddenProperty) { - AppendUnique(builder, seen, baseProperty.GetAttributes(), isInheritedLevel: true); + AppendAttributes(builder, seen, baseProperty.GetAttributes(), inheritedOnly: true); } return builder.ToImmutable(); } - private static void AppendUnique( + private static void AppendAttributes( ImmutableArray.Builder builder, HashSet seen, ImmutableArray attributes, - bool isInheritedLevel) + bool inheritedOnly) { foreach (AttributeData attribute in attributes) { @@ -337,66 +335,83 @@ private static void AppendUnique( continue; } - (bool allowMultiple, bool inherited) = GetAttributeUsage(attributeClass); - - // A base-level attribute declared with AttributeUsage(Inherited = false) must - // not leak onto the derived override (matches MemberInfo.GetCustomAttributes(inherit: true)). - if (isInheritedLevel && !inherited) + AttributeUsageMetadata usage = GetAttributeUsage(attributeClass); + if (inheritedOnly && !usage.Inherited) { continue; } - // Attributes declared with AttributeUsage(AllowMultiple = true) may legitimately - // appear several times across the override chain (e.g. [TestCategory], [DataRow]) - // — keep every instance instead of collapsing them to one. - if (allowMultiple) + string key = attributeClass.ToDisplayString(FullyQualifiedFormat); + if (usage.AllowMultiple || seen.Add(key)) { builder.Add(attribute); - continue; } + } + } - string key = attributeClass.ToDisplayString(FullyQualifiedFormat); - if (seen.Add(key)) + private static AttributeUsageMetadata GetAttributeUsage(INamedTypeSymbol attributeClass) + { + bool inherited = true; + bool allowMultiple = false; + + // [AttributeUsage] is itself inherited (its own AttributeUsage declares Inherited=true). + // Roslyn's GetAttributes() does NOT walk the base-type chain, so we have to walk it + // ourselves to honor an [AttributeUsage] declared on a base attribute type (e.g. when + // a user-defined attribute derives from one of MSTest's attributes without re-declaring + // its own [AttributeUsage]). + for (INamedTypeSymbol? current = attributeClass; + current is not null && current.SpecialType != SpecialType.System_Object; + current = current.BaseType) + { + if (TryReadAttributeUsage(current, out bool currentInherited, out bool currentAllowMultiple)) { - builder.Add(attribute); + inherited = currentInherited; + allowMultiple = currentAllowMultiple; + break; } } + + return new AttributeUsageMetadata(inherited, allowMultiple); } - private static (bool AllowMultiple, bool Inherited) GetAttributeUsage(INamedTypeSymbol attributeClass) + private static bool TryReadAttributeUsage(INamedTypeSymbol attributeClass, out bool inherited, out bool allowMultiple) { - for (INamedTypeSymbol? current = attributeClass; current is not null; current = current.BaseType) + inherited = true; + allowMultiple = false; + + foreach (AttributeData attribute in attributeClass.GetAttributes()) { - foreach (AttributeData attribute in current.GetAttributes()) + if (attribute.AttributeClass?.ToDisplayString(FullyQualifiedFormat) != "global::System.AttributeUsageAttribute") { - if (attribute.AttributeClass?.ToDisplayString(FullyQualifiedFormat) != "global::System.AttributeUsageAttribute") + continue; + } + + foreach (KeyValuePair namedArgument in attribute.NamedArguments) + { + if (namedArgument.Value.Value is not bool value) { continue; } - bool allowMultiple = false; - bool inherited = true; - foreach (KeyValuePair named in attribute.NamedArguments) + switch (namedArgument.Key) { - if (named.Key == "AllowMultiple" && named.Value.Value is bool am) - { - allowMultiple = am; - } - else if (named.Key == "Inherited" && named.Value.Value is bool inh) - { - inherited = inh; - } + case nameof(AttributeUsageAttribute.Inherited): + inherited = value; + break; + case nameof(AttributeUsageAttribute.AllowMultiple): + allowMultiple = value; + break; } - - // [AttributeUsage] on a derived attribute class shadows the base per CLI rules; - // stop at the first level where it is found. - return (allowMultiple, inherited); } + + return true; } - return (AllowMultiple: false, Inherited: true); + return false; } + private readonly record struct AttributeUsageMetadata(bool Inherited, bool AllowMultiple); + private static EquatableArray BuildParameters(IMethodSymbol method) { if (method.Parameters.IsDefaultOrEmpty) diff --git a/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTestReflectionMetadataGeneratorTests.cs b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTestReflectionMetadataGeneratorTests.cs index 5519ff6119..6f2f9dc15f 100644 --- a/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTestReflectionMetadataGeneratorTests.cs +++ b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTestReflectionMetadataGeneratorTests.cs @@ -36,7 +36,7 @@ public TestMethodAttribute() { } public string? DisplayName { get; set; } } - [System.AttributeUsage(System.AttributeTargets.Class | System.AttributeTargets.Method, AllowMultiple = true)] + [System.AttributeUsage(System.AttributeTargets.Class | System.AttributeTargets.Method, AllowMultiple = true, Inherited = true)] public class TestCategoryAttribute : System.Attribute { public TestCategoryAttribute(string category) { Category = category; } @@ -403,6 +403,51 @@ public void Sync(int x) { } "the generated source MUST compile cleanly when consumed in the same compilation as the user code"); } + [TestMethod] + public void Generator_SkipsProtectedMembers() + { + const string userCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + namespace Sample + { + public class TestContext { } + + [TestClass] + public class ProtectedShapes + { + [TestContext] + protected TestContext? Context { get; set; } + + [TestMethod] + protected void ProtectedTest() { } + + [TestMethod] + private protected void PrivateProtectedTest() { } + + [TestMethod] + protected internal void ProtectedInternalTest() { } + } + } + """; + + Compilation outputCompilation = RunGeneratorAndGetCompilation(MinimalMSTestStub, userCode); + string registry = outputCompilation + .SyntaxTrees + .Single(t => t.FilePath.EndsWith("MSTestReflectionMetadata.Registry.g.cs", System.StringComparison.Ordinal)) + .ToString(); + + registry.Should().NotContain("ProtectedTest"); + registry.Should().NotContain("PrivateProtectedTest"); + registry.Should().NotContain("Context"); + registry.Should().Contain("ProtectedInternalTest"); + + IEnumerable errors = outputCompilation + .GetDiagnostics() + .Where(d => d.Severity == DiagnosticSeverity.Error); + errors.Should().BeEmpty("the registry can only call members accessible from a non-derived type in the same assembly"); + } + [TestMethod] public void Generator_StripsNullableAnnotation_FromTypeofExpressions() { @@ -681,9 +726,6 @@ public virtual void Run() { } [TestClass] public class DerivedTests : BaseTests { - // [TestMethod] is re-applied here because the real attribute is declared - // with AttributeUsage(Inherited = false) and would not be inherited. - [TestMethod] public override void Run() { } } } @@ -696,7 +738,43 @@ public override void Run() { } runEntries.Should().Be(1, "the derived override must replace the base entry (not duplicate it)"); registry.Should().Contain("((global::Sample.DerivedTests)instance!).Run();"); registry.Should().NotContain("((global::Sample.BaseTests)instance!).Run();"); - registry.Should().Contain("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestMethodAttribute"); + + // TestMethodAttribute is not inherited, so the override should not pick up the base attribute. + registry.Should().NotContain("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestMethodAttribute"); + } + + [TestMethod] + public void Generator_OverriddenVirtualMethod_HonorsInheritedAttributeUsage() + { + const string userCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + + namespace Sample + { + public class BaseTests + { + [TestMethod] + [TestCategory("Base")] + [DataRow(1)] + public virtual void Run(int value) { } + } + + [TestClass] + public class DerivedTests : BaseTests + { + [TestMethod] + [TestCategory("Derived")] + public override void Run(int value) { } + } + } + """; + + string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode)); + + registry.Should().Contain("\"Base\""); + registry.Should().Contain("\"Derived\""); + registry.Should().Contain("DataRows = Array.Empty()"); + registry.Should().NotContain("new object?[] { 1 }"); } [TestMethod] @@ -1221,7 +1299,51 @@ public override void Run() { } } [TestMethod] - public void Generator_DistinguishesGenericArity_BetweenSameNamedMethods() + public void Generator_HonorsAttributeUsage_DeclaredOnBaseAttributeType() + { + // GetAttributeUsage MUST walk the attribute type's base-type chain. AllowMultiple is + // inherited from a base attribute that declares [AttributeUsage(AllowMultiple = true)] + // even when the derived attribute does not redeclare its own [AttributeUsage]. If the + // walk were skipped, the fallback default AllowMultiple=false would silently drop the + // second occurrence below. + const string userCode = """ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using System; + + namespace Sample + { + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + public class BaseTagAttribute : Attribute + { + public BaseTagAttribute(string value) { Value = value; } + public string Value { get; } + } + + // No [AttributeUsage] here on purpose — must inherit AllowMultiple=true from BaseTagAttribute. + public class DerivedTagAttribute : BaseTagAttribute + { + public DerivedTagAttribute(string value) : base(value) { } + } + + [TestClass] + public class Tests + { + [TestMethod] + [DerivedTag("first")] + [DerivedTag("second")] + public void Run() { } + } + } + """; + + string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode)); + + registry.Should().Contain("\"first\""); + registry.Should().Contain("\"second\""); + } + + [TestMethod] + public void Generator_ReportsAndSkipsGenericMethods_WithSameName() { // Methods that differ only in generic arity (e.g. M() vs M()) MUST be treated as // distinct in the per-class dedup key, otherwise the generator might drop the @@ -1251,19 +1373,14 @@ public void Run() { } GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode); - // The two generic overloads each emit AOTSG0004; the non-generic Run is supported. - result.Diagnostics.Where(d => d.Id == "AOTSG0004").Should().HaveCount(2); - Compilation outputCompilation = RunGeneratorAndGetCompilation(MinimalMSTestStub, userCode); IEnumerable errors = outputCompilation .GetDiagnostics() .Where(d => d.Severity == DiagnosticSeverity.Error); errors.Should().BeEmpty(); - string registry = outputCompilation - .SyntaxTrees - .Single(t => t.FilePath.EndsWith("MSTestReflectionMetadata.Registry.g.cs", System.StringComparison.Ordinal)) - .ToString(); + result.Diagnostics.Where(d => d.Id == "AOTSG0004").Should().HaveCount(2); + string registry = GetRegistry(result); // Only the non-generic Run must be emitted. The generic overloads are excluded by AOTSG0004 // but must not cause the non-generic sibling to be dropped through key collision.