MqttRetainedMessagesManager.cs 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. using MQTTnet.Diagnostics;
  2. using MQTTnet.Implementations;
  3. using MQTTnet.Internal;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Threading.Tasks;
  8. namespace MQTTnet.Server
  9. {
  10. public class MqttRetainedMessagesManager : IMqttRetainedMessagesManager
  11. {
  12. private readonly byte[] _emptyArray = new byte[0];
  13. private readonly AsyncLock _messagesLock = new AsyncLock();
  14. private readonly Dictionary<string, MqttApplicationMessage> _messages = new Dictionary<string, MqttApplicationMessage>();
  15. private IMqttNetLogger _logger;
  16. private IMqttServerOptions _options;
  17. public Task Start(IMqttServerOptions options, IMqttNetLogger logger)
  18. {
  19. if (logger == null) throw new ArgumentNullException(nameof(logger));
  20. _logger = logger.CreateChildLogger(nameof(MqttRetainedMessagesManager));
  21. _options = options ?? throw new ArgumentNullException(nameof(options));
  22. return PlatformAbstractionLayer.CompletedTask;
  23. }
  24. public async Task LoadMessagesAsync()
  25. {
  26. if (_options.Storage == null)
  27. {
  28. return;
  29. }
  30. try
  31. {
  32. var retainedMessages = await _options.Storage.LoadRetainedMessagesAsync().ConfigureAwait(false);
  33. if (retainedMessages?.Any() == true)
  34. {
  35. using (await _messagesLock.WaitAsync().ConfigureAwait(false))
  36. {
  37. _messages.Clear();
  38. foreach (var retainedMessage in retainedMessages)
  39. {
  40. _messages[retainedMessage.Topic] = retainedMessage;
  41. }
  42. }
  43. }
  44. }
  45. catch (Exception exception)
  46. {
  47. _logger.Error(exception, "Unhandled exception while loading retained messages.");
  48. }
  49. }
  50. public async Task HandleMessageAsync(string clientId, MqttApplicationMessage applicationMessage)
  51. {
  52. if (applicationMessage == null) throw new ArgumentNullException(nameof(applicationMessage));
  53. try
  54. {
  55. using (await _messagesLock.WaitAsync().ConfigureAwait(false))
  56. {
  57. var saveIsRequired = false;
  58. var hasPayload = applicationMessage.Payload != null && applicationMessage.Payload.Length > 0;
  59. if (!hasPayload)
  60. {
  61. saveIsRequired = _messages.Remove(applicationMessage.Topic);
  62. _logger.Verbose("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic);
  63. }
  64. else
  65. {
  66. if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage))
  67. {
  68. _messages[applicationMessage.Topic] = applicationMessage;
  69. saveIsRequired = true;
  70. }
  71. else
  72. {
  73. if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !existingMessage.Payload.SequenceEqual(applicationMessage.Payload ?? _emptyArray))
  74. {
  75. _messages[applicationMessage.Topic] = applicationMessage;
  76. saveIsRequired = true;
  77. }
  78. }
  79. _logger.Verbose("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic);
  80. }
  81. if (saveIsRequired)
  82. {
  83. if (_options.Storage != null)
  84. {
  85. var messagesForSave = new List<MqttApplicationMessage>(_messages.Values);
  86. await _options.Storage.SaveRetainedMessagesAsync(messagesForSave).ConfigureAwait(false);
  87. }
  88. }
  89. }
  90. }
  91. catch (Exception exception)
  92. {
  93. _logger.Error(exception, "Unhandled exception while handling retained messages.");
  94. }
  95. }
  96. public async Task<IList<MqttApplicationMessage>> GetSubscribedMessagesAsync(ICollection<MqttTopicFilter> topicFilters)
  97. {
  98. if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
  99. var matchingRetainedMessages = new List<MqttApplicationMessage>();
  100. List<MqttApplicationMessage> retainedMessages;
  101. using (await _messagesLock.WaitAsync().ConfigureAwait(false))
  102. {
  103. retainedMessages = _messages.Values.ToList();
  104. }
  105. foreach (var retainedMessage in retainedMessages)
  106. {
  107. foreach (var topicFilter in topicFilters)
  108. {
  109. if (!MqttTopicFilterComparer.IsMatch(retainedMessage.Topic, topicFilter.Topic))
  110. {
  111. continue;
  112. }
  113. matchingRetainedMessages.Add(retainedMessage);
  114. break;
  115. }
  116. }
  117. return matchingRetainedMessages;
  118. }
  119. public async Task<IList<MqttApplicationMessage>> GetMessagesAsync()
  120. {
  121. using (await _messagesLock.WaitAsync().ConfigureAwait(false))
  122. {
  123. return _messages.Values.ToList();
  124. }
  125. }
  126. public async Task ClearMessagesAsync()
  127. {
  128. using (await _messagesLock.WaitAsync().ConfigureAwait(false))
  129. {
  130. _messages.Clear();
  131. if (_options.Storage != null)
  132. {
  133. await _options.Storage.SaveRetainedMessagesAsync(new List<MqttApplicationMessage>()).ConfigureAwait(false);
  134. }
  135. }
  136. }
  137. }
  138. }