410 lines
9.6 KiB
Go
410 lines
9.6 KiB
Go
package client
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/duffy/usb-server/internal/config"
|
|
"github.com/duffy/usb-server/internal/protocol"
|
|
"github.com/duffy/usb-server/internal/usb"
|
|
"github.com/duffy/usb-server/internal/usbip"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// ShareManager handles sharing USB devices
|
|
type ShareManager struct {
|
|
client *Client
|
|
cfg *config.Config
|
|
mu sync.RWMutex
|
|
devices []usb.Device
|
|
active map[string]*activeShare // busID -> active share
|
|
tunnels map[string]*shareTunnel // tunnelID -> tunnel
|
|
}
|
|
|
|
type activeShare struct {
|
|
device *usb.Device
|
|
server *usbip.Server
|
|
usedBy string // client ID using this device
|
|
tunnelID string
|
|
}
|
|
|
|
type shareTunnel struct {
|
|
id string
|
|
busID string
|
|
inPipe *io.PipeWriter
|
|
outPipe *io.PipeReader
|
|
done chan struct{}
|
|
}
|
|
|
|
// NewShareManager creates a share manager
|
|
func NewShareManager(client *Client, cfg *config.Config) *ShareManager {
|
|
sm := &ShareManager{
|
|
client: client,
|
|
cfg: cfg,
|
|
active: make(map[string]*activeShare),
|
|
tunnels: make(map[string]*shareTunnel),
|
|
}
|
|
|
|
// Set up callbacks
|
|
client.OnRequestDevice = sm.handleRequestDevice
|
|
client.OnReleaseDevice = sm.handleReleaseDevice
|
|
client.OnTunnelData = sm.handleTunnelData
|
|
client.OnClientLeft = sm.handleClientLeft
|
|
client.OnForceRelease = sm.handleForceRelease
|
|
|
|
return sm
|
|
}
|
|
|
|
// Run starts the share manager: periodic device enumeration + event handling
|
|
func (sm *ShareManager) Run() error {
|
|
// Initial enumeration
|
|
sm.refreshDevices()
|
|
sm.broadcastDeviceList()
|
|
|
|
// Periodic refresh
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
sm.refreshDevices()
|
|
sm.broadcastDeviceList()
|
|
case <-sm.client.ctx.Done():
|
|
sm.cleanup()
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetDevices returns the current device list
|
|
func (sm *ShareManager) GetDevices() []usb.Device {
|
|
sm.mu.RLock()
|
|
defer sm.mu.RUnlock()
|
|
result := make([]usb.Device, len(sm.devices))
|
|
copy(result, sm.devices)
|
|
return result
|
|
}
|
|
|
|
func (sm *ShareManager) refreshDevices() {
|
|
devices, err := usb.Enumerate()
|
|
if err != nil {
|
|
log.Printf("[share] USB enumeration error: %v", err)
|
|
return
|
|
}
|
|
|
|
sm.mu.Lock()
|
|
sm.devices = devices
|
|
sm.mu.Unlock()
|
|
}
|
|
|
|
func (sm *ShareManager) broadcastDeviceList() {
|
|
sm.mu.RLock()
|
|
defer sm.mu.RUnlock()
|
|
|
|
var protoDevices []protocol.USBDevice
|
|
for _, dev := range sm.devices {
|
|
status := protocol.StatusAvailable
|
|
usedBy := ""
|
|
if share, ok := sm.active[dev.BusID]; ok {
|
|
status = protocol.StatusInUse
|
|
usedBy = share.usedBy
|
|
}
|
|
|
|
protoDevices = append(protoDevices, protocol.USBDevice{
|
|
BusID: dev.BusID,
|
|
BusNum: dev.BusNum,
|
|
DevNum: dev.DevNum,
|
|
Speed: dev.Speed,
|
|
VendorID: fmt.Sprintf("%04x", dev.VendorID),
|
|
ProductID: fmt.Sprintf("%04x", dev.ProductID),
|
|
Class: dev.DeviceClass,
|
|
SubClass: dev.DeviceSubClass,
|
|
Protocol: dev.DeviceProtocol,
|
|
Name: dev.DisplayName(),
|
|
Manufacturer: dev.Manufacturer,
|
|
NumInterfaces: uint8(len(dev.Interfaces)),
|
|
Status: status,
|
|
UsedBy: usedBy,
|
|
})
|
|
}
|
|
|
|
msg := &protocol.DeviceList{
|
|
Type: protocol.MsgDeviceList,
|
|
ClientID: sm.client.ID(),
|
|
ClientName: sm.client.Config().Name,
|
|
Devices: protoDevices,
|
|
AllowForceDetach: sm.cfg.AllowForceDetach,
|
|
}
|
|
|
|
sm.client.SendJSON(msg)
|
|
}
|
|
|
|
func (sm *ShareManager) handleRequestDevice(targetClient, fromClient, busID, requestID string) {
|
|
log.Printf("[share] device request: busID=%s from=%s", busID, fromClient)
|
|
|
|
sm.mu.Lock()
|
|
|
|
// Check if device exists
|
|
var dev *usb.Device
|
|
for i := range sm.devices {
|
|
if sm.devices[i].BusID == busID {
|
|
dev = &sm.devices[i]
|
|
break
|
|
}
|
|
}
|
|
|
|
if dev == nil {
|
|
sm.mu.Unlock()
|
|
sm.client.SendJSON(map[string]interface{}{
|
|
"type": protocol.MsgDeviceDenied,
|
|
"bus_id": busID,
|
|
"request_id": requestID,
|
|
"reason": "device not found",
|
|
"target_client": fromClient,
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if already in use
|
|
if _, inUse := sm.active[busID]; inUse {
|
|
sm.mu.Unlock()
|
|
sm.client.SendJSON(map[string]interface{}{
|
|
"type": protocol.MsgDeviceDenied,
|
|
"bus_id": busID,
|
|
"request_id": requestID,
|
|
"reason": "device already in use",
|
|
"target_client": fromClient,
|
|
})
|
|
return
|
|
}
|
|
|
|
// Create USB/IP server for this device
|
|
server := usbip.NewServer(dev)
|
|
if err := server.Attach(); err != nil {
|
|
sm.mu.Unlock()
|
|
log.Printf("[share] failed to attach device %s: %v", busID, err)
|
|
sm.client.SendJSON(map[string]interface{}{
|
|
"type": protocol.MsgDeviceDenied,
|
|
"bus_id": busID,
|
|
"request_id": requestID,
|
|
"reason": fmt.Sprintf("attach failed: %v", err),
|
|
"target_client": fromClient,
|
|
})
|
|
return
|
|
}
|
|
|
|
tunnelID := uuid.New().String()[:16] // 16 chars for tunnel header
|
|
for len(tunnelID) < 16 {
|
|
tunnelID += "0"
|
|
}
|
|
|
|
inReader, inWriter := io.Pipe()
|
|
outReader, outWriter := io.Pipe()
|
|
|
|
tunnel := &shareTunnel{
|
|
id: tunnelID,
|
|
busID: busID,
|
|
inPipe: inWriter,
|
|
outPipe: outReader,
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
share := &activeShare{
|
|
device: dev,
|
|
server: server,
|
|
usedBy: fromClient,
|
|
tunnelID: tunnelID,
|
|
}
|
|
|
|
sm.active[busID] = share
|
|
sm.tunnels[tunnelID] = tunnel
|
|
sm.mu.Unlock()
|
|
|
|
// Start USB/IP protocol handler in background
|
|
go func() {
|
|
defer func() {
|
|
close(tunnel.done)
|
|
inWriter.Close()
|
|
outReader.Close()
|
|
}()
|
|
|
|
// First handle the management phase (import request from client)
|
|
// The USB/IP client will send OP_REQ_IMPORT, we respond, then enter transfer phase
|
|
err := server.HandleConnection(inReader, outWriter)
|
|
if err != nil {
|
|
log.Printf("[share] USB/IP connection error for %s: %v", busID, err)
|
|
}
|
|
}()
|
|
|
|
// Forward outgoing data from USB/IP server to tunnel
|
|
go func() {
|
|
buf := make([]byte, 65536)
|
|
for {
|
|
n, err := outReader.Read(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if err := sm.client.SendTunnelData(tunnelID, buf[:n]); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Send grant message
|
|
sm.client.SendJSON(map[string]interface{}{
|
|
"type": protocol.MsgDeviceGranted,
|
|
"bus_id": busID,
|
|
"tunnel_id": tunnelID,
|
|
"request_id": requestID,
|
|
"dev_id": dev.DevID(),
|
|
"speed": dev.Speed,
|
|
"target_client": fromClient,
|
|
})
|
|
|
|
log.Printf("[share] device %s granted to %s (tunnel=%s)", busID, fromClient, tunnelID)
|
|
|
|
// Broadcast updated device list
|
|
sm.refreshDevices()
|
|
sm.broadcastDeviceList()
|
|
}
|
|
|
|
func (sm *ShareManager) handleReleaseDevice(busID, fromClient string) {
|
|
log.Printf("[share] device release: busID=%s from=%s", busID, fromClient)
|
|
|
|
sm.mu.Lock()
|
|
share, exists := sm.active[busID]
|
|
if !exists {
|
|
sm.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
// Close the tunnel pipe to signal HandleConnection to stop reading
|
|
var tunnelDone <-chan struct{}
|
|
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
|
|
tunnel.inPipe.Close()
|
|
tunnelDone = tunnel.done
|
|
delete(sm.tunnels, share.tunnelID)
|
|
}
|
|
|
|
server := share.server
|
|
delete(sm.active, busID)
|
|
sm.mu.Unlock()
|
|
|
|
// Wait for HandleConnection goroutine to finish before detaching.
|
|
// This ensures no more URBs are being submitted when we detach.
|
|
if tunnelDone != nil {
|
|
<-tunnelDone
|
|
log.Printf("[share] HandleConnection goroutine finished for %s", busID)
|
|
}
|
|
|
|
// Now safe to detach - no more USB/IP protocol processing
|
|
server.Detach()
|
|
|
|
// Notify client
|
|
sm.client.SendJSON(&protocol.DeviceReleased{
|
|
Type: protocol.MsgDeviceReleased,
|
|
BusID: busID,
|
|
})
|
|
|
|
log.Printf("[share] device %s released", busID)
|
|
|
|
// Refresh device list
|
|
sm.refreshDevices()
|
|
sm.broadcastDeviceList()
|
|
}
|
|
|
|
func (sm *ShareManager) handleForceRelease(targetClient, fromClient, busID string) {
|
|
if !sm.cfg.AllowForceDetach {
|
|
log.Printf("[share] force-release denied for %s (not allowed by config)", busID)
|
|
return
|
|
}
|
|
|
|
sm.mu.RLock()
|
|
share, exists := sm.active[busID]
|
|
sm.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
log.Printf("[share] force-releasing %s (requested by %s, was used by %s)", busID, fromClient[:8], share.usedBy[:8])
|
|
sm.handleReleaseDevice(busID, share.usedBy)
|
|
}
|
|
|
|
func (sm *ShareManager) handleClientLeft(msg *protocol.ClientLeft) {
|
|
sm.mu.RLock()
|
|
var toRelease []string
|
|
for busID, share := range sm.active {
|
|
if share.usedBy == msg.ClientID {
|
|
toRelease = append(toRelease, busID)
|
|
}
|
|
}
|
|
sm.mu.RUnlock()
|
|
|
|
for _, busID := range toRelease {
|
|
log.Printf("[share] auto-releasing %s (client %s left)", busID, msg.ClientID[:8])
|
|
sm.handleReleaseDevice(busID, msg.ClientID)
|
|
}
|
|
}
|
|
|
|
func (sm *ShareManager) handleTunnelData(tunnelID string, data []byte) {
|
|
sm.mu.RLock()
|
|
tunnel, exists := sm.tunnels[tunnelID]
|
|
sm.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
// Write incoming data to the USB/IP server's input pipe
|
|
tunnel.inPipe.Write(data)
|
|
}
|
|
|
|
func (sm *ShareManager) cleanup() {
|
|
sm.mu.Lock()
|
|
defer sm.mu.Unlock()
|
|
|
|
for busID, share := range sm.active {
|
|
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
|
|
tunnel.inPipe.Close()
|
|
}
|
|
share.server.Detach()
|
|
log.Printf("[share] cleaned up device %s", busID)
|
|
}
|
|
|
|
sm.active = make(map[string]*activeShare)
|
|
sm.tunnels = make(map[string]*shareTunnel)
|
|
}
|
|
|
|
// DeviceListForAPI returns device info formatted for the web API
|
|
func (sm *ShareManager) DeviceListForAPI() []map[string]interface{} {
|
|
sm.mu.RLock()
|
|
defer sm.mu.RUnlock()
|
|
|
|
var result []map[string]interface{}
|
|
for _, dev := range sm.devices {
|
|
status := "available"
|
|
usedBy := ""
|
|
if share, ok := sm.active[dev.BusID]; ok {
|
|
status = "in_use"
|
|
usedBy = share.usedBy
|
|
}
|
|
|
|
result = append(result, map[string]interface{}{
|
|
"bus_id": dev.BusID,
|
|
"vendor_id": fmt.Sprintf("%04x", dev.VendorID),
|
|
"product_id": fmt.Sprintf("%04x", dev.ProductID),
|
|
"name": dev.DisplayName(),
|
|
"status": status,
|
|
"used_by": usedBy,
|
|
"speed": dev.Speed,
|
|
})
|
|
}
|
|
return result
|
|
}
|
|
|