package tcpserver import ( "IotAdmin/core/logger" "bufio" "crypto/tls" "crypto/x509" "fmt" "log" "net" "os" "sync" "time" ) var ( chanBuffSize = 32 ) type RestResult struct { Success bool `json:"success"` Message string `json:"message"` Data string `json:"data"` Stream []byte `json:"stream"` } type Client struct { conn net.Conn //RD_chan chan []byte //WR_chan chan []byte //Exit_chan chan error // 异常退出通道 //channel_rd chan []byte rwLock sync.RWMutex channelWr chan []byte channelErr chan int rdChanFlag bool //读取buffRawRes通道标志位,true表示有线程等待读通道 gRunFlag bool //协程允许标记 //towr chan []byte tord chan []byte clientId string //客户端ID(SN) clientAddr string //客户端地址 buffRestRes chan RestResult buffRawRes chan []byte //原始数据 Server *Server restCnt int } type Server struct { address string config *tls.Config onNewClientCallback func(c *Client) onClientConnectionClosed func(c *Client, err error) onNewMessage func(c *Client, msg []byte) onMsgClientCallback func(c *Client, msg []byte) bool onMsgClientCallbackDefine func(c *Client, msg []byte) bool } // LockClient 锁定tcp通道,与解锁成对使用 func (c *Client) LockClient() { c.rwLock.Lock() } // UnLockClient 解锁tcp通道 func (c *Client) UnLockClient() { c.rwLock.Unlock() } func (c *Client) SetWaitRawRes() { c.rdChanFlag = true } func (c *Client) ClrWaitRawRes() { c.rdChanFlag = false } func (c *Client) GetWaitRawRes() bool { return c.rdChanFlag } // PutRestRes 以json格式推入 func (c *Client) PutRestRes(data RestResult) { //sync_lock //log.Println("PutRestRes:", data) c.buffRestRes <- data } // PutRawRes 或者以原始格式推入 func (c *Client) PutRawRes(data []byte) { //sync_lock //log.Printf("PutRestRes:%d\r\n", len(data)) c.buffRawRes <- data } // GetRestRes 以json格式获取 func (c *Client) GetRestRes() (RestResult, bool) { //sync_lock var ( ret bool = false res RestResult ) select { case res = <-c.buffRestRes: ret = true //log.Println("GetRestRes:", res) //log.Printf("buffRestRes len:%d", len(c.buffRestRes)) case <-time.After(time.Second * 10): ret = false } c.restCnt++ return res, ret } func (c *Client) GetRestCnt() int { return c.restCnt } // GetRawRes 以原始格式获取 func (c *Client) GetRawRes() ([]byte, bool) { //sync_lock var ( ret bool = false res []byte = make([]byte, 0) //256) ) select { case res = <-c.buffRawRes: ret = true case <-time.After(time.Second * 5): ret = false } c.restCnt++ return res, ret } // ChanWrite 将数据写给User端 func (c *Client) ChanWrite(b []byte) int { //sync_lock c.channelWr <- b c.restCnt++ return c.restCnt } func (c *Client) GetClientHost() string { return c.clientAddr } func (c *Client) GetClientRegisterID() string { return c.clientId } func (c *Client) SetClientRegisterID(id string) { c.clientId = id } // GetGRunFlag 协程运行标记 func (c *Client) GetGRunFlag() bool { return c.gRunFlag } func (c *Client) handleClient() { res := "" cnt := 0 ro := make(chan int) //ctx, cancel := context.WithCancel(context.Background()) //wo := make(chan int) //rexit := make(chan int) //wexit := make(chan int) go c.goRead(c.tord, ro) //go c.handleTaskTimer(ro) //go c.gowrite(c.towr, wo) c.clientAddr = c.Conn().RemoteAddr().String() c.Server.onNewClientCallback(c) c.restCnt = 0 c.gRunFlag = true defer func() { c.closeClient() if res != "ro" { <-ro } }() for { select { case <-time.After(time.Minute * 2): cnt++ if cnt > 2 { res = "over time" return //exit = true } case <-ro: res = "ro" return //exit = true case <-c.channelErr: res = "channel err" return //exit = true case wrData := <-c.channelWr: if _, err := c.conn.Write(wrData); err != nil { // && err != io.EOF { res = "wo" return //exit = true } case rData, ok := <-c.tord: if ok { cnt = 0 c.Server.onNewMessage(c, rData) } } } } func (c *Client) closeClient() { var err error defer func(conn net.Conn) { _ = conn.Close() }(c.conn) c.gRunFlag = false c.Server.onClientConnectionClosed(c, err) } func (c *Client) ExitClient() { c.channelErr <- 0 //errors.New("exit") } func (c *Client) goRead(buff chan<- []byte, out chan<- int) { size := (int)(32 * 1024) data := make([]byte, size) reader := bufio.NewReader(c.conn) for { n, err := reader.Read(data) //c.conn.Read(data) // if err != nil { close(buff) out <- 0 return } else { if n > 0 && n < size-1 { buff <- data[:n] } } } } func SendTo(cli *Client, date []byte) (res []byte, ok bool) { var ( en = false ) ok = true cli.LockClient() cnt := cli.ChanWrite(date) select { case <-time.After(time.Millisecond * 800): } for { res, ok = cli.GetRawRes() nCnt := cli.GetRestCnt() if !ok { break } if nCnt == cnt+1 { break } if nCnt > cnt+1 { ok = false break } if en { break } en = true } cli.UnLockClient() return } func (c *Client) Send(msg string) error { _, err := c.conn.Write([]byte(msg)) return err } func (c *Client) SendBytes(msg []byte) error { _, err := c.conn.Write(msg) return err } func (c *Client) Conn() net.Conn { return c.conn } func (c *Client) Close() error { return c.conn.Close() } func (s *Server) OnNewClient(callback func(c *Client)) { s.onNewClientCallback = callback } func (s *Server) OnClientConnectionClosed(callback func(c *Client, err error)) { s.onClientConnectionClosed = callback } func (s *Server) OnNewMessage(callback func(c *Client, msg []byte)) { s.onNewMessage = callback } func (s *Server) OnMsgClientCallbackDefine(callback func(c *Client, msg []byte) bool) { s.onMsgClientCallbackDefine = callback } func (s *Server) OnMsgClientCallback(callback func(c *Client, msg []byte) bool) { s.onMsgClientCallback = callback } func (s *Server) Listen() { var listener net.Listener var err error if s.config == nil { listener, err = net.Listen("tcp", s.address) } else { listener, err = tls.Listen("tcp", s.address, s.config) } if err != nil { log.Fatal("Error starting TCP server") } defer func(listener net.Listener) { _ = listener.Close() }(listener) for { conn, _ := listener.Accept() if tcpConn, ok := conn.(*net.TCPConn); ok { if err := tcpConn.SetKeepAlive(false); err != nil { //fmt.Println("close keepalive fail") } else { //fmt.Println("close keepalive ok") } } Client := &Client{ conn: conn, Server: s, //towr: make(chan []byte), tord: make(chan []byte), channelWr: make(chan []byte, chanBuffSize), //channel_rd: make(chan []byte, chanBuffSize), channelErr: make(chan int), buffRestRes: make(chan RestResult, chanBuffSize), buffRawRes: make(chan []byte, chanBuffSize), clientId: "", clientAddr: "", rdChanFlag: false, gRunFlag: false, //reConn: make(chan bool), } go Client.handleClient() //go Client.listen() } } func New(address string) *Server { logger.Infof("创建TCP服务,端口 [%s]", address) server := &Server{ address: address, config: nil, } server.OnNewClient(func(c *Client) {}) server.OnNewMessage(func(c *Client, msg []byte) {}) server.OnClientConnectionClosed(func(c *Client, err error) {}) return server } func NewWithTLS(address, certFile, keyFile string) *Server { logger.Infof("创建TCP服务,端口 [%s]", address) conf, err := serverTLSConf(certFile, keyFile) // sconf.ClientAuth = tls.RequireAndVerifyClientCert if err != nil { fmt.Println("创建TCP服务失败", err) return nil } //conf.MaxVersion = tls.VersionTLS12 conf.BuildNameToCertificate() server := &Server{ address: address, config: conf, } server.OnNewClient(func(c *Client) {}) server.OnNewMessage(func(c *Client, msg []byte) {}) server.OnClientConnectionClosed(func(c *Client, err error) {}) return server } func serverTLSConf(certFile, keyFile string) (*tls.Config, error) { cacert, _ := os.ReadFile("./cert/ca.crt") pool := x509.NewCertPool() pool.AppendCertsFromPEM(cacert) tlsConf := new(tls.Config) //tlsConf.PreferServerCipherSuites = true tlsConf.ClientCAs = pool tlsConf.ClientAuth = tls.NoClientCert //RequireAndVerifyClientCert // support http2 //tlsConf.NextProtos = append(tlsConf.NextProtos, "h2", "http/1.1") // 准备证书 tlsConf.Certificates = make([]tls.Certificate, 1) var err error tlsConf.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } // tlsConf.KeyLogWriter = handsh.KeyLog("server") return tlsConf, nil }