first commit

This commit is contained in:
Stefan Hacker
2026-02-18 22:01:54 +01:00
commit 5464e553b3
35 changed files with 5432 additions and 0 deletions
+336
View File
@@ -0,0 +1,336 @@
package relay
import (
"encoding/json"
"log"
"sync"
"github.com/duffy/usb-server/internal/protocol"
"github.com/gorilla/websocket"
)
// Client represents a connected WebSocket client
type Client struct {
ID string
Hash string
Mode string // "share" or "use"
Name string
Conn *websocket.Conn
Send chan []byte // buffered channel for outgoing messages
mu sync.Mutex
}
// WriteJSON sends a JSON message to the client
func (c *Client) WriteJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.Conn.WriteJSON(v)
}
// WriteBinary sends a binary message to the client
func (c *Client) WriteBinary(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.Conn.WriteMessage(websocket.BinaryMessage, data)
}
// Hub manages all connected clients and routes messages between them
type Hub struct {
mu sync.RWMutex
groups map[string]map[string]*Client // hash -> client_id -> client
tunnels map[string]*Tunnel // tunnel_id -> tunnel info
}
// Tunnel tracks an active USB/IP tunnel between two clients
type Tunnel struct {
ID string
ShareClient string
UseClient string
BusID string
}
// NewHub creates a new Hub
func NewHub() *Hub {
return &Hub{
groups: make(map[string]map[string]*Client),
tunnels: make(map[string]*Tunnel),
}
}
// Register adds a client to its hash group
func (h *Hub) Register(client *Client) {
h.mu.Lock()
defer h.mu.Unlock()
if h.groups[client.Hash] == nil {
h.groups[client.Hash] = make(map[string]*Client)
}
h.groups[client.Hash][client.ID] = client
log.Printf("[hub] client registered: id=%s hash=%s..%s mode=%s name=%s",
client.ID, client.Hash[:8], client.Hash[len(client.Hash)-4:], client.Mode, client.Name)
// Notify other clients in the group
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientJoined{
Type: protocol.MsgClientJoined,
ClientID: client.ID,
Mode: client.Mode,
Name: client.Name,
})
}
// Unregister removes a client and cleans up its tunnels
func (h *Hub) Unregister(client *Client) {
h.mu.Lock()
defer h.mu.Unlock()
group := h.groups[client.Hash]
if group == nil {
return
}
delete(group, client.ID)
if len(group) == 0 {
delete(h.groups, client.Hash)
}
// Clean up tunnels involving this client
for tid, tunnel := range h.tunnels {
if tunnel.ShareClient == client.ID || tunnel.UseClient == client.ID {
delete(h.tunnels, tid)
}
}
log.Printf("[hub] client unregistered: id=%s name=%s", client.ID, client.Name)
// Notify others
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientLeft{
Type: protocol.MsgClientLeft,
ClientID: client.ID,
})
}
// HandleTextMessage processes a JSON control message
func (h *Hub) HandleTextMessage(sender *Client, data []byte) {
var env protocol.Envelope
if err := json.Unmarshal(data, &env); err != nil {
log.Printf("[hub] invalid message from %s: %v", sender.ID, err)
return
}
switch env.Type {
case protocol.MsgDeviceList:
h.handleDeviceList(sender, data)
case protocol.MsgRequestDevice:
h.handleRequestDevice(sender, data)
case protocol.MsgDeviceGranted:
h.handleDeviceGranted(sender, data)
case protocol.MsgDeviceDenied:
h.handleDeviceDenied(sender, data)
case protocol.MsgReleaseDevice:
h.handleReleaseDevice(sender, data)
case protocol.MsgDeviceReleased:
h.handleDeviceReleased(sender, data)
case protocol.MsgPing:
sender.WriteJSON(&protocol.Pong{Type: protocol.MsgPong})
default:
log.Printf("[hub] unknown message type from %s: %s", sender.ID, env.Type)
}
}
// HandleBinaryMessage forwards tunnel data to the other end
func (h *Hub) HandleBinaryMessage(sender *Client, data []byte) {
if len(data) < protocol.TunnelHeaderSize {
return
}
tunnelID := string(data[:protocol.TunnelHeaderSize])
h.mu.RLock()
tunnel := h.tunnels[tunnelID]
h.mu.RUnlock()
if tunnel == nil {
return
}
// Forward to the other end of the tunnel
var targetID string
if sender.ID == tunnel.ShareClient {
targetID = tunnel.UseClient
} else if sender.ID == tunnel.UseClient {
targetID = tunnel.ShareClient
} else {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[targetID]; target != nil {
target.WriteBinary(data)
}
}
h.mu.RUnlock()
}
// handleDeviceList broadcasts device list from share client to all use clients
func (h *Hub) handleDeviceList(sender *Client, data []byte) {
if sender.Mode != protocol.ModeShare {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
for _, client := range group {
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
client.mu.Lock()
client.Conn.WriteMessage(websocket.TextMessage, data)
client.mu.Unlock()
}
}
h.mu.RUnlock()
}
// handleRequestDevice forwards a device request to the target share client
func (h *Hub) handleRequestDevice(sender *Client, data []byte) {
var msg protocol.RequestDevice
if err := json.Unmarshal(data, &msg); err != nil {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[msg.TargetClient]; target != nil && target.Mode == protocol.ModeShare {
// Add the sender's ID so the share client knows who's requesting
enriched := map[string]interface{}{
"type": protocol.MsgRequestDevice,
"target_client": msg.TargetClient,
"bus_id": msg.BusID,
"request_id": msg.RequestID,
"from_client": sender.ID,
}
target.WriteJSON(enriched)
}
}
h.mu.RUnlock()
}
// handleDeviceGranted registers the tunnel and forwards to the requesting client
func (h *Hub) handleDeviceGranted(sender *Client, data []byte) {
var granted struct {
protocol.DeviceGranted
TargetClient string `json:"target_client"`
}
if err := json.Unmarshal(data, &granted); err != nil {
return
}
// Register tunnel
h.mu.Lock()
h.tunnels[granted.TunnelID] = &Tunnel{
ID: granted.TunnelID,
ShareClient: sender.ID,
UseClient: granted.TargetClient,
BusID: granted.BusID,
}
h.mu.Unlock()
log.Printf("[hub] tunnel created: %s (share=%s, use=%s, device=%s)",
granted.TunnelID, sender.ID, granted.TargetClient, granted.BusID)
// Forward to use client
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[granted.TargetClient]; target != nil {
target.mu.Lock()
target.Conn.WriteMessage(websocket.TextMessage, data)
target.mu.Unlock()
}
}
h.mu.RUnlock()
}
// handleDeviceDenied forwards denial to the requesting client
func (h *Hub) handleDeviceDenied(sender *Client, data []byte) {
var denied struct {
protocol.DeviceDenied
TargetClient string `json:"target_client"`
}
if err := json.Unmarshal(data, &denied); err != nil {
return
}
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[denied.TargetClient]; target != nil {
target.mu.Lock()
target.Conn.WriteMessage(websocket.TextMessage, data)
target.mu.Unlock()
}
}
h.mu.RUnlock()
}
// handleReleaseDevice forwards a release to the share client
func (h *Hub) handleReleaseDevice(sender *Client, data []byte) {
var msg protocol.ReleaseDevice
if err := json.Unmarshal(data, &msg); err != nil {
return
}
// Clean up tunnel
h.mu.Lock()
for tid, tunnel := range h.tunnels {
if tunnel.UseClient == sender.ID && tunnel.BusID == msg.BusID {
delete(h.tunnels, tid)
log.Printf("[hub] tunnel closed: %s", tid)
break
}
}
h.mu.Unlock()
// Forward to share client
h.mu.RLock()
group := h.groups[sender.Hash]
if group != nil {
if target := group[msg.TargetClient]; target != nil {
enriched := map[string]interface{}{
"type": protocol.MsgReleaseDevice,
"target_client": msg.TargetClient,
"bus_id": msg.BusID,
"from_client": sender.ID,
}
target.WriteJSON(enriched)
}
}
h.mu.RUnlock()
}
// handleDeviceReleased broadcasts device released notification
func (h *Hub) handleDeviceReleased(sender *Client, data []byte) {
h.mu.RLock()
group := h.groups[sender.Hash]
for _, client := range group {
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
client.mu.Lock()
client.Conn.WriteMessage(websocket.TextMessage, data)
client.mu.Unlock()
}
}
h.mu.RUnlock()
}
// broadcastToGroup sends a message to all clients in a hash group except the sender
func (h *Hub) broadcastToGroup(hash, excludeID string, msg interface{}) {
group := h.groups[hash]
for _, client := range group {
if client.ID != excludeID {
client.WriteJSON(msg)
}
}
}
+138
View File
@@ -0,0 +1,138 @@
package relay
import (
"encoding/json"
"log"
"net/http"
"time"
"github.com/duffy/usb-server/internal/protocol"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 64 * 1024,
WriteBufferSize: 64 * 1024,
CheckOrigin: func(r *http.Request) bool {
return true // relay accepts all origins
},
}
// Server is the WebSocket relay server
type Server struct {
hub *Hub
addr string
}
// NewServer creates a new relay server
func NewServer(addr string) *Server {
return &Server{
hub: NewHub(),
addr: addr,
}
}
// Run starts the relay server
func (s *Server) Run() error {
mux := http.NewServeMux()
mux.HandleFunc("/ws", s.handleWebSocket)
mux.HandleFunc("/health", s.handleHealth)
log.Printf("[relay] starting on %s", s.addr)
return http.ListenAndServe(s.addr, mux)
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("[relay] upgrade error: %v", err)
return
}
defer conn.Close()
// Set read limits and deadlines
conn.SetReadLimit(1024 * 1024) // 1MB max message
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// Wait for registration message
_, msgData, err := conn.ReadMessage()
if err != nil {
log.Printf("[relay] read error during registration: %v", err)
return
}
var reg protocol.Register
if err := json.Unmarshal(msgData, &reg); err != nil || reg.Type != protocol.MsgRegister {
log.Printf("[relay] invalid registration message")
conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "invalid registration"})
return
}
if reg.Hash == "" || reg.ClientID == "" || (reg.Mode != protocol.ModeShare && reg.Mode != protocol.ModeUse) {
conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "missing required fields"})
return
}
client := &Client{
ID: reg.ClientID,
Hash: reg.Hash,
Mode: reg.Mode,
Name: reg.Name,
Conn: conn,
Send: make(chan []byte, 256),
}
s.hub.Register(client)
defer s.hub.Unregister(client)
// Start ping ticker
done := make(chan struct{})
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
client.mu.Lock()
err := conn.WriteMessage(websocket.PingMessage, nil)
client.mu.Unlock()
if err != nil {
return
}
case <-done:
return
}
}
}()
defer close(done)
// Read loop
for {
msgType, data, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
log.Printf("[relay] read error from %s: %v", client.ID, err)
}
break
}
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
switch msgType {
case websocket.TextMessage:
s.hub.HandleTextMessage(client, data)
case websocket.BinaryMessage:
s.hub.HandleBinaryMessage(client, data)
}
}
}