Skip to content

Commit 8adfef4

Browse files
committed
Pass SqlAuthenticationParameters in GenerateSspiClientContext
As part of this change, the SSPIContextProvider base class now iterates through all the server names similar to what NegotiateSSPIContextProvider did.
1 parent 9450fde commit 8adfef4

File tree

3 files changed

+82
-35
lines changed

3 files changed

+82
-35
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ private void LoadSSPILibrary()
4949
}
5050
}
5151

52-
protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
52+
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
5353
{
5454
#if NETFRAMEWORK
5555
SNIHandle handle = _physicalStateObj.Handle;
@@ -62,9 +62,9 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
6262
var sendLength = s_maxSSPILength;
6363
var outBuff = outgoingBlobWriter.GetSpan((int)sendLength);
6464

65-
if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpns[0]))
65+
if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.ServerName))
6666
{
67-
throw new InvalidOperationException(SQLMessage.SSPIGenerateError());
67+
return false;
6868
}
6969

7070
if (sendLength > int.MaxValue)
@@ -73,6 +73,8 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
7373
}
7474

7575
outgoingBlobWriter.Advance((int)sendLength);
76+
77+
return true;
7678
}
7779
}
7880
}

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs

+13-22
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#if NET
22

33
using System;
4-
using System.Net.Security;
54
using System.Buffers;
5+
using System.Net.Security;
66

77
#nullable enable
88

@@ -12,33 +12,24 @@ internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider
1212
{
1313
private NegotiateAuthentication? _negotiateAuth = null;
1414

15-
protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
15+
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
1616
{
1717
NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials;
1818

19-
for (int i = 0; i < serverSpns.Length; i++)
20-
{
21-
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = serverSpns[i] });
22-
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;
23-
24-
// Log session id, status code and the actual SPN used in the negotiation
25-
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
26-
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
27-
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
28-
{
29-
outgoingBlobWriter.Write(sendBuff);
30-
break; // Successful case, exit the loop with current SPN.
31-
}
32-
else
33-
{
34-
_negotiateAuth = null; // Reset _negotiateAuth to be generated again for next SPN.
35-
}
36-
}
19+
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.ServerName });
20+
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;
21+
22+
// Log session id, status code and the actual SPN used in the negotiation
23+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
24+
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
3725

38-
if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded)
26+
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
3927
{
40-
throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
28+
outgoingBlobWriter.Write(sendBuff);
29+
return true; // Successful case, exit the loop with current SPN.
4130
}
31+
32+
return false;
4233
}
4334
}
4435
}

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs

+64-10
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,78 @@ private protected virtual void Initialize()
2626
{
2727
}
2828

29-
protected abstract void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns);
29+
protected abstract bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams);
3030

3131
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string serverSpn)
32-
=> SSPIData(receivedBuff, outgoingBlobWriter, new[] { serverSpn });
32+
{
33+
using var _ = TrySNIEventScope.Create(nameof(SSPIContextProvider));
3334

34-
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string[] serverSpns)
35+
if (!RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
36+
{
37+
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
38+
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
39+
}
40+
}
41+
42+
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
3543
{
36-
using (TrySNIEventScope.Create(nameof(SSPIContextProvider)))
44+
using var _ = TrySNIEventScope.Create(nameof(SSPIContextProvider));
45+
46+
foreach (var serverSpn in serverSpns)
3747
{
38-
try
39-
{
40-
GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpns);
41-
}
42-
catch (Exception e)
48+
if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
4349
{
44-
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
50+
return;
4551
}
4652
}
53+
54+
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
55+
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
56+
}
57+
58+
private bool RunGenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, string serverSpn)
59+
{
60+
var authParams = CreateSqlAuthParams(_parser.Connection, serverSpn);
61+
62+
try
63+
{
64+
#if NET8_0_OR_GREATER
65+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, SPN={3}", GetType().FullName,
66+
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, serverSpn);
67+
#else
68+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName,
69+
nameof(GenerateSspiClientContext), serverSpn);
70+
#endif
71+
72+
return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams);
73+
}
74+
catch (Exception e)
75+
{
76+
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
77+
return false;
78+
}
79+
}
80+
81+
private static SqlAuthenticationParameters CreateSqlAuthParams(SqlInternalConnectionTds connection, string serverSpn)
82+
{
83+
var auth = new SqlAuthenticationParameters.Builder(
84+
authenticationMethod: connection.ConnectionOptions.Authentication,
85+
resource: null,
86+
authority: null,
87+
serverName: serverSpn,
88+
connection.ConnectionOptions.InitialCatalog);
89+
90+
if (connection.ConnectionOptions.UserID is { } userId)
91+
{
92+
auth.WithUserId(userId);
93+
}
94+
95+
if (connection.ConnectionOptions.Password is { } password)
96+
{
97+
auth.WithPassword(password);
98+
}
99+
100+
return auth;
47101
}
48102

49103
protected void SSPIError(string error, string procedure)

0 commit comments

Comments
 (0)