TestEnvironment.cs 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using MQTTnet.Client;
  3. using MQTTnet.Client.Options;
  4. using MQTTnet.Diagnostics;
  5. using MQTTnet.Internal;
  6. using MQTTnet.Server;
  7. using System;
  8. using System.Collections.Generic;
  9. using System.Linq;
  10. using System.Threading.Tasks;
  11. namespace MQTTnet.Tests.Mockups
  12. {
  13. public sealed class TestEnvironment : Disposable
  14. {
  15. readonly MqttFactory _mqttFactory = new MqttFactory();
  16. readonly List<IMqttClient> _clients = new List<IMqttClient>();
  17. readonly IMqttNetLogger _serverLogger = new MqttNetLogger("server");
  18. readonly IMqttNetLogger _clientLogger = new MqttNetLogger("client");
  19. readonly List<string> _serverErrors = new List<string>();
  20. readonly List<string> _clientErrors = new List<string>();
  21. readonly List<Exception> _exceptions = new List<Exception>();
  22. public IMqttServer Server { get; private set; }
  23. public bool IgnoreClientLogErrors { get; set; }
  24. public bool IgnoreServerLogErrors { get; set; }
  25. public int ServerPort { get; set; } = 1888;
  26. public IMqttNetLogger ServerLogger => _serverLogger;
  27. public IMqttNetLogger ClientLogger => _clientLogger;
  28. public TestContext TestContext { get; }
  29. public TestEnvironment() : this(null)
  30. {
  31. }
  32. public TestEnvironment(TestContext testContext)
  33. {
  34. TestContext = testContext;
  35. _serverLogger.LogMessagePublished += (s, e) =>
  36. {
  37. if (e.LogMessage.Level == MqttNetLogLevel.Error)
  38. {
  39. lock (_serverErrors)
  40. {
  41. _serverErrors.Add(e.LogMessage.ToString());
  42. }
  43. }
  44. };
  45. _clientLogger.LogMessagePublished += (s, e) =>
  46. {
  47. if (e.LogMessage.Level == MqttNetLogLevel.Error)
  48. {
  49. lock (_clientErrors)
  50. {
  51. _clientErrors.Add(e.LogMessage.ToString());
  52. }
  53. }
  54. };
  55. }
  56. public IMqttClient CreateClient()
  57. {
  58. var client = _mqttFactory.CreateMqttClient(_clientLogger);
  59. _clients.Add(client);
  60. return new TestClientWrapper(client, TestContext);
  61. }
  62. public Task<IMqttServer> StartServerAsync()
  63. {
  64. return StartServerAsync(new MqttServerOptionsBuilder());
  65. }
  66. public async Task<IMqttServer> StartServerAsync(MqttServerOptionsBuilder options)
  67. {
  68. if (Server != null)
  69. {
  70. throw new InvalidOperationException("Server already started.");
  71. }
  72. Server = new TestServerWrapper(_mqttFactory.CreateMqttServer(_serverLogger), TestContext, this);
  73. await Server.StartAsync(options.WithDefaultEndpointPort(ServerPort).Build());
  74. return Server;
  75. }
  76. public Task<IMqttClient> ConnectClientAsync()
  77. {
  78. return ConnectClientAsync(new MqttClientOptionsBuilder());
  79. }
  80. public async Task<IMqttClient> ConnectClientAsync(MqttClientOptionsBuilder options)
  81. {
  82. if (options == null) throw new ArgumentNullException(nameof(options));
  83. options = options.WithTcpServer("localhost", ServerPort);
  84. var client = CreateClient();
  85. await client.ConnectAsync(options.Build());
  86. return client;
  87. }
  88. public async Task<IMqttClient> ConnectClientAsync(IMqttClientOptions options)
  89. {
  90. if (options == null) throw new ArgumentNullException(nameof(options));
  91. var client = CreateClient();
  92. await client.ConnectAsync(options);
  93. return client;
  94. }
  95. public void ThrowIfLogErrors()
  96. {
  97. lock (_serverErrors)
  98. {
  99. if (!IgnoreServerLogErrors && _serverErrors.Count > 0)
  100. {
  101. throw new Exception($"Server had {_serverErrors.Count} errors (${string.Join(Environment.NewLine, _serverErrors)}).");
  102. }
  103. }
  104. lock (_clientErrors)
  105. {
  106. if (!IgnoreClientLogErrors && _clientErrors.Count > 0)
  107. {
  108. throw new Exception($"Client(s) had {_clientErrors.Count} errors (${string.Join(Environment.NewLine, _clientErrors)}).");
  109. }
  110. }
  111. }
  112. protected override void Dispose(bool disposing)
  113. {
  114. if (disposing)
  115. {
  116. foreach (var mqttClient in _clients)
  117. {
  118. mqttClient?.Dispose();
  119. }
  120. Server?.StopAsync().GetAwaiter().GetResult();
  121. ThrowIfLogErrors();
  122. if (_exceptions.Any())
  123. {
  124. throw new Exception($"{_exceptions.Count} exceptions tracked.\r\n" + string.Join(Environment.NewLine, _exceptions));
  125. }
  126. }
  127. base.Dispose(disposing);
  128. }
  129. public void TrackException(Exception exception)
  130. {
  131. lock (_exceptions)
  132. {
  133. _exceptions.Add(exception);
  134. }
  135. }
  136. }
  137. }