| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- using MQTTnet.Packets;
- using MQTTnet.Protocol;
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Threading.Tasks;
- namespace MQTTnet.Server
- {
- public class MqttClientSubscriptionsManager
- {
- private readonly Dictionary<string, MqttTopicFilter> _subscriptions = new Dictionary<string, MqttTopicFilter>();
- private readonly MqttClientSession _clientSession;
- private readonly IMqttServerOptions _serverOptions;
- private readonly MqttServerEventDispatcher _eventDispatcher;
- public MqttClientSubscriptionsManager(MqttClientSession clientSession, MqttServerEventDispatcher eventDispatcher, IMqttServerOptions serverOptions)
- {
- _clientSession = clientSession ?? throw new ArgumentNullException(nameof(clientSession));
- // TODO: Consider removing the server options here and build a new class "ISubscriptionInterceptor" and just pass it. The instance is generated in the root server class upon start.
- _serverOptions = serverOptions ?? throw new ArgumentNullException(nameof(serverOptions));
- _eventDispatcher = eventDispatcher ?? throw new ArgumentNullException(nameof(eventDispatcher));
- }
- public async Task<MqttClientSubscribeResult> SubscribeAsync(MqttSubscribePacket subscribePacket, MqttConnectPacket connectPacket)
- {
- if (subscribePacket == null) throw new ArgumentNullException(nameof(subscribePacket));
- if (connectPacket == null) throw new ArgumentNullException(nameof(connectPacket));
- var result = new MqttClientSubscribeResult
- {
- ResponsePacket = new MqttSubAckPacket
- {
- PacketIdentifier = subscribePacket.PacketIdentifier
- },
- CloseConnection = false
- };
- foreach (var originalTopicFilter in subscribePacket.TopicFilters)
- {
- var interceptorContext = await InterceptSubscribeAsync(originalTopicFilter).ConfigureAwait(false);
- var finalTopicFilter = interceptorContext.TopicFilter;
- if (finalTopicFilter == null || string.IsNullOrEmpty(finalTopicFilter.Topic) || !interceptorContext.AcceptSubscription)
- {
- result.ResponsePacket.ReturnCodes.Add(MqttSubscribeReturnCode.Failure);
- result.ResponsePacket.ReasonCodes.Add(MqttSubscribeReasonCode.UnspecifiedError);
- }
- else
- {
- result.ResponsePacket.ReturnCodes.Add(ConvertToSubscribeReturnCode(finalTopicFilter.QualityOfServiceLevel));
- result.ResponsePacket.ReasonCodes.Add(ConvertToSubscribeReasonCode(finalTopicFilter.QualityOfServiceLevel));
- }
- if (interceptorContext.CloseConnection)
- {
- result.CloseConnection = true;
- }
- if (interceptorContext.AcceptSubscription && !string.IsNullOrEmpty(finalTopicFilter?.Topic))
- {
- lock (_subscriptions)
- {
- _subscriptions[finalTopicFilter.Topic] = finalTopicFilter;
- }
- await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, finalTopicFilter).ConfigureAwait(false);
- }
- }
- return result;
- }
- public async Task SubscribeAsync(IEnumerable<MqttTopicFilter> topicFilters)
- {
- if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
- foreach (var topicFilter in topicFilters)
- {
- var interceptorContext = await InterceptSubscribeAsync(topicFilter).ConfigureAwait(false);
- if (!interceptorContext.AcceptSubscription)
- {
- continue;
- }
- if (interceptorContext.AcceptSubscription)
- {
- lock (_subscriptions)
- {
- _subscriptions[topicFilter.Topic] = topicFilter;
- }
- await _eventDispatcher.SafeNotifyClientSubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
- }
- }
- }
- public async Task<MqttUnsubAckPacket> UnsubscribeAsync(MqttUnsubscribePacket unsubscribePacket)
- {
- if (unsubscribePacket == null) throw new ArgumentNullException(nameof(unsubscribePacket));
- var unsubAckPacket = new MqttUnsubAckPacket
- {
- PacketIdentifier = unsubscribePacket.PacketIdentifier
- };
- foreach (var topicFilter in unsubscribePacket.TopicFilters)
- {
- var interceptorContext = await InterceptUnsubscribeAsync(topicFilter).ConfigureAwait(false);
- if (!interceptorContext.AcceptUnsubscription)
- {
- unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.ImplementationSpecificError);
- continue;
- }
- lock (_subscriptions)
- {
- if (_subscriptions.Remove(topicFilter))
- {
- unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.Success);
- }
- else
- {
- unsubAckPacket.ReasonCodes.Add(MqttUnsubscribeReasonCode.NoSubscriptionExisted);
- }
- }
- }
- foreach (var topicFilter in unsubscribePacket.TopicFilters)
- {
- await _eventDispatcher.SafeNotifyClientUnsubscribedTopicAsync(_clientSession.ClientId, topicFilter).ConfigureAwait(false);
- }
- return unsubAckPacket;
- }
- public async Task UnsubscribeAsync(IEnumerable<string> topicFilters)
- {
- if (topicFilters == null) throw new ArgumentNullException(nameof(topicFilters));
- foreach (var topicFilter in topicFilters)
- {
- var interceptorContext = await InterceptUnsubscribeAsync(topicFilter).ConfigureAwait(false);
- if (!interceptorContext.AcceptUnsubscription)
- {
- continue;
- }
- lock (_subscriptions)
- {
- _subscriptions.Remove(topicFilter);
- }
- }
- }
- public CheckSubscriptionsResult CheckSubscriptions(string topic, MqttQualityOfServiceLevel qosLevel)
- {
- var qosLevels = new HashSet<MqttQualityOfServiceLevel>();
- lock (_subscriptions)
- {
- foreach (var subscription in _subscriptions)
- {
- if (!MqttTopicFilterComparer.IsMatch(topic, subscription.Key))
- {
- continue;
- }
- qosLevels.Add(subscription.Value.QualityOfServiceLevel);
- }
- }
- if (qosLevels.Count == 0)
- {
- return new CheckSubscriptionsResult
- {
- IsSubscribed = false
- };
- }
- return CreateSubscriptionResult(qosLevel, qosLevels);
- }
- private static MqttSubscribeReturnCode ConvertToSubscribeReturnCode(MqttQualityOfServiceLevel qualityOfServiceLevel)
- {
- switch (qualityOfServiceLevel)
- {
- case MqttQualityOfServiceLevel.AtMostOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS0;
- case MqttQualityOfServiceLevel.AtLeastOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS1;
- case MqttQualityOfServiceLevel.ExactlyOnce: return MqttSubscribeReturnCode.SuccessMaximumQoS2;
- default: return MqttSubscribeReturnCode.Failure;
- }
- }
- private static MqttSubscribeReasonCode ConvertToSubscribeReasonCode(MqttQualityOfServiceLevel qualityOfServiceLevel)
- {
- switch (qualityOfServiceLevel)
- {
- case MqttQualityOfServiceLevel.AtMostOnce: return MqttSubscribeReasonCode.GrantedQoS0;
- case MqttQualityOfServiceLevel.AtLeastOnce: return MqttSubscribeReasonCode.GrantedQoS1;
- case MqttQualityOfServiceLevel.ExactlyOnce: return MqttSubscribeReasonCode.GrantedQoS2;
- default: return MqttSubscribeReasonCode.UnspecifiedError;
- }
- }
- private async Task<MqttSubscriptionInterceptorContext> InterceptSubscribeAsync(MqttTopicFilter topicFilter)
- {
- var context = new MqttSubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items);
- if (_serverOptions.SubscriptionInterceptor != null)
- {
- await _serverOptions.SubscriptionInterceptor.InterceptSubscriptionAsync(context).ConfigureAwait(false);
- }
- return context;
- }
- private async Task<MqttUnsubscriptionInterceptorContext> InterceptUnsubscribeAsync(string topicFilter)
- {
- var context = new MqttUnsubscriptionInterceptorContext(_clientSession.ClientId, topicFilter, _clientSession.Items);
- if (_serverOptions.UnsubscriptionInterceptor != null)
- {
- await _serverOptions.UnsubscriptionInterceptor.InterceptUnsubscriptionAsync(context).ConfigureAwait(false);
- }
- return context;
- }
- private static CheckSubscriptionsResult CreateSubscriptionResult(MqttQualityOfServiceLevel qosLevel, HashSet<MqttQualityOfServiceLevel> subscribedQoSLevels)
- {
- MqttQualityOfServiceLevel effectiveQoS;
- if (subscribedQoSLevels.Contains(qosLevel))
- {
- effectiveQoS = qosLevel;
- }
- else if (subscribedQoSLevels.Count == 1)
- {
- effectiveQoS = subscribedQoSLevels.First();
- }
- else
- {
- effectiveQoS = subscribedQoSLevels.Max();
- }
- return new CheckSubscriptionsResult
- {
- IsSubscribed = true,
- QualityOfServiceLevel = effectiveQoS
- };
- }
- }
- }
|