first commit
This commit is contained in:
@@ -0,0 +1,336 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/duffy/usb-server/internal/protocol"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Client represents a connected WebSocket client
|
||||
type Client struct {
|
||||
ID string
|
||||
Hash string
|
||||
Mode string // "share" or "use"
|
||||
Name string
|
||||
Conn *websocket.Conn
|
||||
Send chan []byte // buffered channel for outgoing messages
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// WriteJSON sends a JSON message to the client
|
||||
func (c *Client) WriteJSON(v interface{}) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.Conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
// WriteBinary sends a binary message to the client
|
||||
func (c *Client) WriteBinary(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.Conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
// Hub manages all connected clients and routes messages between them
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
groups map[string]map[string]*Client // hash -> client_id -> client
|
||||
tunnels map[string]*Tunnel // tunnel_id -> tunnel info
|
||||
}
|
||||
|
||||
// Tunnel tracks an active USB/IP tunnel between two clients
|
||||
type Tunnel struct {
|
||||
ID string
|
||||
ShareClient string
|
||||
UseClient string
|
||||
BusID string
|
||||
}
|
||||
|
||||
// NewHub creates a new Hub
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
groups: make(map[string]map[string]*Client),
|
||||
tunnels: make(map[string]*Tunnel),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a client to its hash group
|
||||
func (h *Hub) Register(client *Client) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.groups[client.Hash] == nil {
|
||||
h.groups[client.Hash] = make(map[string]*Client)
|
||||
}
|
||||
h.groups[client.Hash][client.ID] = client
|
||||
|
||||
log.Printf("[hub] client registered: id=%s hash=%s..%s mode=%s name=%s",
|
||||
client.ID, client.Hash[:8], client.Hash[len(client.Hash)-4:], client.Mode, client.Name)
|
||||
|
||||
// Notify other clients in the group
|
||||
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientJoined{
|
||||
Type: protocol.MsgClientJoined,
|
||||
ClientID: client.ID,
|
||||
Mode: client.Mode,
|
||||
Name: client.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// Unregister removes a client and cleans up its tunnels
|
||||
func (h *Hub) Unregister(client *Client) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
group := h.groups[client.Hash]
|
||||
if group == nil {
|
||||
return
|
||||
}
|
||||
|
||||
delete(group, client.ID)
|
||||
if len(group) == 0 {
|
||||
delete(h.groups, client.Hash)
|
||||
}
|
||||
|
||||
// Clean up tunnels involving this client
|
||||
for tid, tunnel := range h.tunnels {
|
||||
if tunnel.ShareClient == client.ID || tunnel.UseClient == client.ID {
|
||||
delete(h.tunnels, tid)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[hub] client unregistered: id=%s name=%s", client.ID, client.Name)
|
||||
|
||||
// Notify others
|
||||
h.broadcastToGroup(client.Hash, client.ID, &protocol.ClientLeft{
|
||||
Type: protocol.MsgClientLeft,
|
||||
ClientID: client.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleTextMessage processes a JSON control message
|
||||
func (h *Hub) HandleTextMessage(sender *Client, data []byte) {
|
||||
var env protocol.Envelope
|
||||
if err := json.Unmarshal(data, &env); err != nil {
|
||||
log.Printf("[hub] invalid message from %s: %v", sender.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch env.Type {
|
||||
case protocol.MsgDeviceList:
|
||||
h.handleDeviceList(sender, data)
|
||||
case protocol.MsgRequestDevice:
|
||||
h.handleRequestDevice(sender, data)
|
||||
case protocol.MsgDeviceGranted:
|
||||
h.handleDeviceGranted(sender, data)
|
||||
case protocol.MsgDeviceDenied:
|
||||
h.handleDeviceDenied(sender, data)
|
||||
case protocol.MsgReleaseDevice:
|
||||
h.handleReleaseDevice(sender, data)
|
||||
case protocol.MsgDeviceReleased:
|
||||
h.handleDeviceReleased(sender, data)
|
||||
case protocol.MsgPing:
|
||||
sender.WriteJSON(&protocol.Pong{Type: protocol.MsgPong})
|
||||
default:
|
||||
log.Printf("[hub] unknown message type from %s: %s", sender.ID, env.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleBinaryMessage forwards tunnel data to the other end
|
||||
func (h *Hub) HandleBinaryMessage(sender *Client, data []byte) {
|
||||
if len(data) < protocol.TunnelHeaderSize {
|
||||
return
|
||||
}
|
||||
|
||||
tunnelID := string(data[:protocol.TunnelHeaderSize])
|
||||
|
||||
h.mu.RLock()
|
||||
tunnel := h.tunnels[tunnelID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if tunnel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Forward to the other end of the tunnel
|
||||
var targetID string
|
||||
if sender.ID == tunnel.ShareClient {
|
||||
targetID = tunnel.UseClient
|
||||
} else if sender.ID == tunnel.UseClient {
|
||||
targetID = tunnel.ShareClient
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
group := h.groups[sender.Hash]
|
||||
if group != nil {
|
||||
if target := group[targetID]; target != nil {
|
||||
target.WriteBinary(data)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
|
||||
// handleDeviceList broadcasts device list from share client to all use clients
|
||||
func (h *Hub) handleDeviceList(sender *Client, data []byte) {
|
||||
if sender.Mode != protocol.ModeShare {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
group := h.groups[sender.Hash]
|
||||
for _, client := range group {
|
||||
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
|
||||
client.mu.Lock()
|
||||
client.Conn.WriteMessage(websocket.TextMessage, data)
|
||||
client.mu.Unlock()
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
|
||||
// handleRequestDevice forwards a device request to the target share client
|
||||
func (h *Hub) handleRequestDevice(sender *Client, data []byte) {
|
||||
var msg protocol.RequestDevice
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
group := h.groups[sender.Hash]
|
||||
if group != nil {
|
||||
if target := group[msg.TargetClient]; target != nil && target.Mode == protocol.ModeShare {
|
||||
// Add the sender's ID so the share client knows who's requesting
|
||||
enriched := map[string]interface{}{
|
||||
"type": protocol.MsgRequestDevice,
|
||||
"target_client": msg.TargetClient,
|
||||
"bus_id": msg.BusID,
|
||||
"request_id": msg.RequestID,
|
||||
"from_client": sender.ID,
|
||||
}
|
||||
target.WriteJSON(enriched)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
|
||||
// handleDeviceGranted registers the tunnel and forwards to the requesting client
|
||||
func (h *Hub) handleDeviceGranted(sender *Client, data []byte) {
|
||||
var granted struct {
|
||||
protocol.DeviceGranted
|
||||
TargetClient string `json:"target_client"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &granted); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Register tunnel
|
||||
h.mu.Lock()
|
||||
h.tunnels[granted.TunnelID] = &Tunnel{
|
||||
ID: granted.TunnelID,
|
||||
ShareClient: sender.ID,
|
||||
UseClient: granted.TargetClient,
|
||||
BusID: granted.BusID,
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Printf("[hub] tunnel created: %s (share=%s, use=%s, device=%s)",
|
||||
granted.TunnelID, sender.ID, granted.TargetClient, granted.BusID)
|
||||
|
||||
// Forward to use client
|
||||
h.mu.RLock()
|
||||
group := h.groups[sender.Hash]
|
||||
if group != nil {
|
||||
if target := group[granted.TargetClient]; target != nil {
|
||||
target.mu.Lock()
|
||||
target.Conn.WriteMessage(websocket.TextMessage, data)
|
||||
target.mu.Unlock()
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
|
||||
// handleDeviceDenied forwards denial to the requesting client
|
||||
func (h *Hub) handleDeviceDenied(sender *Client, data []byte) {
|
||||
var denied struct {
|
||||
protocol.DeviceDenied
|
||||
TargetClient string `json:"target_client"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &denied); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
group := h.groups[sender.Hash]
|
||||
if group != nil {
|
||||
if target := group[denied.TargetClient]; target != nil {
|
||||
target.mu.Lock()
|
||||
target.Conn.WriteMessage(websocket.TextMessage, data)
|
||||
target.mu.Unlock()
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
|
||||
// handleReleaseDevice forwards a release to the share client
|
||||
func (h *Hub) handleReleaseDevice(sender *Client, data []byte) {
|
||||
var msg protocol.ReleaseDevice
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clean up tunnel
|
||||
h.mu.Lock()
|
||||
for tid, tunnel := range h.tunnels {
|
||||
if tunnel.UseClient == sender.ID && tunnel.BusID == msg.BusID {
|
||||
delete(h.tunnels, tid)
|
||||
log.Printf("[hub] tunnel closed: %s", tid)
|
||||
break
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
// Forward to share client
|
||||
h.mu.RLock()
|
||||
group := h.groups[sender.Hash]
|
||||
if group != nil {
|
||||
if target := group[msg.TargetClient]; target != nil {
|
||||
enriched := map[string]interface{}{
|
||||
"type": protocol.MsgReleaseDevice,
|
||||
"target_client": msg.TargetClient,
|
||||
"bus_id": msg.BusID,
|
||||
"from_client": sender.ID,
|
||||
}
|
||||
target.WriteJSON(enriched)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
|
||||
// handleDeviceReleased broadcasts device released notification
|
||||
func (h *Hub) handleDeviceReleased(sender *Client, data []byte) {
|
||||
h.mu.RLock()
|
||||
group := h.groups[sender.Hash]
|
||||
for _, client := range group {
|
||||
if client.ID != sender.ID && client.Mode == protocol.ModeUse {
|
||||
client.mu.Lock()
|
||||
client.Conn.WriteMessage(websocket.TextMessage, data)
|
||||
client.mu.Unlock()
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
|
||||
// broadcastToGroup sends a message to all clients in a hash group except the sender
|
||||
func (h *Hub) broadcastToGroup(hash, excludeID string, msg interface{}) {
|
||||
group := h.groups[hash]
|
||||
for _, client := range group {
|
||||
if client.ID != excludeID {
|
||||
client.WriteJSON(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/duffy/usb-server/internal/protocol"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 64 * 1024,
|
||||
WriteBufferSize: 64 * 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // relay accepts all origins
|
||||
},
|
||||
}
|
||||
|
||||
// Server is the WebSocket relay server
|
||||
type Server struct {
|
||||
hub *Hub
|
||||
addr string
|
||||
}
|
||||
|
||||
// NewServer creates a new relay server
|
||||
func NewServer(addr string) *Server {
|
||||
return &Server{
|
||||
hub: NewHub(),
|
||||
addr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the relay server
|
||||
func (s *Server) Run() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", s.handleWebSocket)
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
|
||||
log.Printf("[relay] starting on %s", s.addr)
|
||||
return http.ListenAndServe(s.addr, mux)
|
||||
}
|
||||
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("[relay] upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Set read limits and deadlines
|
||||
conn.SetReadLimit(1024 * 1024) // 1MB max message
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Wait for registration message
|
||||
_, msgData, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("[relay] read error during registration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var reg protocol.Register
|
||||
if err := json.Unmarshal(msgData, ®); err != nil || reg.Type != protocol.MsgRegister {
|
||||
log.Printf("[relay] invalid registration message")
|
||||
conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "invalid registration"})
|
||||
return
|
||||
}
|
||||
|
||||
if reg.Hash == "" || reg.ClientID == "" || (reg.Mode != protocol.ModeShare && reg.Mode != protocol.ModeUse) {
|
||||
conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "missing required fields"})
|
||||
return
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
ID: reg.ClientID,
|
||||
Hash: reg.Hash,
|
||||
Mode: reg.Mode,
|
||||
Name: reg.Name,
|
||||
Conn: conn,
|
||||
Send: make(chan []byte, 256),
|
||||
}
|
||||
|
||||
s.hub.Register(client)
|
||||
defer s.hub.Unregister(client)
|
||||
|
||||
// Start ping ticker
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
client.mu.Lock()
|
||||
err := conn.WriteMessage(websocket.PingMessage, nil)
|
||||
client.mu.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
// Read loop
|
||||
for {
|
||||
msgType, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||
log.Printf("[relay] read error from %s: %v", client.ID, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
switch msgType {
|
||||
case websocket.TextMessage:
|
||||
s.hub.HandleTextMessage(client, data)
|
||||
case websocket.BinaryMessage:
|
||||
s.hub.HandleBinaryMessage(client, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user