Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ private static IReadOnlyList<ValueExpression> GetCtorArgs(

private static ModelProvider? GetModelToInstantiateForFactoryMethod(ModelProvider modelProvider)
{
if (modelProvider is SystemObjectModelProvider)
{
return null;
}

var fullConstructor = modelProvider.FullConstructor;
if (modelProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)
|| fullConstructor.Signature.Parameters.Any(p => !p.Type.IsPublic))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ protected internal override PropertyProvider[] BuildProperties()
var properties = new List<PropertyProvider>(propertiesCount + 1);
Dictionary<string, InputModelProperty> baseProperties = [];
HashSet<string> skippedBasePropertyNames = [];
HashSet<string> skippedBasePropertySerializedNames = [];
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
foreach (var baseProperty in baseModelProvider._inputModel.Properties)
Expand All @@ -546,6 +547,10 @@ protected internal override PropertyProvider[] BuildProperties()
if (baseModelProvider.ShouldSkipDerivedModelProperties)
{
skippedBasePropertyNames.Add(baseProperty.Name);
if (baseProperty.SerializedName is not null)
{
skippedBasePropertySerializedNames.Add(baseProperty.SerializedName);
}
}
else
{
Expand All @@ -564,11 +569,21 @@ protected internal override PropertyProvider[] BuildProperties()
{
var property = _inputModel.Properties[i];
var isDiscriminator = IsDiscriminator(property);
var isSkippedBaseProperty = skippedBasePropertyNames.Contains(property.Name)
|| (property.SerializedName is not null && skippedBasePropertySerializedNames.Contains(property.SerializedName));

// Skip discriminator properties that already exist in the base class
// Check both by C# property name and by serialized name to handle cases where
// the derived model has a discriminator with a different C# name but the same wire name
if (isDiscriminator && (baseProperties.ContainsKey(property.Name) || skippedBasePropertyNames.Contains(property.Name) || (property.SerializedName is not null && baseDiscriminatorSerializedNames.Contains(property.SerializedName))))
if (isDiscriminator &&
(baseProperties.ContainsKey(property.Name) ||
isSkippedBaseProperty ||
(property.SerializedName is not null && baseDiscriminatorSerializedNames.Contains(property.SerializedName))))
{
continue;
}

if (!isDiscriminator && isSkippedBaseProperty)
{
continue;
}
Expand Down Expand Up @@ -609,11 +624,6 @@ protected internal override PropertyProvider[] BuildProperties()
outputProperty.Modifiers |= MethodSignatureModifiers.Virtual;
}
}
if (skippedBasePropertyNames.Contains(property.Name))
{
continue;
}

if (baseProperties.TryGetValue(property.Name, out var baseProperty))
{
if (DomainEqual(baseProperty, property))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ protected override string BuildRelativeFilePath()
/// <inheritdoc/>
protected override bool ShouldSkipDerivedModelProperties => true;

/// <inheritdoc/>
protected override bool CanUpdateIdentity => false;

/// <inheritdoc/>
public override bool ShouldSkipDerivedSerializationMethodOverrides => true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,20 +606,25 @@ public void Update(
_attributes = (attributes as IReadOnlyList<AttributeStatement>) ?? [.. attributes];
}

if (name != null)
if (CanUpdateIdentity)
{
ResetMembersBasedOnIdentityChange(name);
}
if (name != null)
{
ResetMembersBasedOnIdentityChange(name);
}

if (@namespace != null)
{
ResetMembersBasedOnIdentityChange(@namespace: @namespace);
if (@namespace != null)
{
ResetMembersBasedOnIdentityChange(@namespace: @namespace);
}
}

// Rebuild the canonical view
_canonicalView = new(BuildCanonicalView);
}

protected virtual bool CanUpdateIdentity => true;

private void ResetMembersBasedOnIdentityChange(string? name = null, string? @namespace = null)
{
// Reset the custom code view to reflect the new namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,22 @@ protected internal TypeFactory()
CodeModelGenerator.Instance.AddTypeToKeep(modelProvider);
}

CSharpTypeMap[modelProvider.Type] = modelProvider;
TypeProvidersByName[modelProvider.Type.Name] = modelProvider;
RegisterModelProvider(modelProvider);
}
return modelProvider;
}

private void RegisterModelProvider(ModelProvider modelProvider)
{
CSharpTypeMap[modelProvider.Type] = modelProvider;
TypeProvidersByName[modelProvider.Type.Name] = modelProvider;

if (modelProvider is SystemObjectModelProvider systemObjectModelProvider)
{
CSharpTypeMap[systemObjectModelProvider.SystemType] = systemObjectModelProvider;
}
}

protected virtual ModelProvider? CreateModelCore(InputModelType model) => new ModelProvider(model);

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ public void SkipInternalModels()
Assert.AreEqual(ModelList.Length - ModelList.Where(m => m.Access == "internal").Count(), modelFactory.Methods.Count);
}

[Test]
public void SkipSystemObjectModelProvider()
{
var inputModel = InputFactory.Model("Resource", properties: [], access: "public");
var systemType = new CSharpType(
"ResourceData",
"TestFramework",
isValueType: false,
isNullable: false,
declaringType: null,
args: Array.Empty<CSharpType>(),
isPublic: true,
isStruct: false);
_instance = MockHelpers.LoadMockGenerator(
inputNamespaceName: "Sample.Namespace",
inputModelTypes: [inputModel],
createModelCore: model => new SystemObjectModelProvider(systemType, model)).Object;

var modelFactory = _instance!.OutputLibrary.ModelFactory.Value;

Assert.IsFalse(modelFactory.Methods.Any(m => m.Signature.Name == "ResourceData"));
}

[Test]
public void ListParamShape()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,41 @@ public void DerivedModel_OnlySkipsPropertiesFromSkippingBaseProvider()
Assert.IsTrue(propertyNames.Contains("Location"));
}

[Test]
public void DerivedModel_SkipsPropertiesWithSameWireNameDefinedInSystemObjectBase()
{
var baseProp = InputFactory.Property("Name", InputPrimitiveType.String, serializedName: "name");
var baseInputModel = InputFactory.Model("Resource", properties: [baseProp]);

var derivedNameProp = InputFactory.Property("Name0", InputPrimitiveType.String, serializedName: "name");
var derivedLocationProp = InputFactory.Property("Location", InputPrimitiveType.String, serializedName: "location");
var derivedInputModel = InputFactory.Model(
"TrackedResource",
properties: [derivedNameProp, derivedLocationProp],
baseModel: baseInputModel);

var systemType = CreateSystemCSharpType("ResourceData", "TestFramework");
MockHelpers.LoadMockGenerator(
inputModelTypes: [baseInputModel, derivedInputModel],
createModelCore: (model) =>
{
if (model.Name == "Resource")
return new SystemObjectModelProvider(systemType, model);
return new ModelProvider(model);
});

var derived = CodeModelGenerator.Instance.TypeFactory.CreateModel(derivedInputModel) as ModelProvider;
Assert.IsNotNull(derived);

var propertyNames = derived!.Properties.Select(p => p.Name).ToList();
Assert.IsFalse(propertyNames.Contains("Name0"),
"Property 'Name0' should be skipped because its wire name is defined in the SystemObjectModelProvider base");
Assert.IsTrue(propertyNames.Contains("Location"),
"Property 'Location' should be generated because its wire name is NOT in the base");
Assert.IsFalse(derived.FullConstructor.Signature.Parameters.Any(p => p.Name == "name0"),
"Skipped wire-name duplicate properties should not be emitted as constructor parameters");
}

[Test]
public void RegularBaseModel_DoesNotSkipMatchingProperties()
{
Expand Down Expand Up @@ -336,6 +371,44 @@ public void Name_ComesFromSystemType_WhenTypeNotEarlyCached()
Assert.AreEqual("Azure.ResourceManager.Models", provider.Type.Namespace);
}

[Test]
public void Update_DoesNotChangeSystemTypeIdentity()
{
var systemType = CreateSystemCSharpType("TrackedResourceData", "Azure.ResourceManager.Models");
var inputModel = InputFactory.Model("TrackedResource", properties: [], access: "internal");
var provider = new SystemObjectModelProvider(systemType, inputModel);

provider.Update(name: "TrackedResource", @namespace: "Generated.Models");

Assert.AreEqual("TrackedResourceData", provider.Name);
Assert.AreEqual("Azure.ResourceManager.Models", provider.Type.Namespace);
}

[Test]
public void BaseModelProvider_ResolvesSystemTypeAlias()
{
var baseInputModel = InputFactory.Model("Resource", properties: []);
var derivedInputModel = InputFactory.Model("TrackedResource", properties: [], baseModel: baseInputModel);
var systemType = new CSharpType(typeof(InvalidOperationException));

MockHelpers.LoadMockGenerator(
inputModelTypes: [baseInputModel, derivedInputModel],
createModelCore: (model) =>
{
if (model == baseInputModel)
return new SystemObjectModelProvider(systemType, model);
if (model == derivedInputModel)
return new BuildBaseTypeOverridingModelProvider(model, systemType);
return new ModelProvider(model);
});

var systemBase = CodeModelGenerator.Instance.TypeFactory.CreateModel(baseInputModel);
var derived = CodeModelGenerator.Instance.TypeFactory.CreateModel(derivedInputModel);

Assert.IsNotNull(systemBase);
Assert.AreSame(systemBase, derived!.BaseModelProvider);
}

[Test]
public void CrossLanguageDefinitionId_ComesFromInputModel()
{
Expand Down Expand Up @@ -370,5 +443,17 @@ public void Constructor_ThrowsOnNullSystemType()
var inputModel = InputFactory.Model("Resource", properties: []);
Assert.Throws<ArgumentNullException>(() => new SystemObjectModelProvider(null!, inputModel));
}

private class BuildBaseTypeOverridingModelProvider : ModelProvider
{
private readonly CSharpType _baseType;

public BuildBaseTypeOverridingModelProvider(InputModelType inputModel, CSharpType baseType) : base(inputModel)
{
_baseType = baseType;
}

protected override CSharpType? BuildBaseType() => _baseType;
}
}
}
Loading