package relay import ( "encoding/json" "log" "sync" "github.com/duffy/usb-server/internal/protocol" "github.com/gorilla/websocket" ) // Client represents a connected WebSocket client type Client struct { ID string Hash string Mode string // "share" or "use" Name string Conn *websocket.Conn Send chan []byte // buffered channel for outgoing messages mu sync.Mutex } // WriteJSON sends a JSON message to the client func (c *Client) WriteJSON(v interface{}) error { c.mu.Lock() defer c.mu.Unlock() return c.Conn.WriteJSON(v) } // WriteBinary sends a binary message to the client func (c *Client) WriteBinary(data []byte) error { c.mu.Lock() defer c.mu.Unlock() return c.Conn.WriteMessage(websocket.BinaryMessage, data) } // Hub manages all connected clients and routes messages between them type Hub struct { mu sync.RWMutex groups map[string]map[string]*Client // hash -> client_id -> client tunnels map[string]*Tunnel // tunnel_id -> tunnel info } // Tunnel tracks an active USB/IP tunnel between two clients type Tunnel struct { ID string ShareClient string UseClient string BusID string } // NewHub creates a new Hub func NewHub() *Hub { return &Hub{ groups: make(map[string]map[string]*Client), tunnels: make(map[string]*Tunnel), } } // Register adds a client to its hash group func (h *Hub) Register(client *Client) { h.mu.Lock() defer h.mu.Unlock() if h.groups[client.Hash] == nil { h.groups[client.Hash] = make(map[string]*Client) } h.groups[client.Hash][client.ID] = client log.Printf("[hub] client registered: id=%s hash=%s..%s mode=%s name=%s", client.ID, client.Hash[:8], client.Hash[len(client.Hash)-4:], client.Mode, client.Name) // Notify other clients in the group h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientJoined{ Type: protocol.MsgClientJoined, ClientID: client.ID, Mode: client.Mode, Name: client.Name, }) } // Unregister removes a client and cleans up its tunnels func (h *Hub) Unregister(client *Client) { h.mu.Lock() defer h.mu.Unlock() group := h.groups[client.Hash] if group == nil { return } delete(group, client.ID) if len(group) == 0 { delete(h.groups, client.Hash) } // Clean up tunnels involving this client for tid, tunnel := range h.tunnels { if tunnel.ShareClient == client.ID || tunnel.UseClient == client.ID { delete(h.tunnels, tid) } } log.Printf("[hub] client unregistered: id=%s name=%s", client.ID, client.Name) // Notify others h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientLeft{ Type: protocol.MsgClientLeft, ClientID: client.ID, }) } // HandleTextMessage processes a JSON control message func (h *Hub) HandleTextMessage(sender *Client, data []byte) { var env protocol.Envelope if err := json.Unmarshal(data, &env); err != nil { log.Printf("[hub] invalid message from %s: %v", sender.ID, err) return } switch env.Type { case protocol.MsgDeviceList: h.handleDeviceList(sender, data) case protocol.MsgRequestDevice: h.handleRequestDevice(sender, data) case protocol.MsgDeviceGranted: h.handleDeviceGranted(sender, data) case protocol.MsgDeviceDenied: h.handleDeviceDenied(sender, data) case protocol.MsgReleaseDevice: h.handleReleaseDevice(sender, data) case protocol.MsgDeviceReleased: h.handleDeviceReleased(sender, data) case protocol.MsgPing: sender.WriteJSON(&protocol.Pong{Type: protocol.MsgPong}) default: log.Printf("[hub] unknown message type from %s: %s", sender.ID, env.Type) } } // HandleBinaryMessage forwards tunnel data to the other end func (h *Hub) HandleBinaryMessage(sender *Client, data []byte) { if len(data) < protocol.TunnelHeaderSize { return } tunnelID := string(data[:protocol.TunnelHeaderSize]) h.mu.RLock() tunnel := h.tunnels[tunnelID] h.mu.RUnlock() if tunnel == nil { return } // Forward to the other end of the tunnel var targetID string if sender.ID == tunnel.ShareClient { targetID = tunnel.UseClient } else if sender.ID == tunnel.UseClient { targetID = tunnel.ShareClient } else { return } h.mu.RLock() group := h.groups[sender.Hash] if group != nil { if target := group[targetID]; target != nil { target.WriteBinary(data) } } h.mu.RUnlock() } // handleDeviceList broadcasts device list from share client to all use clients func (h *Hub) handleDeviceList(sender *Client, data []byte) { if sender.Mode != protocol.ModeShare { return } h.mu.RLock() group := h.groups[sender.Hash] for _, client := range group { if client.ID != sender.ID && client.Mode == protocol.ModeUse { client.mu.Lock() client.Conn.WriteMessage(websocket.TextMessage, data) client.mu.Unlock() } } h.mu.RUnlock() } // handleRequestDevice forwards a device request to the target share client func (h *Hub) handleRequestDevice(sender *Client, data []byte) { var msg protocol.RequestDevice if err := json.Unmarshal(data, &msg); err != nil { return } h.mu.RLock() group := h.groups[sender.Hash] if group != nil { if target := group[msg.TargetClient]; target != nil && target.Mode == protocol.ModeShare { // Add the sender's ID so the share client knows who's requesting enriched := map[string]interface{}{ "type": protocol.MsgRequestDevice, "target_client": msg.TargetClient, "bus_id": msg.BusID, "request_id": msg.RequestID, "from_client": sender.ID, } target.WriteJSON(enriched) } } h.mu.RUnlock() } // handleDeviceGranted registers the tunnel and forwards to the requesting client func (h *Hub) handleDeviceGranted(sender *Client, data []byte) { var granted struct { protocol.DeviceGranted TargetClient string `json:"target_client"` } if err := json.Unmarshal(data, &granted); err != nil { return } // Register tunnel h.mu.Lock() h.tunnels[granted.TunnelID] = &Tunnel{ ID: granted.TunnelID, ShareClient: sender.ID, UseClient: granted.TargetClient, BusID: granted.BusID, } h.mu.Unlock() log.Printf("[hub] tunnel created: %s (share=%s, use=%s, device=%s)", granted.TunnelID, sender.ID, granted.TargetClient, granted.BusID) // Forward to use client h.mu.RLock() group := h.groups[sender.Hash] if group != nil { if target := group[granted.TargetClient]; target != nil { target.mu.Lock() target.Conn.WriteMessage(websocket.TextMessage, data) target.mu.Unlock() } } h.mu.RUnlock() } // handleDeviceDenied forwards denial to the requesting client func (h *Hub) handleDeviceDenied(sender *Client, data []byte) { var denied struct { protocol.DeviceDenied TargetClient string `json:"target_client"` } if err := json.Unmarshal(data, &denied); err != nil { return } h.mu.RLock() group := h.groups[sender.Hash] if group != nil { if target := group[denied.TargetClient]; target != nil { target.mu.Lock() target.Conn.WriteMessage(websocket.TextMessage, data) target.mu.Unlock() } } h.mu.RUnlock() } // handleReleaseDevice forwards a release to the share client func (h *Hub) handleReleaseDevice(sender *Client, data []byte) { var msg protocol.ReleaseDevice if err := json.Unmarshal(data, &msg); err != nil { return } // Clean up tunnel h.mu.Lock() for tid, tunnel := range h.tunnels { if tunnel.UseClient == sender.ID && tunnel.BusID == msg.BusID { delete(h.tunnels, tid) log.Printf("[hub] tunnel closed: %s", tid) break } } h.mu.Unlock() // Forward to share client h.mu.RLock() group := h.groups[sender.Hash] if group != nil { if target := group[msg.TargetClient]; target != nil { enriched := map[string]interface{}{ "type": protocol.MsgReleaseDevice, "target_client": msg.TargetClient, "bus_id": msg.BusID, "from_client": sender.ID, } target.WriteJSON(enriched) } } h.mu.RUnlock() } // handleDeviceReleased broadcasts device released notification func (h *Hub) handleDeviceReleased(sender *Client, data []byte) { h.mu.RLock() group := h.groups[sender.Hash] for _, client := range group { if client.ID != sender.ID && client.Mode == protocol.ModeUse { client.mu.Lock() client.Conn.WriteMessage(websocket.TextMessage, data) client.mu.Unlock() } } h.mu.RUnlock() } // broadcastToGroup sends a message to all clients in a hash group except the sender func (h *Hub) broadcastToGroup(hash, excludeID string, msg interface{}) { group := h.groups[hash] for _, client := range group { if client.ID != excludeID { client.WriteJSON(msg) } } }