package jwtauth import ( "crypto/rsa" "errors" "net/http" "os" "strings" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v4" ) const JwtPayloadKey = "JWT_IOT_PAYLOAD" const JwtTokenKey = "JWT_TOKEN" // GinJWTMiddleware 提供了Json Web Token身份验证实现 // 失败时,将返回401 HTTP响应,成功后将调用封装的中间件,并将 userID 作为 c.Get("userID").(string).Users 提供。 // 通过向LoginHandler POST请求来获取令牌, 在Authentication标头中传递令牌 Authorization:Bearer XX_TOKEN_XXX type GinJWTMiddleware struct { // Realm 用户姓名 (必需) Realm string // 签名算法-可能的值为HS256、HS384、HS512 。默认 HS256. SigningAlgorithm string // 签名密钥 (必需) Key []byte // jwt令牌有效的持续时间, 默认为一小时 Timeout time.Duration // 此字段允许客户端刷新其令牌,直到MaxRefresh过期为止。 默认为0,表示不可刷新 // 请注意,客户端可以在MaxRefresh的最后一刻刷新其令牌,这意味着令牌的最大有效时间跨度是TokenTime+MaxRefresh。 MaxRefresh time.Duration //回调函数,应根据登录信息对用户进行身份验证。 //必须将用户数据作为用户标识符返回,它将存储在Claim Array中。(必需) //检查错误(e)以确定适当的错误消息。 Authenticator func(c *gin.Context) (interface{}, error) //回调函数,该函数应执行已验证用户的授权。仅在身份验证成功后调用。成功时必须返回真,失败时必须返回假。 默认为 true。 Authorizator func(data interface{}, c *gin.Context) bool // 将在登录期间调用的回调函数。 默认情况下不会设置其他数据。 // 使用此函数可以将额外的有效负载数据添加到网络令牌中。通过c.Get(“JWT_PAYLOAD”)在请求期间获取数据。 // 请注意,有效载荷未加密。 jwt.io上提到的属性不能用作贴图的关键点。 PayloadFunc func(data interface{}) MapClaims // 自定义未授权回调。 Unauthorized func(*gin.Context, int, string) // 自定义登录响应回调 LoginResponse func(*gin.Context, int, string, time.Time) // Antd登录响应回调. AntdLoginResponse func(*gin.Context, int, string, time.Time) // 自定义刷新响应回调 RefreshResponse func(*gin.Context, int, string, time.Time) // 设置Identity处理程序函数 IdentityHandler func(*gin.Context) interface{} // 关键字段,用于存储用户信息 // Set the identity key IdentityKey string // 用户名 NiceKey string // 数据权限类型 DataScopeKey string // role key RKey string // 角色id RoleIdKey string // 角色key RoleKey string // 角色名称 RoleNameKey string // TokenLookup is a string in the form of ":" that is used // to extract token from the request. // Optional. Default value "header:Authorization". // Possible values: // - "header:" // - "query:" // - "cookie:" TokenLookup string // TokenHeadName is a string in the header. Default value is "Bearer" TokenHeadName string // TimeFunc provides the current time. You can override it to use another time value. This is useful for testing or if your server uses a different time zone than your tokens. TimeFunc func() time.Time // HTTP Status messages for when something in the JWT middleware fails. // Check error (e) to determine the appropriate error message. HTTPStatusMessageFunc func(e error, c *gin.Context) string // Private key file for asymmetric algorithms PrivKeyFile string // Public key file for asymmetric algorithms PubKeyFile string // Private key privKey *rsa.PrivateKey // Public key pubKey *rsa.PublicKey // Optionally return the token as a cookie SendCookie bool // Allow insecure cookies for development over http SecureCookie bool // Allow cookies to be accessed client side for development CookieHTTPOnly bool // Allow cookie domain change for development CookieDomain string // SendAuthorization allow return authorization header for every request SendAuthorization bool // Disable abort() of context. DisabledAbort bool // CookieName allow cookie name change for development CookieName string } var ( // ErrMissingSecretKey 需要密钥 ErrMissingSecretKey = errors.New("secret key is required") // ErrForbidden when HTTP status 403 ErrForbidden = errors.New("you don't have permission to access this resource") // ErrMissingAuthenticatorFunc indicates Authenticator is required ErrMissingAuthenticatorFunc = errors.New("ginJWTMiddleware.Authenticator func is undefined") // ErrMissingLoginValues indicates a user tried to authenticate without username or password ErrMissingLoginValues = errors.New("missing Username or Password or Code") // ErrFailedAuthentication indicates authentication failed, could be faulty username or password ErrFailedAuthentication = errors.New("incorrect Username or Password") // ErrFailedTokenCreation indicates JWT Token failed to create, reason unknown ErrFailedTokenCreation = errors.New("failed to create JWT Token") // ErrExpiredToken indicates JWT token has expired. Can't refresh. ErrExpiredToken = errors.New("token is expired") // ErrEmptyAuthHeader can be thrown if authing with a HTTP header, the Auth header needs to be set ErrEmptyAuthHeader = errors.New("auth header is empty") // ErrMissingExpField missing exp field in token ErrMissingExpField = errors.New("missing exp field") // ErrWrongFormatOfExp field must be float64 format ErrWrongFormatOfExp = errors.New("exp must be float64 format") // ErrInvalidAuthHeader indicates auth header is invalid, could for example have the wrong Realm name ErrInvalidAuthHeader = errors.New("auth header is invalid") // ErrEmptyQueryToken can be thrown if authing with URL Query, the query token variable is empty ErrEmptyQueryToken = errors.New("query token is empty") // ErrEmptyCookieToken can be thrown if authing with a cookie, the token cokie is empty ErrEmptyCookieToken = errors.New("cookie token is empty") // ErrEmptyParamToken can be thrown if authing with parameter in path, the parameter in path is empty ErrEmptyParamToken = errors.New("parameter token is empty") // ErrInvalidSigningAlgorithm indicates signing algorithm is invalid, needs to be HS256, HS384, HS512, RS256, RS384 or RS512 ErrInvalidSigningAlgorithm = errors.New("invalid signing algorithm") ErrInvalidVerificationode = errors.New("验证码错误") // ErrNoPrivKeyFile indicates that the given private key is unreadable ErrNoPrivKeyFile = errors.New("private key file unreadable") // ErrNoPubKeyFile indicates that the given public key is unreadable ErrNoPubKeyFile = errors.New("public key file unreadable") // ErrInvalidPrivKey indicates that the given private key is invalid ErrInvalidPrivKey = errors.New("private key invalid") // ErrInvalidPubKey indicates the the given public key is invalid ErrInvalidPubKey = errors.New("public key invalid") // IdentityKey default identity key IdentityKey = "identity" UserIdKey = "userid" // UserNameKey 用戶名 UserNameKey = "username" NiceKey = "nice" DataScopeKey = "datascope" RKey = "r" // RoleIdKey 角色id Old RoleIdKey = "roleid" // RoleKey 角色名称 Old RoleKey = "rolekey" // RoleNameKey 角色名称 Old RoleNameKey = "rolename" // OrgIdKey 组织机构id OrgIdKey = "orgid" // OrgNameKey 组织机构名称 OrgNameKey = "orgname" ) // New for check error with GinJWTMiddleware func New(mw *GinJWTMiddleware) (*GinJWTMiddleware, error) { if err := mw.MiddlewareInit(); err != nil { return nil, err } return mw, nil } func (mw *GinJWTMiddleware) readKeys() error { err := mw.privateKey() if err != nil { return err } err = mw.publicKey() if err != nil { return err } return nil } func (mw *GinJWTMiddleware) privateKey() error { keyData, err := os.ReadFile(mw.PrivKeyFile) if err != nil { return ErrNoPrivKeyFile } key, err := jwt.ParseRSAPrivateKeyFromPEM(keyData) if err != nil { return ErrInvalidPrivKey } mw.privKey = key return nil } func (mw *GinJWTMiddleware) publicKey() error { keyData, err := os.ReadFile(mw.PubKeyFile) if err != nil { return ErrNoPubKeyFile } key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) if err != nil { return ErrInvalidPubKey } mw.pubKey = key return nil } func (mw *GinJWTMiddleware) usingPublicKeyAlgo() bool { switch mw.SigningAlgorithm { case "RS256", "RS512", "RS384": return true } return false } // MiddlewareInit initialize jwt configs. func (mw *GinJWTMiddleware) MiddlewareInit() error { if mw.TokenLookup == "" { mw.TokenLookup = "header:Authorization" } if mw.SigningAlgorithm == "" { mw.SigningAlgorithm = "HS256" } if mw.TimeFunc == nil { mw.TimeFunc = time.Now } mw.TokenHeadName = strings.TrimSpace(mw.TokenHeadName) if len(mw.TokenHeadName) == 0 { mw.TokenHeadName = "Bearer" } if mw.Authorizator == nil { mw.Authorizator = func(data interface{}, c *gin.Context) bool { return true } } if mw.Unauthorized == nil { mw.Unauthorized = func(c *gin.Context, code int, message string) { c.JSON(http.StatusOK, gin.H{ "code": code, "message": message, }) } } if mw.LoginResponse == nil { mw.LoginResponse = func(c *gin.Context, code int, token string, expire time.Time) { c.JSON(http.StatusOK, gin.H{ "code": http.StatusOK, "token": token, "expire": expire.Format(time.RFC3339), }) } } if mw.AntdLoginResponse == nil { mw.AntdLoginResponse = func(c *gin.Context, code int, token string, expire time.Time) { c.JSON(http.StatusOK, gin.H{ "code": http.StatusOK, "success": true, "token": token, "currentAuthority": token, "expire": expire.Format(time.RFC3339), }) } } if mw.RefreshResponse == nil { mw.RefreshResponse = func(c *gin.Context, code int, token string, expire time.Time) { c.JSON(http.StatusOK, gin.H{ "code": http.StatusOK, "token": token, "expire": expire.Format(time.RFC3339), }) } } if mw.IdentityKey == "" { mw.IdentityKey = IdentityKey } if mw.IdentityHandler == nil { mw.IdentityHandler = func(c *gin.Context) interface{} { claims := ExtractClaims(c) return claims } } if mw.HTTPStatusMessageFunc == nil { mw.HTTPStatusMessageFunc = func(e error, c *gin.Context) string { return e.Error() } } if mw.Realm == "" { mw.Realm = "gin vb jwt" } if mw.CookieName == "" { mw.CookieName = "IotAdmin jwt" } if mw.usingPublicKeyAlgo() { return mw.readKeys() } if mw.Key == nil { return ErrMissingSecretKey } return nil } // MiddlewareFunc makes GinJWTMiddleware implement the Middleware interface. func (mw *GinJWTMiddleware) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { mw.middlewareImpl(c) } } func (mw *GinJWTMiddleware) middlewareImpl(c *gin.Context) { claims, err := mw.GetClaimsFromJWT(c) if err != nil { mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, c)) return } exp, err := claims.Exp() if err != nil { mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(err, c)) return } if exp < mw.TimeFunc().Unix() { mw.unauthorized(c, 6401, mw.HTTPStatusMessageFunc(ErrExpiredToken, c)) return } c.Set(JwtPayloadKey, claims) identity := mw.IdentityHandler(c) if identity != nil { c.Set(mw.IdentityKey, identity) } if !mw.Authorizator(identity, c) { mw.unauthorized(c, http.StatusForbidden, mw.HTTPStatusMessageFunc(ErrForbidden, c)) return } c.Next() } // GetClaimsFromJWT get claims from JWT token func (mw *GinJWTMiddleware) GetClaimsFromJWT(c *gin.Context) (MapClaims, error) { token, err := mw.ParseToken(c) if err != nil { return nil, err } if mw.SendAuthorization { if v, ok := c.Get(JwtTokenKey); ok { c.Header("Authorization", mw.TokenHeadName+" "+v.(string)) } } return MapClaims(token.Claims.(jwt.MapClaims)), nil } // LoginHandler can be used by clients to get a jwt token. // Payload needs to be json in the form of {"username": "USERNAME", "password": "PASSWORD"}. // Reply will be of the form {"token": "TOKEN"}. func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) { if mw.Authenticator == nil { mw.unauthorized(c, http.StatusInternalServerError, mw.HTTPStatusMessageFunc(ErrMissingAuthenticatorFunc, c)) return } data, err := mw.Authenticator(c) if err != nil { mw.unauthorized(c, 400, mw.HTTPStatusMessageFunc(err, c)) return } // Create the token token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm)) claims := token.Claims.(jwt.MapClaims) if mw.PayloadFunc != nil { for key, value := range mw.PayloadFunc(data) { claims[key] = value } } expire := mw.TimeFunc().Add(mw.Timeout) claims["exp"] = expire.Unix() claims["orig_iat"] = mw.TimeFunc().Unix() tokenString, err := mw.signedString(token) if err != nil { mw.unauthorized(c, http.StatusOK, mw.HTTPStatusMessageFunc(ErrFailedTokenCreation, c)) return } // set cookie if mw.SendCookie { maxage := int(expire.Unix() - time.Now().Unix()) c.SetCookie( mw.CookieName, tokenString, maxage, "/", mw.CookieDomain, mw.SecureCookie, mw.CookieHTTPOnly, ) } mw.AntdLoginResponse(c, http.StatusOK, tokenString, expire) } func (mw *GinJWTMiddleware) signedString(token *jwt.Token) (string, error) { var tokenString string var err error if mw.usingPublicKeyAlgo() { tokenString, err = token.SignedString(mw.privKey) } else { tokenString, err = token.SignedString(mw.Key) } return tokenString, err } // RefreshHandler can be used to refresh a token. The token still needs to be valid on refresh. // Shall be put under an endpoint that is using the GinJWTMiddleware. // Reply will be of the form {"token": "TOKEN"}. func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) { tokenString, expire, err := mw.RefreshToken(c) if err != nil { mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, c)) return } mw.RefreshResponse(c, http.StatusOK, tokenString, expire) } // RefreshToken refresh token and check if token is expired func (mw *GinJWTMiddleware) RefreshToken(c *gin.Context) (string, time.Time, error) { claims, err := mw.CheckIfTokenExpire(c) if err != nil { return "", time.Now(), err } // Create the token newToken := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm)) newClaims := newToken.Claims.(jwt.MapClaims) for key := range claims { newClaims[key] = claims[key] } expire := mw.TimeFunc().Add(mw.Timeout) newClaims["exp"] = expire.Unix() newClaims["orig_iat"] = mw.TimeFunc().Unix() tokenString, err := mw.signedString(newToken) if err != nil { return "", time.Now(), err } // set cookie if mw.SendCookie { maxage := int(expire.Unix() - time.Now().Unix()) c.SetCookie( mw.CookieName, tokenString, maxage, "/", mw.CookieDomain, mw.SecureCookie, mw.CookieHTTPOnly, ) } return tokenString, expire, nil } // CheckIfTokenExpire check if token expire func (mw *GinJWTMiddleware) CheckIfTokenExpire(c *gin.Context) (jwt.MapClaims, error) { token, err := mw.ParseToken(c) if err != nil { // If we receive an error, and the error is anything other than a single // ValidationErrorExpired, we want to return the error. // If the error is just ValidationErrorExpired, we want to continue, as we can still // refresh the token if it's within the MaxRefresh time. // (see https://github.com/appleboy/gin-jwt/issues/176) validationErr, ok := err.(*jwt.ValidationError) if !ok || validationErr.Errors != jwt.ValidationErrorExpired { return nil, err } } claims := MapClaims(token.Claims.(jwt.MapClaims)) origIat, err := claims.OrigIat() if err != nil { return nil, err } if origIat < mw.TimeFunc().Add(-mw.MaxRefresh).Unix() { return nil, ErrExpiredToken } return token.Claims.(jwt.MapClaims), nil } // TokenGenerator method that clients can use to get a jwt token. func (mw *GinJWTMiddleware) TokenGenerator(data interface{}) (string, time.Time, error) { token := jwt.New(jwt.GetSigningMethod(mw.SigningAlgorithm)) claims := token.Claims.(jwt.MapClaims) if mw.PayloadFunc != nil { for key, value := range mw.PayloadFunc(data) { claims[key] = value } } expire := mw.TimeFunc().UTC().Add(mw.Timeout) claims["exp"] = expire.Unix() claims["orig_iat"] = mw.TimeFunc().Unix() tokenString, err := mw.signedString(token) if err != nil { return "", time.Time{}, err } return tokenString, expire, nil } func (mw *GinJWTMiddleware) jwtFromHeader(c *gin.Context, key string) (string, error) { authHeader := c.Request.Header.Get(key) if authHeader == "" { return "", ErrEmptyAuthHeader } parts := strings.SplitN(authHeader, " ", 2) if !(len(parts) == 2 && parts[0] == mw.TokenHeadName) { return "", ErrInvalidAuthHeader } return parts[1], nil } func (mw *GinJWTMiddleware) jwtFromQuery(c *gin.Context, key string) (string, error) { token := c.Query(key) if token == "" { return "", ErrEmptyQueryToken } return token, nil } func (mw *GinJWTMiddleware) jwtFromCookie(c *gin.Context, key string) (string, error) { cookie, _ := c.Cookie(key) if cookie == "" { return "", ErrEmptyCookieToken } return cookie, nil } func (mw *GinJWTMiddleware) jwtFromParam(c *gin.Context, key string) (string, error) { token := c.Param(key) if token == "" { return "", ErrEmptyParamToken } return token, nil } // ParseToken parse jwt token from gin context func (mw *GinJWTMiddleware) ParseToken(c *gin.Context) (*jwt.Token, error) { var token string var err error methods := strings.Split(mw.TokenLookup, ",") for _, method := range methods { if len(token) > 0 { break } parts := strings.Split(strings.TrimSpace(method), ":") k := strings.TrimSpace(parts[0]) v := strings.TrimSpace(parts[1]) switch k { case "header": token, err = mw.jwtFromHeader(c, v) case "query": token, err = mw.jwtFromQuery(c, v) case "cookie": token, err = mw.jwtFromCookie(c, v) case "param": token, err = mw.jwtFromParam(c, v) } } if err != nil { return nil, err } return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { if jwt.GetSigningMethod(mw.SigningAlgorithm) != t.Method { return nil, ErrInvalidSigningAlgorithm } if mw.usingPublicKeyAlgo() { return mw.pubKey, nil } c.Set(JwtTokenKey, token) return mw.Key, nil }, jwt.WithJSONNumber()) } // ParseTokenString 解析jwt令牌 func (mw *GinJWTMiddleware) ParseTokenString(token string) (*jwt.Token, error) { return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { if jwt.GetSigningMethod(mw.SigningAlgorithm) != t.Method { return nil, ErrInvalidSigningAlgorithm } if mw.usingPublicKeyAlgo() { return mw.pubKey, nil } return mw.Key, nil }) } func (mw *GinJWTMiddleware) unauthorized(c *gin.Context, code int, message string) { c.Header("WWW-Authenticate", "JWT realm="+mw.Realm) if !mw.DisabledAbort { c.Abort() } mw.Unauthorized(c, code, message) } // ExtractClaims 获取JWT声明 func ExtractClaims(c *gin.Context) MapClaims { claims, exists := c.Get(JwtPayloadKey) if !exists { return make(MapClaims) } return claims.(MapClaims) } // ExtractClaimsFromToken 从令牌中获取JWT声明 func ExtractClaimsFromToken(token *jwt.Token) MapClaims { if token == nil { return make(MapClaims) } claims := MapClaims{} for key, value := range token.Claims.(jwt.MapClaims) { claims[key] = value } return claims } // GetToken 获取JWT令牌 func GetToken(c *gin.Context) string { token, exists := c.Get(JwtTokenKey) if !exists { return "" } return token.(string) }