337 lines
8.3 KiB
Go
337 lines
8.3 KiB
Go
package relay
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log"
|
|
"sync"
|
|
|
|
"github.com/duffy/usb-server/internal/protocol"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Client represents a connected WebSocket client
|
|
type Client struct {
|
|
ID string
|
|
Hash string
|
|
Mode string // "share" or "use"
|
|
Name string
|
|
Conn *websocket.Conn
|
|
Send chan []byte // buffered channel for outgoing messages
|
|
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// WriteJSON sends a JSON message to the client
|
|
func (c *Client) WriteJSON(v interface{}) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.Conn.WriteJSON(v)
|
|
}
|
|
|
|
// WriteBinary sends a binary message to the client
|
|
func (c *Client) WriteBinary(data []byte) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.Conn.WriteMessage(websocket.BinaryMessage, data)
|
|
}
|
|
|
|
// Hub manages all connected clients and routes messages between them
|
|
type Hub struct {
|
|
mu sync.RWMutex
|
|
groups map[string]map[string]*Client // hash -> client_id -> client
|
|
tunnels map[string]*Tunnel // tunnel_id -> tunnel info
|
|
}
|
|
|
|
// Tunnel tracks an active USB/IP tunnel between two clients
|
|
type Tunnel struct {
|
|
ID string
|
|
ShareClient string
|
|
UseClient string
|
|
BusID string
|
|
}
|
|
|
|
// NewHub creates a new Hub
|
|
func NewHub() *Hub {
|
|
return &Hub{
|
|
groups: make(map[string]map[string]*Client),
|
|
tunnels: make(map[string]*Tunnel),
|
|
}
|
|
}
|
|
|
|
// Register adds a client to its hash group
|
|
func (h *Hub) Register(client *Client) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if h.groups[client.Hash] == nil {
|
|
h.groups[client.Hash] = make(map[string]*Client)
|
|
}
|
|
h.groups[client.Hash][client.ID] = client
|
|
|
|
log.Printf("[hub] client registered: id=%s hash=%s..%s mode=%s name=%s",
|
|
client.ID, client.Hash[:8], client.Hash[len(client.Hash)-4:], client.Mode, client.Name)
|
|
|
|
// Notify other clients in the group
|
|
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientJoined{
|
|
Type: protocol.MsgClientJoined,
|
|
ClientID: client.ID,
|
|
Mode: client.Mode,
|
|
Name: client.Name,
|
|
})
|
|
}
|
|
|
|
// Unregister removes a client and cleans up its tunnels
|
|
func (h *Hub) Unregister(client *Client) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
group := h.groups[client.Hash]
|
|
if group == nil {
|
|
return
|
|
}
|
|
|
|
delete(group, client.ID)
|
|
if len(group) == 0 {
|
|
delete(h.groups, client.Hash)
|
|
}
|
|
|
|
// Clean up tunnels involving this client
|
|
for tid, tunnel := range h.tunnels {
|
|
if tunnel.ShareClient == client.ID || tunnel.UseClient == client.ID {
|
|
delete(h.tunnels, tid)
|
|
}
|
|
}
|
|
|
|
log.Printf("[hub] client unregistered: id=%s name=%s", client.ID, client.Name)
|
|
|
|
// Notify others
|
|
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientLeft{
|
|
Type: protocol.MsgClientLeft,
|
|
ClientID: client.ID,
|
|
})
|
|
}
|
|
|
|
// HandleTextMessage processes a JSON control message
|
|
func (h *Hub) HandleTextMessage(sender *Client, data []byte) {
|
|
var env protocol.Envelope
|
|
if err := json.Unmarshal(data, &env); err != nil {
|
|
log.Printf("[hub] invalid message from %s: %v", sender.ID, err)
|
|
return
|
|
}
|
|
|
|
switch env.Type {
|
|
case protocol.MsgDeviceList:
|
|
h.handleDeviceList(sender, data)
|
|
case protocol.MsgRequestDevice:
|
|
h.handleRequestDevice(sender, data)
|
|
case protocol.MsgDeviceGranted:
|
|
h.handleDeviceGranted(sender, data)
|
|
case protocol.MsgDeviceDenied:
|
|
h.handleDeviceDenied(sender, data)
|
|
case protocol.MsgReleaseDevice:
|
|
h.handleReleaseDevice(sender, data)
|
|
case protocol.MsgDeviceReleased:
|
|
h.handleDeviceReleased(sender, data)
|
|
case protocol.MsgPing:
|
|
sender.WriteJSON(&protocol.Pong{Type: protocol.MsgPong})
|
|
default:
|
|
log.Printf("[hub] unknown message type from %s: %s", sender.ID, env.Type)
|
|
}
|
|
}
|
|
|
|
// HandleBinaryMessage forwards tunnel data to the other end
|
|
func (h *Hub) HandleBinaryMessage(sender *Client, data []byte) {
|
|
if len(data) < protocol.TunnelHeaderSize {
|
|
return
|
|
}
|
|
|
|
tunnelID := string(data[:protocol.TunnelHeaderSize])
|
|
|
|
h.mu.RLock()
|
|
tunnel := h.tunnels[tunnelID]
|
|
h.mu.RUnlock()
|
|
|
|
if tunnel == nil {
|
|
return
|
|
}
|
|
|
|
// Forward to the other end of the tunnel
|
|
var targetID string
|
|
if sender.ID == tunnel.ShareClient {
|
|
targetID = tunnel.UseClient
|
|
} else if sender.ID == tunnel.UseClient {
|
|
targetID = tunnel.ShareClient
|
|
} else {
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
group := h.groups[sender.Hash]
|
|
if group != nil {
|
|
if target := group[targetID]; target != nil {
|
|
target.WriteBinary(data)
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
|
|
// handleDeviceList broadcasts device list from share client to all use clients
|
|
func (h *Hub) handleDeviceList(sender *Client, data []byte) {
|
|
if sender.Mode != protocol.ModeShare {
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
group := h.groups[sender.Hash]
|
|
for _, client := range group {
|
|
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
|
|
client.mu.Lock()
|
|
client.Conn.WriteMessage(websocket.TextMessage, data)
|
|
client.mu.Unlock()
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
|
|
// handleRequestDevice forwards a device request to the target share client
|
|
func (h *Hub) handleRequestDevice(sender *Client, data []byte) {
|
|
var msg protocol.RequestDevice
|
|
if err := json.Unmarshal(data, &msg); err != nil {
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
group := h.groups[sender.Hash]
|
|
if group != nil {
|
|
if target := group[msg.TargetClient]; target != nil && target.Mode == protocol.ModeShare {
|
|
// Add the sender's ID so the share client knows who's requesting
|
|
enriched := map[string]interface{}{
|
|
"type": protocol.MsgRequestDevice,
|
|
"target_client": msg.TargetClient,
|
|
"bus_id": msg.BusID,
|
|
"request_id": msg.RequestID,
|
|
"from_client": sender.ID,
|
|
}
|
|
target.WriteJSON(enriched)
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
|
|
// handleDeviceGranted registers the tunnel and forwards to the requesting client
|
|
func (h *Hub) handleDeviceGranted(sender *Client, data []byte) {
|
|
var granted struct {
|
|
protocol.DeviceGranted
|
|
TargetClient string `json:"target_client"`
|
|
}
|
|
if err := json.Unmarshal(data, &granted); err != nil {
|
|
return
|
|
}
|
|
|
|
// Register tunnel
|
|
h.mu.Lock()
|
|
h.tunnels[granted.TunnelID] = &Tunnel{
|
|
ID: granted.TunnelID,
|
|
ShareClient: sender.ID,
|
|
UseClient: granted.TargetClient,
|
|
BusID: granted.BusID,
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
log.Printf("[hub] tunnel created: %s (share=%s, use=%s, device=%s)",
|
|
granted.TunnelID, sender.ID, granted.TargetClient, granted.BusID)
|
|
|
|
// Forward to use client
|
|
h.mu.RLock()
|
|
group := h.groups[sender.Hash]
|
|
if group != nil {
|
|
if target := group[granted.TargetClient]; target != nil {
|
|
target.mu.Lock()
|
|
target.Conn.WriteMessage(websocket.TextMessage, data)
|
|
target.mu.Unlock()
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
|
|
// handleDeviceDenied forwards denial to the requesting client
|
|
func (h *Hub) handleDeviceDenied(sender *Client, data []byte) {
|
|
var denied struct {
|
|
protocol.DeviceDenied
|
|
TargetClient string `json:"target_client"`
|
|
}
|
|
if err := json.Unmarshal(data, &denied); err != nil {
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
group := h.groups[sender.Hash]
|
|
if group != nil {
|
|
if target := group[denied.TargetClient]; target != nil {
|
|
target.mu.Lock()
|
|
target.Conn.WriteMessage(websocket.TextMessage, data)
|
|
target.mu.Unlock()
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
|
|
// handleReleaseDevice forwards a release to the share client
|
|
func (h *Hub) handleReleaseDevice(sender *Client, data []byte) {
|
|
var msg protocol.ReleaseDevice
|
|
if err := json.Unmarshal(data, &msg); err != nil {
|
|
return
|
|
}
|
|
|
|
// Clean up tunnel
|
|
h.mu.Lock()
|
|
for tid, tunnel := range h.tunnels {
|
|
if tunnel.UseClient == sender.ID && tunnel.BusID == msg.BusID {
|
|
delete(h.tunnels, tid)
|
|
log.Printf("[hub] tunnel closed: %s", tid)
|
|
break
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
// Forward to share client
|
|
h.mu.RLock()
|
|
group := h.groups[sender.Hash]
|
|
if group != nil {
|
|
if target := group[msg.TargetClient]; target != nil {
|
|
enriched := map[string]interface{}{
|
|
"type": protocol.MsgReleaseDevice,
|
|
"target_client": msg.TargetClient,
|
|
"bus_id": msg.BusID,
|
|
"from_client": sender.ID,
|
|
}
|
|
target.WriteJSON(enriched)
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
|
|
// handleDeviceReleased broadcasts device released notification
|
|
func (h *Hub) handleDeviceReleased(sender *Client, data []byte) {
|
|
h.mu.RLock()
|
|
group := h.groups[sender.Hash]
|
|
for _, client := range group {
|
|
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
|
|
client.mu.Lock()
|
|
client.Conn.WriteMessage(websocket.TextMessage, data)
|
|
client.mu.Unlock()
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
|
|
// broadcastToGroup sends a message to all clients in a hash group except the sender
|
|
func (h *Hub) broadcastToGroup(hash, excludeID string, msg interface{}) {
|
|
group := h.groups[hash]
|
|
for _, client := range group {
|
|
if client.ID != excludeID {
|
|
client.WriteJSON(msg)
|
|
}
|
|
}
|
|
}
|