usb-server/internal/relay/hub.go

374 lines
9.3 KiB
Go

package relay
import (
"encoding/json"
"log"
"sync"
"github.com/duffy/usb-server/internal/protocol"
"github.com/gorilla/websocket"
)
// Client represents a connected WebSocket client
type Client struct {
ID string
Hash string
Mode string // "share" or "use"
Name string
Conn *websocket.Conn
Send chan []byte // buffered channel for outgoing messages
mu sync.Mutex
}
// WriteJSON sends a JSON message to the client
func (c *Client) WriteJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.Conn.WriteJSON(v)
}
// WriteBinary sends a binary message to the client
func (c *Client) WriteBinary(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.Conn.WriteMessage(websocket.BinaryMessage, data)
}
// Hub manages all connected clients and routes messages between them
type Hub struct {
mu sync.RWMutex
groups map[string]map[string]*Client // hash -> client_id -> client
tunnels map[string]*Tunnel // tunnel_id -> tunnel info
}
// Tunnel tracks an active USB/IP tunnel between two clients
type Tunnel struct {
ID string
ShareClient string
UseClient string
BusID string
}
// NewHub creates a new Hub
func NewHub() *Hub {
return &Hub{
groups: make(map[string]map[string]*Client),
tunnels: make(map[string]*Tunnel),
}
}
// Register adds a client to its hash group
func (h *Hub) Register(client *Client) {
h.mu.Lock()
defer h.mu.Unlock()
if h.groups[client.Hash] == nil {
h.groups[client.Hash] = make(map[string]*Client)
}
h.groups[client.Hash][client.ID] = client
log.Printf("[hub] client registered: id=%s hash=%s..%s mode=%s name=%s",
client.ID, client.Hash[:8], client.Hash[len(client.Hash)-4:], client.Mode, client.Name)
// Notify other clients in the group
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientJoined{
Type: protocol.MsgClientJoined,
ClientID: client.ID,
Mode: client.Mode,
Name: client.Name,
})
}
// Unregister removes a client and cleans up its tunnels
func (h *Hub) Unregister(client *Client) {
h.mu.Lock()
defer h.mu.Unlock()
group := h.groups[client.Hash]
if group == nil {
return
}
delete(group, client.ID)
if len(group) == 0 {
delete(h.groups, client.Hash)
}
// Clean up tunnels involving this client
for tid, tunnel := range h.tunnels {
if tunnel.ShareClient == client.ID || tunnel.UseClient == client.ID {
delete(h.tunnels, tid)
}
}
log.Printf("[hub] client unregistered: id=%s name=%s", client.ID, client.Name)
// Notify others
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientLeft{
Type: protocol.MsgClientLeft,
ClientID: client.ID,
})
}
// HandleTextMessage processes a JSON control message
func (h *Hub) HandleTextMessage(sender *Client, data []byte) {
var env protocol.Envelope
if err := json.Unmarshal(data, &env); err != nil {
log.Printf("[hub] invalid message from %s: %v", sender.ID, err)
return
}
switch env.Type {
case protocol.MsgDeviceList:
h.handleDeviceList(sender, data)
case protocol.MsgRequestDevice:
h.handleRequestDevice(sender, data)
case protocol.MsgDeviceGranted:
h.handleDeviceGranted(sender, data)
case protocol.MsgDeviceDenied:
h.handleDeviceDenied(sender, data)
case protocol.MsgForceRelease:
h.handleForceRelease(sender, data)
case protocol.MsgReleaseDevice:
h.handleReleaseDevice(sender, data)
case protocol.MsgDeviceReleased:
h.handleDeviceReleased(sender, data)
case protocol.MsgPing:
sender.WriteJSON(&protocol.Pong{Type: protocol.MsgPong})
default:
log.Printf("[hub] unknown message type from %s: %s", sender.ID, env.Type)
}
}
// HandleBinaryMessage forwards tunnel data to the other end
func (h *Hub) HandleBinaryMessage(sender *Client, data []byte) {
if len(data) < protocol.TunnelHeaderSize {
return
}
tunnelID := string(data[:protocol.TunnelHeaderSize])
h.mu.RLock()
tunnel := h.tunnels[tunnelID]
h.mu.RUnlock()
if tunnel == nil {
return
}
// Forward to the other end of the tunnel
var targetID string
if sender.ID == tunnel.ShareClient {
targetID = tunnel.UseClient
} else if sender.ID == tunnel.UseClient {
targetID = tunnel.ShareClient
} else {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[targetID]; target != nil {
target.WriteBinary(data)
}
}
h.mu.RUnlock()
}
// handleDeviceList broadcasts device list from share client to all use clients
func (h *Hub) handleDeviceList(sender *Client, data []byte) {
if sender.Mode != protocol.ModeShare {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
for _, client := range group {
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
client.mu.Lock()
client.Conn.WriteMessage(websocket.TextMessage, data)
client.mu.Unlock()
}
}
h.mu.RUnlock()
}
// handleRequestDevice forwards a device request to the target share client
func (h *Hub) handleRequestDevice(sender *Client, data []byte) {
var msg protocol.RequestDevice
if err := json.Unmarshal(data, &msg); err != nil {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[msg.TargetClient]; target != nil && target.Mode == protocol.ModeShare {
// Add the sender's ID so the share client knows who's requesting
enriched := map[string]interface{}{
"type": protocol.MsgRequestDevice,
"target_client": msg.TargetClient,
"bus_id": msg.BusID,
"request_id": msg.RequestID,
"from_client": sender.ID,
}
target.WriteJSON(enriched)
}
}
h.mu.RUnlock()
}
// handleDeviceGranted registers the tunnel and forwards to the requesting client
func (h *Hub) handleDeviceGranted(sender *Client, data []byte) {
var granted struct {
protocol.DeviceGranted
TargetClient string `json:"target_client"`
}
if err := json.Unmarshal(data, &granted); err != nil {
return
}
// Register tunnel
h.mu.Lock()
h.tunnels[granted.TunnelID] = &Tunnel{
ID: granted.TunnelID,
ShareClient: sender.ID,
UseClient: granted.TargetClient,
BusID: granted.BusID,
}
h.mu.Unlock()
log.Printf("[hub] tunnel created: %s (share=%s, use=%s, device=%s)",
granted.TunnelID, sender.ID, granted.TargetClient, granted.BusID)
// Forward to use client
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[granted.TargetClient]; target != nil {
target.mu.Lock()
target.Conn.WriteMessage(websocket.TextMessage, data)
target.mu.Unlock()
}
}
h.mu.RUnlock()
}
// handleDeviceDenied forwards denial to the requesting client
func (h *Hub) handleDeviceDenied(sender *Client, data []byte) {
var denied struct {
protocol.DeviceDenied
TargetClient string `json:"target_client"`
}
if err := json.Unmarshal(data, &denied); err != nil {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[denied.TargetClient]; target != nil {
target.mu.Lock()
target.Conn.WriteMessage(websocket.TextMessage, data)
target.mu.Unlock()
}
}
h.mu.RUnlock()
}
// handleReleaseDevice forwards a release to the share client
func (h *Hub) handleReleaseDevice(sender *Client, data []byte) {
var msg protocol.ReleaseDevice
if err := json.Unmarshal(data, &msg); err != nil {
return
}
// Clean up tunnel
h.mu.Lock()
for tid, tunnel := range h.tunnels {
if tunnel.UseClient == sender.ID && tunnel.BusID == msg.BusID {
delete(h.tunnels, tid)
log.Printf("[hub] tunnel closed: %s", tid)
break
}
}
h.mu.Unlock()
// Forward to share client
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[msg.TargetClient]; target != nil {
enriched := map[string]interface{}{
"type": protocol.MsgReleaseDevice,
"target_client": msg.TargetClient,
"bus_id": msg.BusID,
"from_client": sender.ID,
}
target.WriteJSON(enriched)
}
}
h.mu.RUnlock()
}
// handleForceRelease forwards a force-release request to the target share client
func (h *Hub) handleForceRelease(sender *Client, data []byte) {
var msg protocol.ForceRelease
if err := json.Unmarshal(data, &msg); err != nil {
return
}
// Clean up tunnel for this device (by BusID, regardless of who owns it)
h.mu.Lock()
for tid, tunnel := range h.tunnels {
if tunnel.BusID == msg.BusID && tunnel.ShareClient == msg.TargetClient {
delete(h.tunnels, tid)
log.Printf("[hub] tunnel force-closed: %s", tid)
break
}
}
h.mu.Unlock()
// Forward to share client
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[msg.TargetClient]; target != nil && target.Mode == protocol.ModeShare {
enriched := map[string]interface{}{
"type": protocol.MsgForceRelease,
"target_client": msg.TargetClient,
"bus_id": msg.BusID,
"from_client": sender.ID,
}
target.WriteJSON(enriched)
}
}
h.mu.RUnlock()
}
// handleDeviceReleased broadcasts device released notification
func (h *Hub) handleDeviceReleased(sender *Client, data []byte) {
h.mu.RLock()
group := h.groups[sender.Hash]
for _, client := range group {
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
client.mu.Lock()
client.Conn.WriteMessage(websocket.TextMessage, data)
client.mu.Unlock()
}
}
h.mu.RUnlock()
}
// broadcastToGroup sends a message to all clients in a hash group except the sender
func (h *Hub) broadcastToGroup(hash, excludeID string, msg interface{}) {
group := h.groups[hash]
for _, client := range group {
if client.ID != excludeID {
client.WriteJSON(msg)
}
}
}