Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,28 +168,53 @@ internal override void CreatePhysicalSNIHandle(
string hostNameInCertificate,
string serverCertificateFilename)
{
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
if (!string.IsNullOrEmpty(serverSPN))
{
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
}
else
{
// This will signal to the interop layer that we need to retrieve the SPN
serverSPN = string.Empty;
}
}
// Normalize SPN based on authentication mode
serverSPN = NormalizeServerSpn(serverSPN, isIntegratedSecurity);

ConsumerInfo myInfo = CreateConsumerInfo(async);
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);

_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
resolvedSpn = new(serverSPN.TrimEnd());

// Only produce resolvedSpn when we actually have one.
if (!string.IsNullOrWhiteSpace(serverSPN))
{
resolvedSpn = new(serverSPN.TrimEnd());
}
else
{
resolvedSpn = default;
}
}

/// <summary>
/// Normalizes the serverSPN based on authentication mode.
/// </summary>
/// <param name="serverSPN">The server SPN value from the connection string.</param>
/// <param name="isIntegratedSecurity">Indicates whether integrated security (SSPI) is being used.</param>
/// <returns>
/// For integrated security: returns <paramref name="serverSPN"/> if provided, otherwise <see cref="string.Empty"/> to trigger SPN generation.
/// For SQL auth: returns <see langword="null"/> if <paramref name="serverSPN"/> is empty (no generation), otherwise returns the provided value.
/// </returns>
internal static string NormalizeServerSpn(string serverSPN, bool isIntegratedSecurity)
{
if (isIntegratedSecurity)
{
if (string.IsNullOrWhiteSpace(serverSPN))
{
// Empty signifies to interop layer that SPN needs to be generated
return string.Empty;
}

// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Server SPN `{0}` from the connection string is used.", serverSPN);
return serverSPN;
}

// For SQL auth (and other non-SSPI modes), null means "No SPN generation".
return string.IsNullOrWhiteSpace(serverSPN) ? null : serverSPN;
}

protected override uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#if NETCOREAPP && WINDOWS

#nullable enable

using Xunit;

namespace Microsoft.Data.SqlClient.UnitTests
{
public class TdsParserStateObjectNativeTests
{
[Theory]
[InlineData(null, true, "")] // Integrated + null -> empty (generate SPN)
[InlineData("", true, "")] // Integrated + empty -> empty (generate SPN)
[InlineData(" ", true, "")] // Integrated + whitespace -> empty (generate SPN)
[InlineData("MSSQLSvc/host", true, "MSSQLSvc/host")] // Integrated + provided -> use it
[InlineData(null, false, null)] // SQL Auth + null -> null (no generation)
[InlineData("", false, null)] // SQL Auth + empty -> null (no generation)
[InlineData(" ", false, null)] // SQL Auth + whitespace -> null (no generation)
[InlineData("MSSQLSvc/host", false, "MSSQLSvc/host")] // SQL Auth + provided -> use it
public void NormalizeServerSpn_ReturnsExpectedValue(
string? inputSpn,
bool isIntegratedSecurity,
string? expectedSpn)
{
string? result = TdsParserStateObjectNative.NormalizeServerSpn(inputSpn, isIntegratedSecurity);
Assert.Equal(expectedSpn, result);
}
}
}

#endif
Loading