first commit
This commit is contained in:
@@ -0,0 +1,358 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/duffy/usb-server/internal/protocol"
|
||||
"github.com/duffy/usb-server/internal/usb"
|
||||
"github.com/duffy/usb-server/internal/usbip"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ShareManager handles sharing USB devices
|
||||
type ShareManager struct {
|
||||
client *Client
|
||||
mu sync.RWMutex
|
||||
devices []usb.Device
|
||||
active map[string]*activeShare // busID -> active share
|
||||
tunnels map[string]*shareTunnel // tunnelID -> tunnel
|
||||
}
|
||||
|
||||
type activeShare struct {
|
||||
device *usb.Device
|
||||
server *usbip.Server
|
||||
usedBy string // client ID using this device
|
||||
tunnelID string
|
||||
}
|
||||
|
||||
type shareTunnel struct {
|
||||
id string
|
||||
busID string
|
||||
inPipe *io.PipeWriter
|
||||
outPipe *io.PipeReader
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewShareManager creates a share manager
|
||||
func NewShareManager(client *Client) *ShareManager {
|
||||
sm := &ShareManager{
|
||||
client: client,
|
||||
active: make(map[string]*activeShare),
|
||||
tunnels: make(map[string]*shareTunnel),
|
||||
}
|
||||
|
||||
// Set up callbacks
|
||||
client.OnRequestDevice = sm.handleRequestDevice
|
||||
client.OnReleaseDevice = sm.handleReleaseDevice
|
||||
client.OnTunnelData = sm.handleTunnelData
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// Run starts the share manager: periodic device enumeration + event handling
|
||||
func (sm *ShareManager) Run() error {
|
||||
// Initial enumeration
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
|
||||
// Periodic refresh
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
case <-sm.client.ctx.Done():
|
||||
sm.cleanup()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetDevices returns the current device list
|
||||
func (sm *ShareManager) GetDevices() []usb.Device {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
result := make([]usb.Device, len(sm.devices))
|
||||
copy(result, sm.devices)
|
||||
return result
|
||||
}
|
||||
|
||||
func (sm *ShareManager) refreshDevices() {
|
||||
devices, err := usb.Enumerate()
|
||||
if err != nil {
|
||||
log.Printf("[share] USB enumeration error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
sm.devices = devices
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
|
||||
func (sm *ShareManager) broadcastDeviceList() {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var protoDevices []protocol.USBDevice
|
||||
for _, dev := range sm.devices {
|
||||
status := protocol.StatusAvailable
|
||||
usedBy := ""
|
||||
if share, ok := sm.active[dev.BusID]; ok {
|
||||
status = protocol.StatusInUse
|
||||
usedBy = share.usedBy
|
||||
}
|
||||
|
||||
protoDevices = append(protoDevices, protocol.USBDevice{
|
||||
BusID: dev.BusID,
|
||||
BusNum: dev.BusNum,
|
||||
DevNum: dev.DevNum,
|
||||
Speed: dev.Speed,
|
||||
VendorID: fmt.Sprintf("%04x", dev.VendorID),
|
||||
ProductID: fmt.Sprintf("%04x", dev.ProductID),
|
||||
Class: dev.DeviceClass,
|
||||
SubClass: dev.DeviceSubClass,
|
||||
Protocol: dev.DeviceProtocol,
|
||||
Name: dev.DisplayName(),
|
||||
Manufacturer: dev.Manufacturer,
|
||||
NumInterfaces: uint8(len(dev.Interfaces)),
|
||||
Status: status,
|
||||
UsedBy: usedBy,
|
||||
})
|
||||
}
|
||||
|
||||
msg := &protocol.DeviceList{
|
||||
Type: protocol.MsgDeviceList,
|
||||
ClientID: sm.client.ID(),
|
||||
ClientName: sm.client.Config().Name,
|
||||
Devices: protoDevices,
|
||||
}
|
||||
|
||||
sm.client.SendJSON(msg)
|
||||
}
|
||||
|
||||
func (sm *ShareManager) handleRequestDevice(targetClient, fromClient, busID, requestID string) {
|
||||
log.Printf("[share] device request: busID=%s from=%s", busID, fromClient)
|
||||
|
||||
sm.mu.Lock()
|
||||
|
||||
// Check if device exists
|
||||
var dev *usb.Device
|
||||
for i := range sm.devices {
|
||||
if sm.devices[i].BusID == busID {
|
||||
dev = &sm.devices[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if dev == nil {
|
||||
sm.mu.Unlock()
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceDenied,
|
||||
"bus_id": busID,
|
||||
"request_id": requestID,
|
||||
"reason": "device not found",
|
||||
"target_client": fromClient,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if already in use
|
||||
if _, inUse := sm.active[busID]; inUse {
|
||||
sm.mu.Unlock()
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceDenied,
|
||||
"bus_id": busID,
|
||||
"request_id": requestID,
|
||||
"reason": "device already in use",
|
||||
"target_client": fromClient,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create USB/IP server for this device
|
||||
server := usbip.NewServer(dev)
|
||||
if err := server.Attach(); err != nil {
|
||||
sm.mu.Unlock()
|
||||
log.Printf("[share] failed to attach device %s: %v", busID, err)
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceDenied,
|
||||
"bus_id": busID,
|
||||
"request_id": requestID,
|
||||
"reason": fmt.Sprintf("attach failed: %v", err),
|
||||
"target_client": fromClient,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tunnelID := uuid.New().String()[:16] // 16 chars for tunnel header
|
||||
for len(tunnelID) < 16 {
|
||||
tunnelID += "0"
|
||||
}
|
||||
|
||||
inReader, inWriter := io.Pipe()
|
||||
outReader, outWriter := io.Pipe()
|
||||
|
||||
tunnel := &shareTunnel{
|
||||
id: tunnelID,
|
||||
busID: busID,
|
||||
inPipe: inWriter,
|
||||
outPipe: outReader,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
share := &activeShare{
|
||||
device: dev,
|
||||
server: server,
|
||||
usedBy: fromClient,
|
||||
tunnelID: tunnelID,
|
||||
}
|
||||
|
||||
sm.active[busID] = share
|
||||
sm.tunnels[tunnelID] = tunnel
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Start USB/IP protocol handler in background
|
||||
go func() {
|
||||
defer func() {
|
||||
close(tunnel.done)
|
||||
inWriter.Close()
|
||||
outReader.Close()
|
||||
}()
|
||||
|
||||
// First handle the management phase (import request from client)
|
||||
// The USB/IP client will send OP_REQ_IMPORT, we respond, then enter transfer phase
|
||||
err := server.HandleConnection(inReader, outWriter)
|
||||
if err != nil {
|
||||
log.Printf("[share] USB/IP connection error for %s: %v", busID, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward outgoing data from USB/IP server to tunnel
|
||||
go func() {
|
||||
buf := make([]byte, 65536)
|
||||
for {
|
||||
n, err := outReader.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := sm.client.SendTunnelData(tunnelID, buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Send grant message
|
||||
sm.client.SendJSON(map[string]interface{}{
|
||||
"type": protocol.MsgDeviceGranted,
|
||||
"bus_id": busID,
|
||||
"tunnel_id": tunnelID,
|
||||
"request_id": requestID,
|
||||
"dev_id": dev.DevID(),
|
||||
"speed": dev.Speed,
|
||||
"target_client": fromClient,
|
||||
})
|
||||
|
||||
log.Printf("[share] device %s granted to %s (tunnel=%s)", busID, fromClient, tunnelID)
|
||||
|
||||
// Broadcast updated device list
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
}
|
||||
|
||||
func (sm *ShareManager) handleReleaseDevice(busID, fromClient string) {
|
||||
log.Printf("[share] device release: busID=%s from=%s", busID, fromClient)
|
||||
|
||||
sm.mu.Lock()
|
||||
share, exists := sm.active[busID]
|
||||
if !exists {
|
||||
sm.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Clean up tunnel
|
||||
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
|
||||
tunnel.inPipe.Close()
|
||||
delete(sm.tunnels, share.tunnelID)
|
||||
}
|
||||
|
||||
// Detach device (release interfaces, reconnect kernel driver)
|
||||
share.server.Detach()
|
||||
delete(sm.active, busID)
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Notify client
|
||||
sm.client.SendJSON(&protocol.DeviceReleased{
|
||||
Type: protocol.MsgDeviceReleased,
|
||||
BusID: busID,
|
||||
})
|
||||
|
||||
log.Printf("[share] device %s released", busID)
|
||||
|
||||
// Refresh device list
|
||||
sm.refreshDevices()
|
||||
sm.broadcastDeviceList()
|
||||
}
|
||||
|
||||
func (sm *ShareManager) handleTunnelData(tunnelID string, data []byte) {
|
||||
sm.mu.RLock()
|
||||
tunnel, exists := sm.tunnels[tunnelID]
|
||||
sm.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Write incoming data to the USB/IP server's input pipe
|
||||
tunnel.inPipe.Write(data)
|
||||
}
|
||||
|
||||
func (sm *ShareManager) cleanup() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
for busID, share := range sm.active {
|
||||
if tunnel, ok := sm.tunnels[share.tunnelID]; ok {
|
||||
tunnel.inPipe.Close()
|
||||
}
|
||||
share.server.Detach()
|
||||
log.Printf("[share] cleaned up device %s", busID)
|
||||
}
|
||||
|
||||
sm.active = make(map[string]*activeShare)
|
||||
sm.tunnels = make(map[string]*shareTunnel)
|
||||
}
|
||||
|
||||
// DeviceListForAPI returns device info formatted for the web API
|
||||
func (sm *ShareManager) DeviceListForAPI() []map[string]interface{} {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
var result []map[string]interface{}
|
||||
for _, dev := range sm.devices {
|
||||
status := "available"
|
||||
usedBy := ""
|
||||
if share, ok := sm.active[dev.BusID]; ok {
|
||||
status = "in_use"
|
||||
usedBy = share.usedBy
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"bus_id": dev.BusID,
|
||||
"vendor_id": fmt.Sprintf("%04x", dev.VendorID),
|
||||
"product_id": fmt.Sprintf("%04x", dev.ProductID),
|
||||
"name": dev.DisplayName(),
|
||||
"status": status,
|
||||
"used_by": usedBy,
|
||||
"speed": dev.Speed,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user