MqttClientSessionsManager.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. using MQTTnet.Adapter;
  2. using MQTTnet.Diagnostics;
  3. using MQTTnet.Exceptions;
  4. using MQTTnet.Formatter;
  5. using MQTTnet.Internal;
  6. using MQTTnet.Packets;
  7. using MQTTnet.Protocol;
  8. using MQTTnet.Server.Status;
  9. using System;
  10. using System.Collections.Concurrent;
  11. using System.Collections.Generic;
  12. using System.Threading;
  13. using System.Threading.Tasks;
  14. namespace MQTTnet.Server
  15. {
  16. public class MqttClientSessionsManager : Disposable
  17. {
  18. readonly AsyncQueue<MqttEnqueuedApplicationMessage> _messageQueue = new AsyncQueue<MqttEnqueuedApplicationMessage>();
  19. readonly AsyncLock _createConnectionGate = new AsyncLock();
  20. readonly ConcurrentDictionary<string, MqttClientConnection> _connections = new ConcurrentDictionary<string, MqttClientConnection>();
  21. readonly ConcurrentDictionary<string, MqttClientSession> _sessions = new ConcurrentDictionary<string, MqttClientSession>();
  22. readonly IDictionary<object, object> _serverSessionItems = new ConcurrentDictionary<object, object>();
  23. readonly CancellationToken _cancellationToken;
  24. readonly MqttServerEventDispatcher _eventDispatcher;
  25. readonly IMqttRetainedMessagesManager _retainedMessagesManager;
  26. readonly IMqttServerOptions _options;
  27. readonly IMqttNetLogger _logger;
  28. public MqttClientSessionsManager(
  29. IMqttServerOptions options,
  30. IMqttRetainedMessagesManager retainedMessagesManager,
  31. CancellationToken cancellationToken,
  32. MqttServerEventDispatcher eventDispatcher,
  33. IMqttNetLogger logger)
  34. {
  35. _cancellationToken = cancellationToken;
  36. if (logger == null) throw new ArgumentNullException(nameof(logger));
  37. _logger = logger.CreateChildLogger(nameof(MqttClientSessionsManager));
  38. _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher));
  39. _options = options ?? throw new ArgumentNullException(nameof(options));
  40. _retainedMessagesManager = retainedMessagesManager ?? throw new ArgumentNullException(nameof(retainedMessagesManager));
  41. }
  42. public void Start()
  43. {
  44. Task.Run(() => TryProcessQueuedApplicationMessagesAsync(_cancellationToken), _cancellationToken).Forget(_logger);
  45. }
  46. public async Task StopAsync()
  47. {
  48. foreach (var connection in _connections.Values)
  49. {
  50. await connection.StopAsync().ConfigureAwait(false);
  51. }
  52. }
  53. public Task HandleClientConnectionAsync(IMqttChannelAdapter clientAdapter)
  54. {
  55. if (clientAdapter is null) throw new ArgumentNullException(nameof(clientAdapter));
  56. return HandleClientConnectionAsync(clientAdapter, _cancellationToken);
  57. }
  58. public Task<IList<IMqttClientStatus>> GetClientStatusAsync()
  59. {
  60. var result = new List<IMqttClientStatus>();
  61. foreach (var connection in _connections.Values)
  62. {
  63. var clientStatus = new MqttClientStatus(connection);
  64. connection.FillStatus(clientStatus);
  65. var sessionStatus = new MqttSessionStatus(connection.Session, this);
  66. connection.Session.FillStatus(sessionStatus);
  67. clientStatus.Session = sessionStatus;
  68. result.Add(clientStatus);
  69. }
  70. return Task.FromResult((IList<IMqttClientStatus>)result);
  71. }
  72. public Task<IList<IMqttSessionStatus>> GetSessionStatusAsync()
  73. {
  74. var result = new List<IMqttSessionStatus>();
  75. foreach (var session in _sessions.Values)
  76. {
  77. var sessionStatus = new MqttSessionStatus(session, this);
  78. session.FillStatus(sessionStatus);
  79. result.Add(sessionStatus);
  80. }
  81. return Task.FromResult((IList<IMqttSessionStatus>)result);
  82. }
  83. public void DispatchApplicationMessage(MqttApplicationMessage applicationMessage, MqttClientConnection sender)
  84. {
  85. if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));
  86. _messageQueue.Enqueue(new MqttEnqueuedApplicationMessage(applicationMessage, sender));
  87. }
  88. public Task SubscribeAsync(string clientId, ICollection<MqttTopicFilter> topicFilters)
  89. {
  90. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  91. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  92. if (!_sessions.TryGetValue(clientId, out var session))
  93. {
  94. throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
  95. }
  96. return session.SubscribeAsync(topicFilters);
  97. }
  98. public Task UnsubscribeAsync(string clientId, IEnumerable<string> topicFilters)
  99. {
  100. if (clientId == null) throw new ArgumentNullException(nameof(clientId));
  101. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  102. if (!_sessions.TryGetValue(clientId, out var session))
  103. {
  104. throw new InvalidOperationException($"Client session '{clientId}' is unknown.");
  105. }
  106. return session.UnsubscribeAsync(topicFilters);
  107. }
  108. public async Task DeleteSessionAsync(string clientId)
  109. {
  110. if (_connections.TryGetValue(clientId, out var connection))
  111. {
  112. await connection.StopAsync().ConfigureAwait(false);
  113. }
  114. if (_sessions.TryRemove(clientId, out _))
  115. {
  116. }
  117. _logger.Verbose("Session for client '{0}' deleted.", clientId);
  118. }
  119. protected override void Dispose(bool disposing)
  120. {
  121. if (disposing)
  122. {
  123. _messageQueue?.Dispose();
  124. }
  125. base.Dispose(disposing);
  126. }
  127. async Task TryProcessQueuedApplicationMessagesAsync(CancellationToken cancellationToken)
  128. {
  129. while (!cancellationToken.IsCancellationRequested)
  130. {
  131. try
  132. {
  133. await TryProcessNextQueuedApplicationMessageAsync(cancellationToken).ConfigureAwait(false);
  134. }
  135. catch (OperationCanceledException)
  136. {
  137. }
  138. catch (Exception exception)
  139. {
  140. _logger.Error(exception, "Unhandled exception while processing queued application messages.");
  141. }
  142. }
  143. }
  144. async Task TryProcessNextQueuedApplicationMessageAsync(CancellationToken cancellationToken)
  145. {
  146. try
  147. {
  148. if (cancellationToken.IsCancellationRequested)
  149. {
  150. return;
  151. }
  152. var dequeueResult = await _messageQueue.TryDequeueAsync(cancellationToken).ConfigureAwait(false);
  153. if (!dequeueResult.IsSuccess)
  154. {
  155. return;
  156. }
  157. var queuedApplicationMessage = dequeueResult.Item;
  158. var sender = queuedApplicationMessage.Sender;
  159. var applicationMessage = queuedApplicationMessage.ApplicationMessage;
  160. var interceptorContext = await InterceptApplicationMessageAsync(sender, applicationMessage).ConfigureAwait(false);
  161. if (interceptorContext != null)
  162. {
  163. if (interceptorContext.CloseConnection)
  164. {
  165. if (sender != null)
  166. {
  167. await sender.StopAsync().ConfigureAwait(false);
  168. }
  169. }
  170. if (interceptorContext.ApplicationMessage == null || !interceptorContext.AcceptPublish)
  171. {
  172. return;
  173. }
  174. applicationMessage = interceptorContext.ApplicationMessage;
  175. }
  176. await _eventDispatcher.SafeNotifyApplicationMessageReceivedAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
  177. if (applicationMessage.Retain)
  178. {
  179. await _retainedMessagesManager.HandleMessageAsync(sender?.ClientId, applicationMessage).ConfigureAwait(false);
  180. }
  181. foreach (var clientSession in _sessions.Values)
  182. {
  183. clientSession.EnqueueApplicationMessage(
  184. applicationMessage,
  185. sender?.ClientId,
  186. false);
  187. }
  188. }
  189. catch (OperationCanceledException)
  190. {
  191. }
  192. catch (Exception exception)
  193. {
  194. _logger.Error(exception, "Unhandled exception while processing next queued application message.");
  195. }
  196. }
  197. async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
  198. {
  199. string clientId = null;
  200. MqttConnectPacket connectPacket;
  201. try
  202. {
  203. try
  204. {
  205. var firstPacket = await channelAdapter.ReceivePacketAsync(_options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
  206. connectPacket = firstPacket as MqttConnectPacket;
  207. if (connectPacket == null)
  208. {
  209. _logger.Warning(null, "The first packet from client '{0}' was no 'CONNECT' packet [MQTT-3.1.0-1].", channelAdapter.Endpoint);
  210. return;
  211. }
  212. }
  213. catch (MqttCommunicationTimedOutException)
  214. {
  215. _logger.Warning(null, "Client '{0}' connected but did not sent a CONNECT packet.", channelAdapter.Endpoint);
  216. return;
  217. }
  218. var connectionValidatorContext = await ValidateConnectionAsync(connectPacket, channelAdapter).ConfigureAwait(false);
  219. if (connectionValidatorContext.ReasonCode != MqttConnectReasonCode.Success)
  220. {
  221. // Send failure response here without preparing a session. The result for a successful connect
  222. // will be sent from the session itself.
  223. var connAckPacket = channelAdapter.PacketFormatterAdapter.DataConverter.CreateConnAckPacket(connectionValidatorContext);
  224. await channelAdapter.SendPacketAsync(connAckPacket, _options.DefaultCommunicationTimeout, cancellationToken).ConfigureAwait(false);
  225. return;
  226. }
  227. clientId = connectPacket.ClientId;
  228. var connection = await CreateClientConnectionAsync(
  229. connectPacket,
  230. connectionValidatorContext,
  231. channelAdapter,
  232. async () => await _eventDispatcher.SafeNotifyClientConnectedAsync(clientId).ConfigureAwait(false),
  233. async disconnectType => await CleanUpClient(clientId, channelAdapter, disconnectType)
  234. ).ConfigureAwait(false);
  235. await connection.RunAsync(connectionValidatorContext).ConfigureAwait(false);
  236. }
  237. catch (OperationCanceledException)
  238. {
  239. }
  240. catch (Exception exception)
  241. {
  242. _logger.Error(exception, exception.Message);
  243. }
  244. }
  245. private async Task CleanUpClient(string clientId, IMqttChannelAdapter channelAdapter, MqttClientDisconnectType disconnectType)
  246. {
  247. if (clientId != null)
  248. {
  249. _connections.TryRemove(clientId, out _);
  250. if (!_options.EnablePersistentSessions)
  251. {
  252. await DeleteSessionAsync(clientId).ConfigureAwait(false);
  253. }
  254. }
  255. await SafeCleanupChannelAsync(channelAdapter).ConfigureAwait(false);
  256. if (clientId != null)
  257. {
  258. await _eventDispatcher.SafeNotifyClientDisconnectedAsync(clientId, disconnectType).ConfigureAwait(false);
  259. }
  260. }
  261. async Task<MqttConnectionValidatorContext> ValidateConnectionAsync(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
  262. {
  263. var context = new MqttConnectionValidatorContext(connectPacket, channelAdapter, new ConcurrentDictionary<object, object>());
  264. var connectionValidator = _options.ConnectionValidator;
  265. if (connectionValidator == null)
  266. {
  267. context.ReasonCode = MqttConnectReasonCode.Success;
  268. return context;
  269. }
  270. await connectionValidator.ValidateConnectionAsync(context).ConfigureAwait(false);
  271. // Check the client ID and set a random one if supported.
  272. if (string.IsNullOrEmpty(connectPacket.ClientId) && channelAdapter.PacketFormatterAdapter.ProtocolVersion == MqttProtocolVersion.V500)
  273. {
  274. connectPacket.ClientId = context.AssignedClientIdentifier;
  275. }
  276. if (string.IsNullOrEmpty(connectPacket.ClientId))
  277. {
  278. context.ReasonCode = MqttConnectReasonCode.ClientIdentifierNotValid;
  279. }
  280. return context;
  281. }
  282. async Task<MqttClientConnection> CreateClientConnectionAsync(MqttConnectPacket connectPacket, MqttConnectionValidatorContext connectionValidatorContext, IMqttChannelAdapter channelAdapter, Func<Task> onStart, Func<MqttClientDisconnectType, Task> onStop)
  283. {
  284. using (await _createConnectionGate.WaitAsync(_cancellationToken).ConfigureAwait(false))
  285. {
  286. var isSessionPresent = _sessions.TryGetValue(connectPacket.ClientId, out var session);
  287. var isConnectionPresent = _connections.TryGetValue(connectPacket.ClientId, out var existingConnection);
  288. if (isConnectionPresent)
  289. {
  290. await existingConnection.StopAsync(true).ConfigureAwait(false);
  291. }
  292. if (isSessionPresent)
  293. {
  294. if (connectPacket.CleanSession)
  295. {
  296. session = null;
  297. _logger.Verbose("Deleting existing session of client '{0}'.", connectPacket.ClientId);
  298. }
  299. else
  300. {
  301. _logger.Verbose("Reusing existing session of client '{0}'.", connectPacket.ClientId);
  302. }
  303. }
  304. if (session == null)
  305. {
  306. session = new MqttClientSession(connectPacket.ClientId, connectionValidatorContext.SessionItems, _eventDispatcher, _options, _retainedMessagesManager, _logger);
  307. _logger.Verbose("Created a new session for client '{0}'.", connectPacket.ClientId);
  308. }
  309. var connection = new MqttClientConnection(connectPacket, channelAdapter, session, _options, this, _retainedMessagesManager, onStart, onStop, _logger);
  310. _connections[connection.ClientId] = connection;
  311. _sessions[session.ClientId] = session;
  312. return connection;
  313. }
  314. }
  315. async Task<MqttApplicationMessageInterceptorContext> InterceptApplicationMessageAsync(MqttClientConnection senderConnection, MqttApplicationMessage applicationMessage)
  316. {
  317. var interceptor = _options.ApplicationMessageInterceptor;
  318. if (interceptor == null)
  319. {
  320. return null;
  321. }
  322. string senderClientId;
  323. IDictionary<object, object> sessionItems;
  324. var messageIsFromServer = senderConnection == null;
  325. if (messageIsFromServer)
  326. {
  327. senderClientId = _options.ClientId;
  328. sessionItems = _serverSessionItems;
  329. }
  330. else
  331. {
  332. senderClientId = senderConnection.ClientId;
  333. sessionItems = senderConnection.Session.Items;
  334. }
  335. var interceptorContext = new MqttApplicationMessageInterceptorContext(senderClientId, sessionItems, applicationMessage);
  336. await interceptor.InterceptApplicationMessagePublishAsync(interceptorContext).ConfigureAwait(false);
  337. return interceptorContext;
  338. }
  339. async Task SafeCleanupChannelAsync(IMqttChannelAdapter channelAdapter)
  340. {
  341. try
  342. {
  343. await channelAdapter.DisconnectAsync(_options.DefaultCommunicationTimeout, CancellationToken.None).ConfigureAwait(false);
  344. }
  345. catch (Exception exception)
  346. {
  347. _logger.Error(exception, "Error while disconnecting client channel.");
  348. }
  349. }
  350. }
  351. }