329 lines
7.3 KiB
Go
329 lines
7.3 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/duffy/usb-server/internal/config"
|
|
"github.com/duffy/usb-server/internal/protocol"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Client manages the connection to the relay server
|
|
type Client struct {
|
|
cfg *config.Config
|
|
clientID string
|
|
conn *websocket.Conn
|
|
mu sync.Mutex
|
|
|
|
// Event callbacks
|
|
OnDeviceList func(msg *protocol.DeviceList)
|
|
OnDeviceGranted func(msg *protocol.DeviceGranted)
|
|
OnDeviceDenied func(msg *protocol.DeviceDenied)
|
|
OnDeviceReleased func(msg *protocol.DeviceReleased)
|
|
OnClientJoined func(msg *protocol.ClientJoined)
|
|
OnClientLeft func(msg *protocol.ClientLeft)
|
|
OnRequestDevice func(targetClient, fromClient, busID, requestID string)
|
|
OnReleaseDevice func(busID, fromClient string)
|
|
OnForceRelease func(targetClient, fromClient, busID string)
|
|
OnTunnelData func(tunnelID string, data []byte)
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewClient creates a new client instance
|
|
func NewClient(cfg *config.Config) *Client {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &Client{
|
|
cfg: cfg,
|
|
clientID: uuid.New().String(),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// ID returns the client ID
|
|
func (c *Client) ID() string {
|
|
return c.clientID
|
|
}
|
|
|
|
// Config returns the client config
|
|
func (c *Client) Config() *config.Config {
|
|
return c.cfg
|
|
}
|
|
|
|
// Connect establishes connection to the relay server
|
|
func (c *Client) Connect() error {
|
|
u, err := url.Parse(c.cfg.RelayAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid relay address: %w", err)
|
|
}
|
|
|
|
// Ensure WebSocket scheme
|
|
switch u.Scheme {
|
|
case "ws", "wss":
|
|
// ok
|
|
case "http":
|
|
u.Scheme = "ws"
|
|
case "https":
|
|
u.Scheme = "wss"
|
|
default:
|
|
u.Scheme = "ws"
|
|
}
|
|
|
|
if u.Path == "" {
|
|
u.Path = "/ws"
|
|
}
|
|
|
|
log.Printf("[client] connecting to %s", u.String())
|
|
|
|
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
if err != nil {
|
|
return fmt.Errorf("connecting to relay: %w", err)
|
|
}
|
|
|
|
c.mu.Lock()
|
|
c.conn = conn
|
|
c.mu.Unlock()
|
|
|
|
// Send registration
|
|
reg := &protocol.Register{
|
|
Type: protocol.MsgRegister,
|
|
Hash: c.cfg.Hash,
|
|
Mode: c.cfg.Mode,
|
|
ClientID: c.clientID,
|
|
Name: c.cfg.Name,
|
|
}
|
|
|
|
if err := conn.WriteJSON(reg); err != nil {
|
|
conn.Close()
|
|
return fmt.Errorf("sending registration: %w", err)
|
|
}
|
|
|
|
log.Printf("[client] registered as %s (mode=%s, name=%s)", c.clientID, c.cfg.Mode, c.cfg.Name)
|
|
|
|
return nil
|
|
}
|
|
|
|
// RunReadLoop reads messages from the relay and dispatches them
|
|
func (c *Client) RunReadLoop() error {
|
|
for {
|
|
select {
|
|
case <-c.ctx.Done():
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
msgType, data, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
|
return fmt.Errorf("read error: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
switch msgType {
|
|
case websocket.TextMessage:
|
|
c.handleTextMessage(data)
|
|
case websocket.BinaryMessage:
|
|
c.handleBinaryMessage(data)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Run connects and runs the main loop with auto-reconnect
|
|
func (c *Client) Run() error {
|
|
for {
|
|
if err := c.Connect(); err != nil {
|
|
log.Printf("[client] connection failed: %v, retrying in 5s...", err)
|
|
select {
|
|
case <-time.After(5 * time.Second):
|
|
continue
|
|
case <-c.ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
|
|
err := c.RunReadLoop()
|
|
if err != nil {
|
|
log.Printf("[client] disconnected: %v, reconnecting in 5s...", err)
|
|
} else {
|
|
log.Printf("[client] disconnected, reconnecting in 5s...")
|
|
}
|
|
|
|
c.mu.Lock()
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
c.mu.Unlock()
|
|
|
|
select {
|
|
case <-time.After(5 * time.Second):
|
|
case <-c.ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close shuts down the client
|
|
func (c *Client) Close() {
|
|
c.cancel()
|
|
c.mu.Lock()
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
}
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
// SendJSON sends a JSON message to the relay
|
|
func (c *Client) SendJSON(v interface{}) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.conn == nil {
|
|
return fmt.Errorf("not connected")
|
|
}
|
|
return c.conn.WriteJSON(v)
|
|
}
|
|
|
|
// SendBinary sends a binary message to the relay
|
|
func (c *Client) SendBinary(data []byte) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.conn == nil {
|
|
return fmt.Errorf("not connected")
|
|
}
|
|
return c.conn.WriteMessage(websocket.BinaryMessage, data)
|
|
}
|
|
|
|
// SendTunnelData sends tunnel data with the tunnel ID prefix
|
|
func (c *Client) SendTunnelData(tunnelID string, data []byte) error {
|
|
// Tunnel header: 16 bytes tunnel ID + payload
|
|
msg := make([]byte, protocol.TunnelHeaderSize+len(data))
|
|
copy(msg[:protocol.TunnelHeaderSize], tunnelID)
|
|
copy(msg[protocol.TunnelHeaderSize:], data)
|
|
return c.SendBinary(msg)
|
|
}
|
|
|
|
func (c *Client) handleTextMessage(data []byte) {
|
|
var env protocol.Envelope
|
|
if err := json.Unmarshal(data, &env); err != nil {
|
|
return
|
|
}
|
|
|
|
switch env.Type {
|
|
case protocol.MsgDeviceList:
|
|
if c.OnDeviceList != nil {
|
|
var msg protocol.DeviceList
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnDeviceList(&msg)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgRequestDevice:
|
|
if c.OnRequestDevice != nil {
|
|
var msg struct {
|
|
TargetClient string `json:"target_client"`
|
|
FromClient string `json:"from_client"`
|
|
BusID string `json:"bus_id"`
|
|
RequestID string `json:"request_id"`
|
|
}
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnRequestDevice(msg.TargetClient, msg.FromClient, msg.BusID, msg.RequestID)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgDeviceGranted:
|
|
if c.OnDeviceGranted != nil {
|
|
var msg protocol.DeviceGranted
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnDeviceGranted(&msg)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgDeviceDenied:
|
|
if c.OnDeviceDenied != nil {
|
|
var msg protocol.DeviceDenied
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnDeviceDenied(&msg)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgForceRelease:
|
|
if c.OnForceRelease != nil {
|
|
var msg struct {
|
|
TargetClient string `json:"target_client"`
|
|
FromClient string `json:"from_client"`
|
|
BusID string `json:"bus_id"`
|
|
}
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnForceRelease(msg.TargetClient, msg.FromClient, msg.BusID)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgReleaseDevice:
|
|
if c.OnReleaseDevice != nil {
|
|
var msg struct {
|
|
BusID string `json:"bus_id"`
|
|
FromClient string `json:"from_client"`
|
|
}
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnReleaseDevice(msg.BusID, msg.FromClient)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgDeviceReleased:
|
|
if c.OnDeviceReleased != nil {
|
|
var msg protocol.DeviceReleased
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnDeviceReleased(&msg)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgClientJoined:
|
|
if c.OnClientJoined != nil {
|
|
var msg protocol.ClientJoined
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnClientJoined(&msg)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgClientLeft:
|
|
if c.OnClientLeft != nil {
|
|
var msg protocol.ClientLeft
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
c.OnClientLeft(&msg)
|
|
}
|
|
}
|
|
|
|
case protocol.MsgPong:
|
|
// ignore pong
|
|
|
|
case protocol.MsgError:
|
|
var msg protocol.ErrorMsg
|
|
if json.Unmarshal(data, &msg) == nil {
|
|
log.Printf("[client] error from relay: %s", msg.Message)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleBinaryMessage(data []byte) {
|
|
if len(data) < protocol.TunnelHeaderSize {
|
|
return
|
|
}
|
|
|
|
tunnelID := string(data[:protocol.TunnelHeaderSize])
|
|
payload := data[protocol.TunnelHeaderSize:]
|
|
|
|
if c.OnTunnelData != nil {
|
|
c.OnTunnelData(tunnelID, payload)
|
|
}
|
|
}
|