usb-server/internal/client/share.go

410 lines
9.6 KiB
Go

package client
import (
"fmt"
"io"
"log"
"sync"
"time"
"github.com/duffy/usb-server/internal/config"
"github.com/duffy/usb-server/internal/protocol"
"github.com/duffy/usb-server/internal/usb"
"github.com/duffy/usb-server/internal/usbip"
"github.com/google/uuid"
)
// ShareManager handles sharing USB devices
type ShareManager struct {
client *Client
cfg *config.Config
mu sync.RWMutex
devices []usb.Device
active map[string]*activeShare // busID -> active share
tunnels map[string]*shareTunnel // tunnelID -> tunnel
}
type activeShare struct {
device *usb.Device
server *usbip.Server
usedBy string // client ID using this device
tunnelID string
}
type shareTunnel struct {
id string
busID string
inPipe *io.PipeWriter
outPipe *io.PipeReader
done chan struct{}
}
// NewShareManager creates a share manager
func NewShareManager(client *Client, cfg *config.Config) *ShareManager {
sm := &ShareManager{
client: client,
cfg: cfg,
active: make(map[string]*activeShare),
tunnels: make(map[string]*shareTunnel),
}
// Set up callbacks
client.OnRequestDevice = sm.handleRequestDevice
client.OnReleaseDevice = sm.handleReleaseDevice
client.OnTunnelData = sm.handleTunnelData
client.OnClientLeft = sm.handleClientLeft
client.OnForceRelease = sm.handleForceRelease
return sm
}
// Run starts the share manager: periodic device enumeration + event handling
func (sm *ShareManager) Run() error {
// Initial enumeration
sm.refreshDevices()
sm.broadcastDeviceList()
// Periodic refresh
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sm.refreshDevices()
sm.broadcastDeviceList()
case <-sm.client.ctx.Done():
sm.cleanup()
return nil
}
}
}
// GetDevices returns the current device list
func (sm *ShareManager) GetDevices() []usb.Device {
sm.mu.RLock()
defer sm.mu.RUnlock()
result := make([]usb.Device, len(sm.devices))
copy(result, sm.devices)
return result
}
func (sm *ShareManager) refreshDevices() {
devices, err := usb.Enumerate()
if err != nil {
log.Printf("[share] USB enumeration error: %v", err)
return
}
sm.mu.Lock()
sm.devices = devices
sm.mu.Unlock()
}
func (sm *ShareManager) broadcastDeviceList() {
sm.mu.RLock()
defer sm.mu.RUnlock()
var protoDevices []protocol.USBDevice
for _, dev := range sm.devices {
status := protocol.StatusAvailable
usedBy := ""
if share, ok := sm.active[dev.BusID]; ok {
status = protocol.StatusInUse
usedBy = share.usedBy
}
protoDevices = append(protoDevices, protocol.USBDevice{
BusID: dev.BusID,
BusNum: dev.BusNum,
DevNum: dev.DevNum,
Speed: dev.Speed,
VendorID: fmt.Sprintf("%04x", dev.VendorID),
ProductID: fmt.Sprintf("%04x", dev.ProductID),
Class: dev.DeviceClass,
SubClass: dev.DeviceSubClass,
Protocol: dev.DeviceProtocol,
Name: dev.DisplayName(),
Manufacturer: dev.Manufacturer,
NumInterfaces: uint8(len(dev.Interfaces)),
Status: status,
UsedBy: usedBy,
})
}
msg := &protocol.DeviceList{
Type: protocol.MsgDeviceList,
ClientID: sm.client.ID(),
ClientName: sm.client.Config().Name,
Devices: protoDevices,
AllowForceDetach: sm.cfg.AllowForceDetach,
}
sm.client.SendJSON(msg)
}
func (sm *ShareManager) handleRequestDevice(targetClient, fromClient, busID, requestID string) {
log.Printf("[share] device request: busID=%s from=%s", busID, fromClient)
sm.mu.Lock()
// Check if device exists
var dev *usb.Device
for i := range sm.devices {
if sm.devices[i].BusID == busID {
dev = &sm.devices[i]
break
}
}
if dev == nil {
sm.mu.Unlock()
sm.client.SendJSON(map[string]interface{}{
"type": protocol.MsgDeviceDenied,
"bus_id": busID,
"request_id": requestID,
"reason": "device not found",
"target_client": fromClient,
})
return
}
// Check if already in use
if _, inUse := sm.active[busID]; inUse {
sm.mu.Unlock()
sm.client.SendJSON(map[string]interface{}{
"type": protocol.MsgDeviceDenied,
"bus_id": busID,
"request_id": requestID,
"reason": "device already in use",
"target_client": fromClient,
})
return
}
// Create USB/IP server for this device
server := usbip.NewServer(dev)
if err := server.Attach(); err != nil {
sm.mu.Unlock()
log.Printf("[share] failed to attach device %s: %v", busID, err)
sm.client.SendJSON(map[string]interface{}{
"type": protocol.MsgDeviceDenied,
"bus_id": busID,
"request_id": requestID,
"reason": fmt.Sprintf("attach failed: %v", err),
"target_client": fromClient,
})
return
}
tunnelID := uuid.New().String()[:16] // 16 chars for tunnel header
for len(tunnelID) < 16 {
tunnelID += "0"
}
inReader, inWriter := io.Pipe()
outReader, outWriter := io.Pipe()
tunnel := &shareTunnel{
id: tunnelID,
busID: busID,
inPipe: inWriter,
outPipe: outReader,
done: make(chan struct{}),
}
share := &activeShare{
device: dev,
server: server,
usedBy: fromClient,
tunnelID: tunnelID,
}
sm.active[busID] = share
sm.tunnels[tunnelID] = tunnel
sm.mu.Unlock()
// Start USB/IP protocol handler in background
go func() {
defer func() {
close(tunnel.done)
inWriter.Close()
outReader.Close()
}()
// First handle the management phase (import request from client)
// The USB/IP client will send OP_REQ_IMPORT, we respond, then enter transfer phase
err := server.HandleConnection(inReader, outWriter)
if err != nil {
log.Printf("[share] USB/IP connection error for %s: %v", busID, err)
}
}()
// Forward outgoing data from USB/IP server to tunnel
go func() {
buf := make([]byte, 65536)
for {
n, err := outReader.Read(buf)
if err != nil {
return
}
if err := sm.client.SendTunnelData(tunnelID, buf[:n]); err != nil {
return
}
}
}()
// Send grant message
sm.client.SendJSON(map[string]interface{}{
"type": protocol.MsgDeviceGranted,
"bus_id": busID,
"tunnel_id": tunnelID,
"request_id": requestID,
"dev_id": dev.DevID(),
"speed": dev.Speed,
"target_client": fromClient,
})
log.Printf("[share] device %s granted to %s (tunnel=%s)", busID, fromClient, tunnelID)
// Broadcast updated device list
sm.refreshDevices()
sm.broadcastDeviceList()
}
func (sm *ShareManager) handleReleaseDevice(busID, fromClient string) {
log.Printf("[share] device release: busID=%s from=%s", busID, fromClient)
sm.mu.Lock()
share, exists := sm.active[busID]
if !exists {
sm.mu.Unlock()
return
}
// Close the tunnel pipe to signal HandleConnection to stop reading
var tunnelDone <-chan struct{}
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
tunnel.inPipe.Close()
tunnelDone = tunnel.done
delete(sm.tunnels, share.tunnelID)
}
server := share.server
delete(sm.active, busID)
sm.mu.Unlock()
// Wait for HandleConnection goroutine to finish before detaching.
// This ensures no more URBs are being submitted when we detach.
if tunnelDone != nil {
<-tunnelDone
log.Printf("[share] HandleConnection goroutine finished for %s", busID)
}
// Now safe to detach - no more USB/IP protocol processing
server.Detach()
// Notify client
sm.client.SendJSON(&protocol.DeviceReleased{
Type: protocol.MsgDeviceReleased,
BusID: busID,
})
log.Printf("[share] device %s released", busID)
// Refresh device list
sm.refreshDevices()
sm.broadcastDeviceList()
}
func (sm *ShareManager) handleForceRelease(targetClient, fromClient, busID string) {
if !sm.cfg.AllowForceDetach {
log.Printf("[share] force-release denied for %s (not allowed by config)", busID)
return
}
sm.mu.RLock()
share, exists := sm.active[busID]
sm.mu.RUnlock()
if !exists {
return
}
log.Printf("[share] force-releasing %s (requested by %s, was used by %s)", busID, fromClient[:8], share.usedBy[:8])
sm.handleReleaseDevice(busID, share.usedBy)
}
func (sm *ShareManager) handleClientLeft(msg *protocol.ClientLeft) {
sm.mu.RLock()
var toRelease []string
for busID, share := range sm.active {
if share.usedBy == msg.ClientID {
toRelease = append(toRelease, busID)
}
}
sm.mu.RUnlock()
for _, busID := range toRelease {
log.Printf("[share] auto-releasing %s (client %s left)", busID, msg.ClientID[:8])
sm.handleReleaseDevice(busID, msg.ClientID)
}
}
func (sm *ShareManager) handleTunnelData(tunnelID string, data []byte) {
sm.mu.RLock()
tunnel, exists := sm.tunnels[tunnelID]
sm.mu.RUnlock()
if !exists {
return
}
// Write incoming data to the USB/IP server's input pipe
tunnel.inPipe.Write(data)
}
func (sm *ShareManager) cleanup() {
sm.mu.Lock()
defer sm.mu.Unlock()
for busID, share := range sm.active {
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
tunnel.inPipe.Close()
}
share.server.Detach()
log.Printf("[share] cleaned up device %s", busID)
}
sm.active = make(map[string]*activeShare)
sm.tunnels = make(map[string]*shareTunnel)
}
// DeviceListForAPI returns device info formatted for the web API
func (sm *ShareManager) DeviceListForAPI() []map[string]interface{} {
sm.mu.RLock()
defer sm.mu.RUnlock()
var result []map[string]interface{}
for _, dev := range sm.devices {
status := "available"
usedBy := ""
if share, ok := sm.active[dev.BusID]; ok {
status = "in_use"
usedBy = share.usedBy
}
result = append(result, map[string]interface{}{
"bus_id": dev.BusID,
"vendor_id": fmt.Sprintf("%04x", dev.VendorID),
"product_id": fmt.Sprintf("%04x", dev.ProductID),
"name": dev.DisplayName(),
"status": status,
"used_by": usedBy,
"speed": dev.Speed,
})
}
return result
}