package client import ( "context" "encoding/json" "fmt" "log" "net/url" "sync" "time" "github.com/duffy/usb-server/internal/config" "github.com/duffy/usb-server/internal/protocol" "github.com/google/uuid" "github.com/gorilla/websocket" ) // Client manages the connection to the relay server type Client struct { cfg *config.Config clientID string conn *websocket.Conn mu sync.Mutex // Event callbacks OnDeviceList func(msg *protocol.DeviceList) OnDeviceGranted func(msg *protocol.DeviceGranted) OnDeviceDenied func(msg *protocol.DeviceDenied) OnDeviceReleased func(msg *protocol.DeviceReleased) OnClientJoined func(msg *protocol.ClientJoined) OnClientLeft func(msg *protocol.ClientLeft) OnRequestDevice func(targetClient, fromClient, busID, requestID string) OnReleaseDevice func(busID, fromClient string) OnTunnelData func(tunnelID string, data []byte) ctx context.Context cancel context.CancelFunc } // NewClient creates a new client instance func NewClient(cfg *config.Config) *Client { ctx, cancel := context.WithCancel(context.Background()) return &Client{ cfg: cfg, clientID: uuid.New().String(), ctx: ctx, cancel: cancel, } } // ID returns the client ID func (c *Client) ID() string { return c.clientID } // Config returns the client config func (c *Client) Config() *config.Config { return c.cfg } // Connect establishes connection to the relay server func (c *Client) Connect() error { u, err := url.Parse(c.cfg.RelayAddr) if err != nil { return fmt.Errorf("invalid relay address: %w", err) } // Ensure WebSocket scheme switch u.Scheme { case "ws", "wss": // ok case "http": u.Scheme = "ws" case "https": u.Scheme = "wss" default: u.Scheme = "ws" } if u.Path == "" { u.Path = "/ws" } log.Printf("[client] connecting to %s", u.String()) conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { return fmt.Errorf("connecting to relay: %w", err) } c.mu.Lock() c.conn = conn c.mu.Unlock() // Send registration reg := &protocol.Register{ Type: protocol.MsgRegister, Hash: c.cfg.Hash, Mode: c.cfg.Mode, ClientID: c.clientID, Name: c.cfg.Name, } if err := conn.WriteJSON(reg); err != nil { conn.Close() return fmt.Errorf("sending registration: %w", err) } log.Printf("[client] registered as %s (mode=%s, name=%s)", c.clientID, c.cfg.Mode, c.cfg.Name) return nil } // RunReadLoop reads messages from the relay and dispatches them func (c *Client) RunReadLoop() error { for { select { case <-c.ctx.Done(): return nil default: } msgType, data, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { return fmt.Errorf("read error: %w", err) } return nil } switch msgType { case websocket.TextMessage: c.handleTextMessage(data) case websocket.BinaryMessage: c.handleBinaryMessage(data) } } } // Run connects and runs the main loop with auto-reconnect func (c *Client) Run() error { for { if err := c.Connect(); err != nil { log.Printf("[client] connection failed: %v, retrying in 5s...", err) select { case <-time.After(5 * time.Second): continue case <-c.ctx.Done(): return nil } } err := c.RunReadLoop() if err != nil { log.Printf("[client] disconnected: %v, reconnecting in 5s...", err) } else { log.Printf("[client] disconnected, reconnecting in 5s...") } c.mu.Lock() if c.conn != nil { c.conn.Close() c.conn = nil } c.mu.Unlock() select { case <-time.After(5 * time.Second): case <-c.ctx.Done(): return nil } } } // Close shuts down the client func (c *Client) Close() { c.cancel() c.mu.Lock() if c.conn != nil { c.conn.Close() } c.mu.Unlock() } // SendJSON sends a JSON message to the relay func (c *Client) SendJSON(v interface{}) error { c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { return fmt.Errorf("not connected") } return c.conn.WriteJSON(v) } // SendBinary sends a binary message to the relay func (c *Client) SendBinary(data []byte) error { c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { return fmt.Errorf("not connected") } return c.conn.WriteMessage(websocket.BinaryMessage, data) } // SendTunnelData sends tunnel data with the tunnel ID prefix func (c *Client) SendTunnelData(tunnelID string, data []byte) error { // Tunnel header: 16 bytes tunnel ID + payload msg := make([]byte, protocol.TunnelHeaderSize+len(data)) copy(msg[:protocol.TunnelHeaderSize], tunnelID) copy(msg[protocol.TunnelHeaderSize:], data) return c.SendBinary(msg) } func (c *Client) handleTextMessage(data []byte) { var env protocol.Envelope if err := json.Unmarshal(data, &env); err != nil { return } switch env.Type { case protocol.MsgDeviceList: if c.OnDeviceList != nil { var msg protocol.DeviceList if json.Unmarshal(data, &msg) == nil { c.OnDeviceList(&msg) } } case protocol.MsgRequestDevice: if c.OnRequestDevice != nil { var msg struct { TargetClient string `json:"target_client"` FromClient string `json:"from_client"` BusID string `json:"bus_id"` RequestID string `json:"request_id"` } if json.Unmarshal(data, &msg) == nil { c.OnRequestDevice(msg.TargetClient, msg.FromClient, msg.BusID, msg.RequestID) } } case protocol.MsgDeviceGranted: if c.OnDeviceGranted != nil { var msg protocol.DeviceGranted if json.Unmarshal(data, &msg) == nil { c.OnDeviceGranted(&msg) } } case protocol.MsgDeviceDenied: if c.OnDeviceDenied != nil { var msg protocol.DeviceDenied if json.Unmarshal(data, &msg) == nil { c.OnDeviceDenied(&msg) } } case protocol.MsgReleaseDevice: if c.OnReleaseDevice != nil { var msg struct { BusID string `json:"bus_id"` FromClient string `json:"from_client"` } if json.Unmarshal(data, &msg) == nil { c.OnReleaseDevice(msg.BusID, msg.FromClient) } } case protocol.MsgDeviceReleased: if c.OnDeviceReleased != nil { var msg protocol.DeviceReleased if json.Unmarshal(data, &msg) == nil { c.OnDeviceReleased(&msg) } } case protocol.MsgClientJoined: if c.OnClientJoined != nil { var msg protocol.ClientJoined if json.Unmarshal(data, &msg) == nil { c.OnClientJoined(&msg) } } case protocol.MsgClientLeft: if c.OnClientLeft != nil { var msg protocol.ClientLeft if json.Unmarshal(data, &msg) == nil { c.OnClientLeft(&msg) } } case protocol.MsgPong: // ignore pong case protocol.MsgError: var msg protocol.ErrorMsg if json.Unmarshal(data, &msg) == nil { log.Printf("[client] error from relay: %s", msg.Message) } } } func (c *Client) handleBinaryMessage(data []byte) { if len(data) < protocol.TunnelHeaderSize { return } tunnelID := string(data[:protocol.TunnelHeaderSize]) payload := data[protocol.TunnelHeaderSize:] if c.OnTunnelData != nil { c.OnTunnelData(tunnelID, payload) } }