first commit
This commit is contained in:
@@ -0,0 +1,315 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/duffy/usb-server/internal/config"
|
||||
"github.com/duffy/usb-server/internal/protocol"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Client manages the connection to the relay server
|
||||
type Client struct {
|
||||
cfg *config.Config
|
||||
clientID string
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
|
||||
// Event callbacks
|
||||
OnDeviceList func(msg *protocol.DeviceList)
|
||||
OnDeviceGranted func(msg *protocol.DeviceGranted)
|
||||
OnDeviceDenied func(msg *protocol.DeviceDenied)
|
||||
OnDeviceReleased func(msg *protocol.DeviceReleased)
|
||||
OnClientJoined func(msg *protocol.ClientJoined)
|
||||
OnClientLeft func(msg *protocol.ClientLeft)
|
||||
OnRequestDevice func(targetClient, fromClient, busID, requestID string)
|
||||
OnReleaseDevice func(busID, fromClient string)
|
||||
OnTunnelData func(tunnelID string, data []byte)
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewClient creates a new client instance
|
||||
func NewClient(cfg *config.Config) *Client {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Client{
|
||||
cfg: cfg,
|
||||
clientID: uuid.New().String(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the client ID
|
||||
func (c *Client) ID() string {
|
||||
return c.clientID
|
||||
}
|
||||
|
||||
// Config returns the client config
|
||||
func (c *Client) Config() *config.Config {
|
||||
return c.cfg
|
||||
}
|
||||
|
||||
// Connect establishes connection to the relay server
|
||||
func (c *Client) Connect() error {
|
||||
u, err := url.Parse(c.cfg.RelayAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid relay address: %w", err)
|
||||
}
|
||||
|
||||
// Ensure WebSocket scheme
|
||||
switch u.Scheme {
|
||||
case "ws", "wss":
|
||||
// ok
|
||||
case "http":
|
||||
u.Scheme = "ws"
|
||||
case "https":
|
||||
u.Scheme = "wss"
|
||||
default:
|
||||
u.Scheme = "ws"
|
||||
}
|
||||
|
||||
if u.Path == "" {
|
||||
u.Path = "/ws"
|
||||
}
|
||||
|
||||
log.Printf("[client] connecting to %s", u.String())
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to relay: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
// Send registration
|
||||
reg := &protocol.Register{
|
||||
Type: protocol.MsgRegister,
|
||||
Hash: c.cfg.Hash,
|
||||
Mode: c.cfg.Mode,
|
||||
ClientID: c.clientID,
|
||||
Name: c.cfg.Name,
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(reg); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("sending registration: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[client] registered as %s (mode=%s, name=%s)", c.clientID, c.cfg.Mode, c.cfg.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunReadLoop reads messages from the relay and dispatches them
|
||||
func (c *Client) RunReadLoop() error {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
msgType, data, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||
return fmt.Errorf("read error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case websocket.TextMessage:
|
||||
c.handleTextMessage(data)
|
||||
case websocket.BinaryMessage:
|
||||
c.handleBinaryMessage(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run connects and runs the main loop with auto-reconnect
|
||||
func (c *Client) Run() error {
|
||||
for {
|
||||
if err := c.Connect(); err != nil {
|
||||
log.Printf("[client] connection failed: %v, retrying in 5s...", err)
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
continue
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
err := c.RunReadLoop()
|
||||
if err != nil {
|
||||
log.Printf("[client] disconnected: %v, reconnecting in 5s...", err)
|
||||
} else {
|
||||
log.Printf("[client] disconnected, reconnecting in 5s...")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-c.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the client
|
||||
func (c *Client) Close() {
|
||||
c.cancel()
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// SendJSON sends a JSON message to the relay
|
||||
func (c *Client) SendJSON(v interface{}) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
return c.conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
// SendBinary sends a binary message to the relay
|
||||
func (c *Client) SendBinary(data []byte) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
return c.conn.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
// SendTunnelData sends tunnel data with the tunnel ID prefix
|
||||
func (c *Client) SendTunnelData(tunnelID string, data []byte) error {
|
||||
// Tunnel header: 16 bytes tunnel ID + payload
|
||||
msg := make([]byte, protocol.TunnelHeaderSize+len(data))
|
||||
copy(msg[:protocol.TunnelHeaderSize], tunnelID)
|
||||
copy(msg[protocol.TunnelHeaderSize:], data)
|
||||
return c.SendBinary(msg)
|
||||
}
|
||||
|
||||
func (c *Client) handleTextMessage(data []byte) {
|
||||
var env protocol.Envelope
|
||||
if err := json.Unmarshal(data, &env); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch env.Type {
|
||||
case protocol.MsgDeviceList:
|
||||
if c.OnDeviceList != nil {
|
||||
var msg protocol.DeviceList
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnDeviceList(&msg)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgRequestDevice:
|
||||
if c.OnRequestDevice != nil {
|
||||
var msg struct {
|
||||
TargetClient string `json:"target_client"`
|
||||
FromClient string `json:"from_client"`
|
||||
BusID string `json:"bus_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnRequestDevice(msg.TargetClient, msg.FromClient, msg.BusID, msg.RequestID)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgDeviceGranted:
|
||||
if c.OnDeviceGranted != nil {
|
||||
var msg protocol.DeviceGranted
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnDeviceGranted(&msg)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgDeviceDenied:
|
||||
if c.OnDeviceDenied != nil {
|
||||
var msg protocol.DeviceDenied
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnDeviceDenied(&msg)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgReleaseDevice:
|
||||
if c.OnReleaseDevice != nil {
|
||||
var msg struct {
|
||||
BusID string `json:"bus_id"`
|
||||
FromClient string `json:"from_client"`
|
||||
}
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnReleaseDevice(msg.BusID, msg.FromClient)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgDeviceReleased:
|
||||
if c.OnDeviceReleased != nil {
|
||||
var msg protocol.DeviceReleased
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnDeviceReleased(&msg)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgClientJoined:
|
||||
if c.OnClientJoined != nil {
|
||||
var msg protocol.ClientJoined
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnClientJoined(&msg)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgClientLeft:
|
||||
if c.OnClientLeft != nil {
|
||||
var msg protocol.ClientLeft
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
c.OnClientLeft(&msg)
|
||||
}
|
||||
}
|
||||
|
||||
case protocol.MsgPong:
|
||||
// ignore pong
|
||||
|
||||
case protocol.MsgError:
|
||||
var msg protocol.ErrorMsg
|
||||
if json.Unmarshal(data, &msg) == nil {
|
||||
log.Printf("[client] error from relay: %s", msg.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleBinaryMessage(data []byte) {
|
||||
if len(data) < protocol.TunnelHeaderSize {
|
||||
return
|
||||
}
|
||||
|
||||
tunnelID := string(data[:protocol.TunnelHeaderSize])
|
||||
payload := data[protocol.TunnelHeaderSize:]
|
||||
|
||||
if c.OnTunnelData != nil {
|
||||
c.OnTunnelData(tunnelID, payload)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,358 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/duffy/usb-server/internal/protocol"
|
||||
"github.com/duffy/usb-server/internal/usb"
|
||||
"github.com/duffy/usb-server/internal/usbip"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ShareManager handles sharing USB devices
|
||||
type ShareManager struct {
|
||||
client *Client
|
||||
mu sync.RWMutex
|
||||
devices []usb.Device
|
||||
active map[string]*activeShare // busID -> active share
|
||||
tunnels map[string]*shareTunnel // tunnelID -> tunnel
|
||||
}
|
||||
|
||||
type activeShare struct {
|
||||
device *usb.Device
|
||||
server *usbip.Server
|
||||
usedBy string // client ID using this device
|
||||
tunnelID string
|
||||
}
|
||||
|
||||
type shareTunnel struct {
|
||||
id string
|
||||
busID string
|
||||
inPipe *io.PipeWriter
|
||||
outPipe *io.PipeReader
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewShareManager creates a share manager
|
||||
func NewShareManager(client *Client) *ShareManager {
|
||||
sm := &ShareManager{
|
||||
client: client,
|
||||
active: make(map[string]*activeShare),
|
||||
tunnels: make(map[string]*shareTunnel),
|
||||
}
|
||||
|
||||
// Set up callbacks
|
||||
client.OnRequestDevice = sm.handleRequestDevice
|
||||
client.OnReleaseDevice = sm.handleReleaseDevice
|
||||
client.OnTunnelData = sm.handleTunnelData
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// Run starts the share manager: periodic device enumeration + event handling
|
||||
func (sm *ShareManager) Run() error {
|
||||
// Initial enumeration
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
|
||||
// Periodic refresh
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
case <-sm.client.ctx.Done():
|
||||
sm.cleanup()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetDevices returns the current device list
|
||||
func (sm *ShareManager) GetDevices() []usb.Device {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
result := make([]usb.Device, len(sm.devices))
|
||||
copy(result, sm.devices)
|
||||
return result
|
||||
}
|
||||
|
||||
func (sm *ShareManager) refreshDevices() {
|
||||
devices, err := usb.Enumerate()
|
||||
if err != nil {
|
||||
log.Printf("[share] USB enumeration error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
sm.devices = devices
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
|
||||
func (sm *ShareManager) broadcastDeviceList() {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var protoDevices []protocol.USBDevice
|
||||
for _, dev := range sm.devices {
|
||||
status := protocol.StatusAvailable
|
||||
usedBy := ""
|
||||
if share, ok := sm.active[dev.BusID]; ok {
|
||||
status = protocol.StatusInUse
|
||||
usedBy = share.usedBy
|
||||
}
|
||||
|
||||
protoDevices = append(protoDevices, protocol.USBDevice{
|
||||
BusID: dev.BusID,
|
||||
BusNum: dev.BusNum,
|
||||
DevNum: dev.DevNum,
|
||||
Speed: dev.Speed,
|
||||
VendorID: fmt.Sprintf("%04x", dev.VendorID),
|
||||
ProductID: fmt.Sprintf("%04x", dev.ProductID),
|
||||
Class: dev.DeviceClass,
|
||||
SubClass: dev.DeviceSubClass,
|
||||
Protocol: dev.DeviceProtocol,
|
||||
Name: dev.DisplayName(),
|
||||
Manufacturer: dev.Manufacturer,
|
||||
NumInterfaces: uint8(len(dev.Interfaces)),
|
||||
Status: status,
|
||||
UsedBy: usedBy,
|
||||
})
|
||||
}
|
||||
|
||||
msg := &protocol.DeviceList{
|
||||
Type: protocol.MsgDeviceList,
|
||||
ClientID: sm.client.ID(),
|
||||
ClientName: sm.client.Config().Name,
|
||||
Devices: protoDevices,
|
||||
}
|
||||
|
||||
sm.client.SendJSON(msg)
|
||||
}
|
||||
|
||||
func (sm *ShareManager) handleRequestDevice(targetClient, fromClient, busID, requestID string) {
|
||||
log.Printf("[share] device request: busID=%s from=%s", busID, fromClient)
|
||||
|
||||
sm.mu.Lock()
|
||||
|
||||
// Check if device exists
|
||||
var dev *usb.Device
|
||||
for i := range sm.devices {
|
||||
if sm.devices[i].BusID == busID {
|
||||
dev = &sm.devices[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dev == nil {
|
||||
sm.mu.Unlock()
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceDenied,
|
||||
"bus_id": busID,
|
||||
"request_id": requestID,
|
||||
"reason": "device not found",
|
||||
"target_client": fromClient,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if already in use
|
||||
if _, inUse := sm.active[busID]; inUse {
|
||||
sm.mu.Unlock()
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceDenied,
|
||||
"bus_id": busID,
|
||||
"request_id": requestID,
|
||||
"reason": "device already in use",
|
||||
"target_client": fromClient,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create USB/IP server for this device
|
||||
server := usbip.NewServer(dev)
|
||||
if err := server.Attach(); err != nil {
|
||||
sm.mu.Unlock()
|
||||
log.Printf("[share] failed to attach device %s: %v", busID, err)
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceDenied,
|
||||
"bus_id": busID,
|
||||
"request_id": requestID,
|
||||
"reason": fmt.Sprintf("attach failed: %v", err),
|
||||
"target_client": fromClient,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tunnelID := uuid.New().String()[:16] // 16 chars for tunnel header
|
||||
for len(tunnelID) < 16 {
|
||||
tunnelID += "0"
|
||||
}
|
||||
|
||||
inReader, inWriter := io.Pipe()
|
||||
outReader, outWriter := io.Pipe()
|
||||
|
||||
tunnel := &shareTunnel{
|
||||
id: tunnelID,
|
||||
busID: busID,
|
||||
inPipe: inWriter,
|
||||
outPipe: outReader,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
share := &activeShare{
|
||||
device: dev,
|
||||
server: server,
|
||||
usedBy: fromClient,
|
||||
tunnelID: tunnelID,
|
||||
}
|
||||
|
||||
sm.active[busID] = share
|
||||
sm.tunnels[tunnelID] = tunnel
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Start USB/IP protocol handler in background
|
||||
go func() {
|
||||
defer func() {
|
||||
close(tunnel.done)
|
||||
inWriter.Close()
|
||||
outReader.Close()
|
||||
}()
|
||||
|
||||
// First handle the management phase (import request from client)
|
||||
// The USB/IP client will send OP_REQ_IMPORT, we respond, then enter transfer phase
|
||||
err := server.HandleConnection(inReader, outWriter)
|
||||
if err != nil {
|
||||
log.Printf("[share] USB/IP connection error for %s: %v", busID, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward outgoing data from USB/IP server to tunnel
|
||||
go func() {
|
||||
buf := make([]byte, 65536)
|
||||
for {
|
||||
n, err := outReader.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := sm.client.SendTunnelData(tunnelID, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Send grant message
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceGranted,
|
||||
"bus_id": busID,
|
||||
"tunnel_id": tunnelID,
|
||||
"request_id": requestID,
|
||||
"dev_id": dev.DevID(),
|
||||
"speed": dev.Speed,
|
||||
"target_client": fromClient,
|
||||
})
|
||||
|
||||
log.Printf("[share] device %s granted to %s (tunnel=%s)", busID, fromClient, tunnelID)
|
||||
|
||||
// Broadcast updated device list
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
}
|
||||
|
||||
func (sm *ShareManager) handleReleaseDevice(busID, fromClient string) {
|
||||
log.Printf("[share] device release: busID=%s from=%s", busID, fromClient)
|
||||
|
||||
sm.mu.Lock()
|
||||
share, exists := sm.active[busID]
|
||||
if !exists {
|
||||
sm.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Clean up tunnel
|
||||
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
|
||||
tunnel.inPipe.Close()
|
||||
delete(sm.tunnels, share.tunnelID)
|
||||
}
|
||||
|
||||
// Detach device (release interfaces, reconnect kernel driver)
|
||||
share.server.Detach()
|
||||
delete(sm.active, busID)
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Notify client
|
||||
sm.client.SendJSON(&protocol.DeviceReleased{
|
||||
Type: protocol.MsgDeviceReleased,
|
||||
BusID: busID,
|
||||
})
|
||||
|
||||
log.Printf("[share] device %s released", busID)
|
||||
|
||||
// Refresh device list
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
}
|
||||
|
||||
func (sm *ShareManager) handleTunnelData(tunnelID string, data []byte) {
|
||||
sm.mu.RLock()
|
||||
tunnel, exists := sm.tunnels[tunnelID]
|
||||
sm.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Write incoming data to the USB/IP server's input pipe
|
||||
tunnel.inPipe.Write(data)
|
||||
}
|
||||
|
||||
func (sm *ShareManager) cleanup() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
for busID, share := range sm.active {
|
||||
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
|
||||
tunnel.inPipe.Close()
|
||||
}
|
||||
share.server.Detach()
|
||||
log.Printf("[share] cleaned up device %s", busID)
|
||||
}
|
||||
|
||||
sm.active = make(map[string]*activeShare)
|
||||
sm.tunnels = make(map[string]*shareTunnel)
|
||||
}
|
||||
|
||||
// DeviceListForAPI returns device info formatted for the web API
|
||||
func (sm *ShareManager) DeviceListForAPI() []map[string]interface{} {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var result []map[string]interface{}
|
||||
for _, dev := range sm.devices {
|
||||
status := "available"
|
||||
usedBy := ""
|
||||
if share, ok := sm.active[dev.BusID]; ok {
|
||||
status = "in_use"
|
||||
usedBy = share.usedBy
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"bus_id": dev.BusID,
|
||||
"vendor_id": fmt.Sprintf("%04x", dev.VendorID),
|
||||
"product_id": fmt.Sprintf("%04x", dev.ProductID),
|
||||
"name": dev.DisplayName(),
|
||||
"status": status,
|
||||
"used_by": usedBy,
|
||||
"speed": dev.Speed,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build linux
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// createSocketPair creates a Unix domain socket pair
|
||||
func createSocketPair() ([2]int, error) {
|
||||
fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
|
||||
if err != nil {
|
||||
return [2]int{}, fmt.Errorf("socketpair: %w", err)
|
||||
}
|
||||
return fds, nil
|
||||
}
|
||||
|
||||
func closeFDs(fds [2]int) {
|
||||
unix.Close(fds[0])
|
||||
unix.Close(fds[1])
|
||||
}
|
||||
|
||||
func fdToFile(fd int, name string) *os.File {
|
||||
return os.NewFile(uintptr(fd), name)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//go:build windows
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func createSocketPair() ([2]int, error) {
|
||||
return [2]int{}, fmt.Errorf("socketpair not implemented on Windows")
|
||||
}
|
||||
|
||||
func closeFDs(fds [2]int) {}
|
||||
|
||||
func fdToFile(fd int, name string) *os.File {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,368 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/duffy/usb-server/internal/protocol"
|
||||
"github.com/duffy/usb-server/internal/usbip"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// RemoteDevice represents a USB device available from a share client
|
||||
type RemoteDevice struct {
|
||||
protocol.USBDevice
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name"`
|
||||
}
|
||||
|
||||
// AttachedDevice represents a device currently attached via VHCI
|
||||
type AttachedDevice struct {
|
||||
RemoteDevice
|
||||
TunnelID string `json:"tunnel_id"`
|
||||
VHCIPort int `json:"vhci_port"`
|
||||
SocketFD int `json:"socket_fd"`
|
||||
}
|
||||
|
||||
// UseManager handles receiving/using remote USB devices
|
||||
type UseManager struct {
|
||||
client *Client
|
||||
mu sync.RWMutex
|
||||
available map[string][]RemoteDevice // clientID -> devices
|
||||
attached map[string]*AttachedDevice // busID@clientID -> attached info
|
||||
tunnels map[string]*useTunnel // tunnelID -> tunnel
|
||||
pending map[string]chan *protocol.DeviceGranted // requestID -> response channel
|
||||
}
|
||||
|
||||
type useTunnel struct {
|
||||
id string
|
||||
busID string
|
||||
clientID string
|
||||
conn net.Conn // our end of the socketpair
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewUseManager creates a use manager
|
||||
func NewUseManager(client *Client) *UseManager {
|
||||
um := &UseManager{
|
||||
client: client,
|
||||
available: make(map[string][]RemoteDevice),
|
||||
attached: make(map[string]*AttachedDevice),
|
||||
tunnels: make(map[string]*useTunnel),
|
||||
pending: make(map[string]chan *protocol.DeviceGranted),
|
||||
}
|
||||
|
||||
client.OnDeviceList = um.handleDeviceList
|
||||
client.OnDeviceGranted = um.handleDeviceGranted
|
||||
client.OnDeviceDenied = um.handleDeviceDenied
|
||||
client.OnDeviceReleased = um.handleDeviceReleased
|
||||
client.OnTunnelData = um.handleTunnelData
|
||||
client.OnClientLeft = um.handleClientLeft
|
||||
|
||||
return um
|
||||
}
|
||||
|
||||
// GetAvailableDevices returns all available remote devices
|
||||
func (um *UseManager) GetAvailableDevices() []RemoteDevice {
|
||||
um.mu.RLock()
|
||||
defer um.mu.RUnlock()
|
||||
|
||||
var all []RemoteDevice
|
||||
for _, devs := range um.available {
|
||||
all = append(all, devs...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
// GetAttachedDevices returns currently attached devices
|
||||
func (um *UseManager) GetAttachedDevices() []*AttachedDevice {
|
||||
um.mu.RLock()
|
||||
defer um.mu.RUnlock()
|
||||
|
||||
var result []*AttachedDevice
|
||||
for _, dev := range um.attached {
|
||||
result = append(result, dev)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// AttachDevice requests and attaches a remote USB device
|
||||
func (um *UseManager) AttachDevice(clientID, busID string) error {
|
||||
// Check if VHCI is available
|
||||
if !usbip.IsVHCIAvailable() {
|
||||
return fmt.Errorf("vhci-hcd kernel module not loaded (run: sudo modprobe vhci-hcd)")
|
||||
}
|
||||
|
||||
key := busID + "@" + clientID
|
||||
um.mu.RLock()
|
||||
if _, already := um.attached[key]; already {
|
||||
um.mu.RUnlock()
|
||||
return fmt.Errorf("device %s already attached", key)
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
|
||||
// Create request
|
||||
requestID := uuid.New().String()
|
||||
respChan := make(chan *protocol.DeviceGranted, 1)
|
||||
|
||||
um.mu.Lock()
|
||||
um.pending[requestID] = respChan
|
||||
um.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
um.mu.Lock()
|
||||
delete(um.pending, requestID)
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Send request to relay
|
||||
err := um.client.SendJSON(&protocol.RequestDevice{
|
||||
Type: protocol.MsgRequestDevice,
|
||||
TargetClient: clientID,
|
||||
BusID: busID,
|
||||
RequestID: requestID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[use] requesting device %s from %s", busID, clientID)
|
||||
|
||||
// Wait for response (with timeout via context)
|
||||
select {
|
||||
case granted := <-respChan:
|
||||
return um.setupVHCI(clientID, busID, granted)
|
||||
case <-um.client.ctx.Done():
|
||||
return fmt.Errorf("client shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
// DetachDevice releases an attached device
|
||||
func (um *UseManager) DetachDevice(clientID, busID string) error {
|
||||
key := busID + "@" + clientID
|
||||
|
||||
um.mu.Lock()
|
||||
dev, exists := um.attached[key]
|
||||
if !exists {
|
||||
um.mu.Unlock()
|
||||
return fmt.Errorf("device %s not attached", key)
|
||||
}
|
||||
|
||||
// Clean up tunnel
|
||||
if tunnel, ok := um.tunnels[dev.TunnelID]; ok {
|
||||
close(tunnel.done)
|
||||
if tunnel.conn != nil {
|
||||
tunnel.conn.Close()
|
||||
}
|
||||
delete(um.tunnels, dev.TunnelID)
|
||||
}
|
||||
|
||||
// Detach from VHCI
|
||||
if dev.VHCIPort >= 0 {
|
||||
if err := usbip.DetachDevice(dev.VHCIPort); err != nil {
|
||||
log.Printf("[use] warning: VHCI detach error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
delete(um.attached, key)
|
||||
um.mu.Unlock()
|
||||
|
||||
// Notify share client
|
||||
um.client.SendJSON(&protocol.ReleaseDevice{
|
||||
Type: protocol.MsgReleaseDevice,
|
||||
TargetClient: clientID,
|
||||
BusID: busID,
|
||||
})
|
||||
|
||||
log.Printf("[use] device %s detached", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (um *UseManager) setupVHCI(clientID, busID string, granted *protocol.DeviceGranted) error {
|
||||
// Create a socketpair - one end for VHCI, one for our tunnel
|
||||
fds, err := createSocketPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating socketpair: %w", err)
|
||||
}
|
||||
|
||||
vhciFD := fds[0]
|
||||
tunnelFD := fds[1]
|
||||
|
||||
// Find a free VHCI port
|
||||
port, err := usbip.FindFreePort(granted.Speed)
|
||||
if err != nil {
|
||||
closeFDs(fds)
|
||||
return fmt.Errorf("finding free VHCI port: %w", err)
|
||||
}
|
||||
|
||||
// Attach to VHCI
|
||||
if err := usbip.AttachDevice(port, vhciFD, granted.DevID, granted.Speed); err != nil {
|
||||
closeFDs(fds)
|
||||
return fmt.Errorf("VHCI attach: %w", err)
|
||||
}
|
||||
|
||||
// The VHCI driver now owns vhciFD, so we don't close it
|
||||
// Create a net.Conn from the tunnel FD
|
||||
tunnelFile := fdToFile(tunnelFD, "usb-tunnel")
|
||||
tunnelConn, err := net.FileConn(tunnelFile)
|
||||
tunnelFile.Close() // FileConn dups the fd
|
||||
if err != nil {
|
||||
usbip.DetachDevice(port)
|
||||
return fmt.Errorf("creating tunnel conn: %w", err)
|
||||
}
|
||||
|
||||
tunnel := &useTunnel{
|
||||
id: granted.TunnelID,
|
||||
busID: busID,
|
||||
clientID: clientID,
|
||||
conn: tunnelConn,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
key := busID + "@" + clientID
|
||||
um.mu.Lock()
|
||||
um.tunnels[granted.TunnelID] = tunnel
|
||||
um.attached[key] = &AttachedDevice{
|
||||
RemoteDevice: RemoteDevice{
|
||||
USBDevice: protocol.USBDevice{BusID: busID},
|
||||
ClientID: clientID,
|
||||
},
|
||||
TunnelID: granted.TunnelID,
|
||||
VHCIPort: port,
|
||||
SocketFD: vhciFD,
|
||||
}
|
||||
um.mu.Unlock()
|
||||
|
||||
// Start reading from the tunnel socket (VHCI -> relay)
|
||||
go um.tunnelReadLoop(tunnel)
|
||||
|
||||
log.Printf("[use] device %s attached on VHCI port %d", key, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
// tunnelReadLoop reads from the VHCI socket and sends to relay
|
||||
func (um *UseManager) tunnelReadLoop(tunnel *useTunnel) {
|
||||
buf := make([]byte, 65536)
|
||||
for {
|
||||
select {
|
||||
case <-tunnel.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := tunnel.conn.Read(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-tunnel.done:
|
||||
return
|
||||
default:
|
||||
log.Printf("[use] tunnel read error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := um.client.SendTunnelData(tunnel.id, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (um *UseManager) handleDeviceList(msg *protocol.DeviceList) {
|
||||
um.mu.Lock()
|
||||
var remoteDevs []RemoteDevice
|
||||
for _, dev := range msg.Devices {
|
||||
remoteDevs = append(remoteDevs, RemoteDevice{
|
||||
USBDevice: dev,
|
||||
ClientID: msg.ClientID,
|
||||
ClientName: msg.ClientName,
|
||||
})
|
||||
}
|
||||
um.available[msg.ClientID] = remoteDevs
|
||||
um.mu.Unlock()
|
||||
|
||||
log.Printf("[use] received device list from %s (%s): %d devices",
|
||||
msg.ClientName, msg.ClientID[:8], len(msg.Devices))
|
||||
}
|
||||
|
||||
func (um *UseManager) handleDeviceGranted(msg *protocol.DeviceGranted) {
|
||||
um.mu.RLock()
|
||||
ch, exists := um.pending[msg.RequestID]
|
||||
um.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
ch <- msg
|
||||
}
|
||||
}
|
||||
|
||||
func (um *UseManager) handleDeviceDenied(msg *protocol.DeviceDenied) {
|
||||
log.Printf("[use] device request denied: %s - %s", msg.BusID, msg.Reason)
|
||||
|
||||
um.mu.RLock()
|
||||
ch, exists := um.pending[msg.RequestID]
|
||||
um.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
close(ch) // signal denial by closing channel
|
||||
}
|
||||
}
|
||||
|
||||
func (um *UseManager) handleDeviceReleased(msg *protocol.DeviceReleased) {
|
||||
log.Printf("[use] device released by share client: %s", msg.BusID)
|
||||
}
|
||||
|
||||
func (um *UseManager) handleTunnelData(tunnelID string, data []byte) {
|
||||
um.mu.RLock()
|
||||
tunnel, exists := um.tunnels[tunnelID]
|
||||
um.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Write to the tunnel socket (relay -> VHCI)
|
||||
tunnel.conn.Write(data)
|
||||
}
|
||||
|
||||
func (um *UseManager) handleClientLeft(msg *protocol.ClientLeft) {
|
||||
um.mu.Lock()
|
||||
delete(um.available, msg.ClientID)
|
||||
|
||||
// Detach any devices from this client
|
||||
for key, dev := range um.attached {
|
||||
if dev.ClientID == msg.ClientID {
|
||||
if tunnel, ok := um.tunnels[dev.TunnelID]; ok {
|
||||
close(tunnel.done)
|
||||
tunnel.conn.Close()
|
||||
delete(um.tunnels, dev.TunnelID)
|
||||
}
|
||||
if dev.VHCIPort >= 0 {
|
||||
usbip.DetachDevice(dev.VHCIPort)
|
||||
}
|
||||
delete(um.attached, key)
|
||||
log.Printf("[use] device %s auto-detached (client left)", key)
|
||||
}
|
||||
}
|
||||
um.mu.Unlock()
|
||||
}
|
||||
|
||||
// Cleanup releases all attached devices
|
||||
func (um *UseManager) Cleanup() {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
for key, dev := range um.attached {
|
||||
if tunnel, ok := um.tunnels[dev.TunnelID]; ok {
|
||||
close(tunnel.done)
|
||||
tunnel.conn.Close()
|
||||
}
|
||||
if dev.VHCIPort >= 0 {
|
||||
usbip.DetachDevice(dev.VHCIPort)
|
||||
}
|
||||
log.Printf("[use] cleaned up device %s", key)
|
||||
}
|
||||
|
||||
um.attached = make(map[string]*AttachedDevice)
|
||||
um.tunnels = make(map[string]*useTunnel)
|
||||
}
|
||||
Reference in New Issue
Block a user