usb-server/internal/client/socket_windows.go

204 lines
5.8 KiB
Go

//go:build windows
package client
import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"github.com/duffy/usb-server/internal/protocol"
"github.com/duffy/usb-server/internal/usbip"
)
// createVHCIAttachment creates a VHCI attachment on Windows using usbip-win2.
// It starts a local TCP proxy, launches usbip.exe to connect to it,
// handles the USB/IP management phase (OP_REQ_IMPORT) locally,
// and returns the TCP connection for the transfer phase bridge.
func createVHCIAttachment(ctx context.Context, granted *protocol.DeviceGranted, devInfo *RemoteDevice) (net.Conn, int, error) {
// Find usbip.exe
usbipExe, err := usbip.FindUsbipExe()
if err != nil {
return nil, -1, err
}
// Start TCP listener on localhost with random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, -1, fmt.Errorf("starting TCP listener: %w", err)
}
tcpPort := listener.Addr().(*net.TCPAddr).Port
log.Printf("[vhci-win] TCP proxy listening on 127.0.0.1:%d", tcpPort)
// Channel for the accepted connection after management phase
type acceptResult struct {
conn net.Conn
err error
}
resultCh := make(chan acceptResult, 1)
// Accept connection and handle management phase in goroutine
go func() {
conn, err := listener.Accept()
listener.Close() // only accept one connection
if err != nil {
resultCh <- acceptResult{nil, fmt.Errorf("accepting connection: %w", err)}
return
}
log.Printf("[vhci-win] usbip.exe connected, handling management phase")
// Handle OP_REQ_IMPORT from usbip.exe
if err := handleImportRequest(conn, granted, devInfo); err != nil {
conn.Close()
resultCh <- acceptResult{nil, fmt.Errorf("management phase: %w", err)}
return
}
log.Printf("[vhci-win] management phase complete, entering transfer phase")
resultCh <- acceptResult{conn, nil}
}()
// Launch usbip.exe attach
cmd := exec.CommandContext(ctx, usbipExe,
"--tcp-port", fmt.Sprintf("%d", tcpPort),
"attach", "-r", "127.0.0.1", "-b", granted.BusID)
output, err := cmd.CombinedOutput()
outputStr := strings.TrimSpace(string(output))
if err != nil {
// Close listener to unblock Accept goroutine
listener.Close()
return nil, -1, fmt.Errorf("usbip.exe attach failed: %w (output: %s)", err, outputStr)
}
log.Printf("[vhci-win] usbip.exe output: %s", outputStr)
// Parse VHCI port from usbip.exe output (e.g. "succesfully attached to port 0")
vhciPort := parsePortFromOutput(outputStr)
// Wait for management phase to complete
result := <-resultCh
if result.err != nil {
return nil, -1, result.err
}
log.Printf("[vhci-win] device attached on VHCI port %d", vhciPort)
return result.conn, vhciPort, nil
}
// handleImportRequest reads OP_REQ_IMPORT from the usbip.exe client
// and responds with OP_REP_IMPORT containing the device descriptor.
func handleImportRequest(conn net.Conn, granted *protocol.DeviceGranted, devInfo *RemoteDevice) error {
// Read the OpHeader (8 bytes)
hdr, err := usbip.ReadOpHeader(conn)
if err != nil {
return fmt.Errorf("reading op header: %w", err)
}
if hdr.Command != usbip.OpReqImport {
return fmt.Errorf("unexpected command: 0x%04x (expected OP_REQ_IMPORT 0x%04x)", hdr.Command, usbip.OpReqImport)
}
// Read the 32-byte bus ID
var busIDBuf [32]byte
if _, err := io.ReadFull(conn, busIDBuf[:]); err != nil {
return fmt.Errorf("reading bus ID: %w", err)
}
requestedBusID := usbip.GetBusID(busIDBuf)
log.Printf("[vhci-win] OP_REQ_IMPORT for bus ID: %s", requestedBusID)
// Build device descriptor from available info
desc := buildDeviceDescriptor(granted, devInfo)
// Build and send OP_REP_IMPORT reply
reply, err := usbip.BuildImportReply(0, &desc)
if err != nil {
return fmt.Errorf("building import reply: %w", err)
}
if _, err := conn.Write(reply); err != nil {
return fmt.Errorf("writing import reply: %w", err)
}
return nil
}
// buildDeviceDescriptor creates a USB/IP DeviceDescriptor from the
// information available in the DeviceGranted message and RemoteDevice.
func buildDeviceDescriptor(granted *protocol.DeviceGranted, devInfo *RemoteDevice) usbip.DeviceDescriptor {
var desc usbip.DeviceDescriptor
usbip.SetBusID(&desc.BusID, granted.BusID)
usbip.SetPath(&desc.Path, "/sys/bus/usb/"+granted.BusID)
desc.Speed = granted.Speed
desc.BusNum = granted.DevID >> 16
desc.DevNum = granted.DevID & 0xFFFF
// Fill from RemoteDevice if available
if devInfo != nil {
desc.BusNum = devInfo.BusNum
desc.DevNum = devInfo.DevNum
// Parse hex VendorID/ProductID
if vid, err := strconv.ParseUint(devInfo.VendorID, 16, 16); err == nil {
desc.IDVendor = uint16(vid)
}
if pid, err := strconv.ParseUint(devInfo.ProductID, 16, 16); err == nil {
desc.IDProduct = uint16(pid)
}
desc.BDeviceClass = devInfo.Class
desc.BDeviceSubClass = devInfo.SubClass
desc.BDeviceProtocol = devInfo.Protocol
desc.BNumInterfaces = devInfo.NumInterfaces
}
// Defaults for fields not available in the protocol
desc.BcdDevice = 0x0100
desc.BConfigurationValue = 1
desc.BNumConfigurations = 1
if desc.BNumInterfaces == 0 {
desc.BNumInterfaces = 1
}
return desc
}
// parsePortFromOutput extracts the VHCI port number from usbip.exe output.
// Returns -1 if the port cannot be parsed.
func parsePortFromOutput(output string) int {
// Match patterns like "port 0", "port 1", etc.
re := regexp.MustCompile(`(?i)port\s+(\d+)`)
matches := re.FindStringSubmatch(output)
if len(matches) >= 2 {
if port, err := strconv.Atoi(matches[1]); err == nil {
return port
}
}
return -1
}
// createSocketPair is not used on Windows but required for compilation.
func createSocketPair() ([2]int, error) {
return [2]int{}, fmt.Errorf("socketpair not available on Windows")
}
func closeFDs(fds [2]int) {}
func fdToFile(fd int, name string) *os.File {
return nil
}
// fixVHCIDevicePermissions is not needed on Windows.
func fixVHCIDevicePermissions(port int) {}