Skip to content

Commit 7736cf9

Browse files
authored
Added an ability to specify pool size with a callback using IServiceProvider (#757)
1 parent 0288134 commit 7736cf9

File tree

2 files changed

+86
-44
lines changed

2 files changed

+86
-44
lines changed

src/NATS.Extensions.Microsoft.DependencyInjection/NatsBuilder.cs

+21-44
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class NatsBuilder
1313
{
1414
private readonly IServiceCollection _services;
1515

16-
private int _poolSize = 1;
16+
private Func<IServiceProvider, int> _poolSizeConfigurer = _ => 1;
1717
private Func<IServiceProvider, NatsOpts, NatsOpts>? _configureOpts;
1818
private Action<IServiceProvider, NatsConnection>? _configureConnection;
1919
private object? _diKey = null;
@@ -25,7 +25,14 @@ public NatsBuilder(IServiceCollection services)
2525

2626
public NatsBuilder WithPoolSize(int size)
2727
{
28-
_poolSize = Math.Max(size, 1);
28+
_poolSizeConfigurer = _ => Math.Max(size, 1);
29+
30+
return this;
31+
}
32+
33+
public NatsBuilder WithPoolSize(Func<IServiceProvider, int> sizeConfigurer)
34+
{
35+
_poolSizeConfigurer = sp => Math.Max(sizeConfigurer(sp), 1);
2936

3037
return this;
3138
}
@@ -120,43 +127,23 @@ public NatsBuilder WithSerializerRegistry(INatsSerializerRegistry registry)
120127

121128
internal IServiceCollection Build()
122129
{
123-
if (_poolSize != 1)
130+
if (_diKey == null)
124131
{
125-
if (_diKey == null)
126-
{
127-
_services.TryAddSingleton<NatsConnectionPool>(provider => PoolFactory(provider));
128-
_services.TryAddSingleton<INatsConnectionPool>(static provider => provider.GetRequiredService<NatsConnectionPool>());
129-
_services.TryAddTransient<NatsConnection>(static provider => PooledConnectionFactory(provider, null));
130-
_services.TryAddTransient<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
131-
_services.TryAddTransient<INatsClient>(static provider => provider.GetRequiredService<NatsConnection>());
132-
}
133-
else
134-
{
135-
#if NET8_0_OR_GREATER
136-
_services.TryAddKeyedSingleton<NatsConnectionPool>(_diKey, PoolFactory);
137-
_services.TryAddKeyedSingleton<INatsConnectionPool>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnectionPool>(key));
138-
_services.TryAddKeyedTransient<NatsConnection>(_diKey, PooledConnectionFactory);
139-
_services.TryAddKeyedTransient<INatsConnection>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
140-
_services.TryAddKeyedTransient<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
141-
#endif
142-
}
132+
_services.TryAddSingleton<NatsConnectionPool>(provider => PoolFactory(provider));
133+
_services.TryAddSingleton<INatsConnectionPool>(static provider => provider.GetRequiredService<NatsConnectionPool>());
134+
_services.TryAddTransient<NatsConnection>(static provider => PooledConnectionFactory(provider, null));
135+
_services.TryAddTransient<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
136+
_services.TryAddTransient<INatsClient>(static provider => provider.GetRequiredService<NatsConnection>());
143137
}
144138
else
145139
{
146-
if (_diKey == null)
147-
{
148-
_services.TryAddSingleton<NatsConnection>(provider => SingleConnectionFactory(provider));
149-
_services.TryAddSingleton<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
150-
_services.TryAddSingleton<INatsClient>(static provider => provider.GetRequiredService<NatsConnection>());
151-
}
152-
else
153-
{
154140
#if NET8_0_OR_GREATER
155-
_services.TryAddKeyedSingleton(_diKey, SingleConnectionFactory);
156-
_services.TryAddKeyedSingleton<INatsConnection>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
157-
_services.TryAddKeyedSingleton<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
141+
_services.TryAddKeyedSingleton<NatsConnectionPool>(_diKey, PoolFactory);
142+
_services.TryAddKeyedSingleton<INatsConnectionPool>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnectionPool>(key));
143+
_services.TryAddKeyedTransient<NatsConnection>(_diKey, PooledConnectionFactory);
144+
_services.TryAddKeyedTransient<INatsConnection>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
145+
_services.TryAddKeyedTransient<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
158146
#endif
159-
}
160147
}
161148

162149
return _services;
@@ -179,17 +166,7 @@ private NatsConnectionPool PoolFactory(IServiceProvider provider, object? diKey
179166
{
180167
var options = GetNatsOpts(provider);
181168

182-
return new NatsConnectionPool(_poolSize, options, con => _configureConnection?.Invoke(provider, con));
183-
}
184-
185-
private NatsConnection SingleConnectionFactory(IServiceProvider provider, object? diKey = null)
186-
{
187-
var options = GetNatsOpts(provider);
188-
189-
var conn = new NatsConnection(options);
190-
_configureConnection?.Invoke(provider, conn);
191-
192-
return conn;
169+
return new NatsConnectionPool(_poolSizeConfigurer(provider), options, con => _configureConnection?.Invoke(provider, con));
193170
}
194171

195172
private NatsOpts GetNatsOpts(IServiceProvider provider)

tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/NatsHostingExtensionsTests.cs

+65
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,21 @@ public async Task AddNatsClient_WithDefaultSerializer()
177177
}
178178
}
179179

180+
[Fact]
181+
public void AddNatsClient_RegistersNatsConnectionAsTransient_WhenPoolSizeFuncIsGreaterThanOne()
182+
{
183+
var services = new ServiceCollection();
184+
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();
185+
services.AddNatsClient(nats => nats.WithPoolSize(_ => 2));
186+
187+
var provider = services.BuildServiceProvider();
188+
var natsConnection1 = provider.GetRequiredService<INatsConnection>();
189+
var natsConnection2 = provider.GetRequiredService<INatsConnection>();
190+
191+
Assert.NotNull(natsConnection1);
192+
Assert.NotSame(natsConnection1, natsConnection2); // Transient should return different instances
193+
}
194+
180195
[Fact]
181196
public async Task AddNatsClient_WithJsonSerializer()
182197
{
@@ -347,6 +362,56 @@ public void AddNats_RegistersKeyedNatsConnection_WhenKeyIsProvided_pooled()
347362
Assert.NotSame(obj1, obj2);
348363
}
349364
}
365+
366+
[Fact]
367+
public void AddNats_RegistersKeyedNatsConnection_WhenKeyIsProvided_pooledFunc()
368+
{
369+
var key1 = "TestKey1";
370+
var key2 = "TestKey2";
371+
372+
var services = new ServiceCollection();
373+
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();
374+
375+
services.AddNatsClient(builder => builder.WithPoolSize(_ => 2).WithKey(key1));
376+
services.AddNatsClient(builder => builder.WithPoolSize(_ => 2).WithKey(key2));
377+
var provider = services.BuildServiceProvider();
378+
379+
Dictionary<string, List<object>> connections = new();
380+
foreach (var key in new[] { key1, key2 })
381+
{
382+
var nats1 = provider.GetKeyedService<INatsConnection>(key);
383+
Assert.NotNull(nats1);
384+
var nats2 = provider.GetKeyedService<INatsConnection>(key);
385+
Assert.NotNull(nats2);
386+
var nats3 = provider.GetKeyedService<INatsConnection>(key);
387+
Assert.NotNull(nats3);
388+
var nats4 = provider.GetKeyedService<INatsConnection>(key);
389+
Assert.NotNull(nats4);
390+
391+
// relying on the fact that the pool size is 2 and connections are returned in a round-robin fashion
392+
Assert.NotSame(nats1, nats2);
393+
Assert.Same(nats1, nats3);
394+
Assert.NotSame(nats2, nats3);
395+
Assert.Same(nats2, nats4);
396+
397+
if (!connections.TryGetValue(key, out var list))
398+
{
399+
list = new List<object>();
400+
connections.Add(key, list);
401+
}
402+
403+
list.Add(nats1);
404+
list.Add(nats2);
405+
list.Add(nats3);
406+
list.Add(nats4);
407+
}
408+
409+
foreach (var obj1 in connections[key1])
410+
{
411+
foreach (var obj2 in connections[key2])
412+
Assert.NotSame(obj1, obj2);
413+
}
414+
}
350415
#endif
351416
}
352417

0 commit comments

Comments
 (0)