usb-server/internal/client/socket_linux.go

157 lines
4.5 KiB
Go

//go:build linux
package client
import (
"context"
"fmt"
"log"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/duffy/usb-server/internal/protocol"
"github.com/duffy/usb-server/internal/usbip"
"golang.org/x/sys/unix"
)
// createVHCIAttachment creates a VHCI attachment on Linux using socketpair + sysfs.
// Returns the tunnel connection (our end of the socketpair), the VHCI port number, and any error.
func createVHCIAttachment(_ context.Context, granted *protocol.DeviceGranted, _ *RemoteDevice) (net.Conn, int, error) {
// Create a socketpair - one end for VHCI, one for our tunnel
fds, err := createSocketPair()
if err != nil {
return nil, -1, fmt.Errorf("creating socketpair: %w", err)
}
vhciFD := fds[0]
tunnelFD := fds[1]
// Find a free VHCI port
port, err := usbip.FindFreePort(granted.Speed)
if err != nil {
closeFDs(fds)
return nil, -1, fmt.Errorf("finding free VHCI port: %w", err)
}
// Attach to VHCI
if err := usbip.AttachDevice(port, vhciFD, granted.DevID, granted.Speed); err != nil {
closeFDs(fds)
return nil, -1, fmt.Errorf("VHCI attach: %w", err)
}
// The VHCI driver holds a kernel reference to the socket via sockfd_lookup,
// so we can close our copy of the fd to avoid leaking it.
unix.Close(vhciFD)
// Create a net.Conn from the tunnel FD
tunnelFile := fdToFile(tunnelFD, "usb-tunnel")
tunnelConn, err := net.FileConn(tunnelFile)
tunnelFile.Close() // FileConn dups the fd
if err != nil {
usbip.DetachDevice(port)
return nil, -1, fmt.Errorf("creating tunnel conn: %w", err)
}
return tunnelConn, port, nil
}
// createSocketPair creates a Unix domain socket pair
func createSocketPair() ([2]int, error) {
fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
if err != nil {
return [2]int{}, fmt.Errorf("socketpair: %w", err)
}
return fds, nil
}
func closeFDs(fds [2]int) {
unix.Close(fds[0])
unix.Close(fds[1])
}
func fdToFile(fd int, name string) *os.File {
return os.NewFile(uintptr(fd), name)
}
// fixVHCIDevicePermissions waits for the VHCI-attached device to create
// device nodes (e.g. /dev/video*, /dev/input/event*, /dev/hidraw*) and sets
// them to world-accessible. VHCI-created devices don't get normal udev
// rules applied, so they default to root-only access.
func fixVHCIDevicePermissions(port int) {
// Wait for the device to finish enumerating and create device nodes.
// The kernel needs time to enumerate descriptors and bind drivers.
for attempt := 0; attempt < 15; attempt++ {
time.Sleep(500 * time.Millisecond)
found := false
// Walk the VHCI sysfs tree to find device nodes at any depth.
// Paths look like: vhci_hcd.0/usb3/3-1/3-1:1.0/video4linux/video0
filepath.WalkDir("/sys/devices/platform/vhci_hcd.0", func(path string, d os.DirEntry, err error) error {
if err != nil {
return nil
}
dir := filepath.Dir(path)
parent := filepath.Base(dir)
// video4linux devices → /dev/videoN
if parent == "video4linux" && strings.HasPrefix(d.Name(), "video") {
devPath := "/dev/" + d.Name()
if err := os.Chmod(devPath, 0666); err == nil {
log.Printf("[use] set permissions 0666 on %s", devPath)
found = true
} else {
log.Printf("[use] chmod %s failed: %v", devPath, err)
}
}
// sound devices → /dev/snd/*
if parent == "sound" && strings.HasPrefix(d.Name(), "card") {
sndDir := filepath.Join(path, "device")
if _, err := os.Stat(sndDir); err == nil {
filepath.WalkDir("/dev/snd", func(sndPath string, sd os.DirEntry, err error) error {
if err == nil && !sd.IsDir() {
os.Chmod(sndPath, 0666)
}
return nil
})
}
}
// input devices → /dev/input/eventN
if strings.HasPrefix(d.Name(), "event") && strings.Contains(path, "/input/input") {
devPath := "/dev/input/" + d.Name()
if err := os.Chmod(devPath, 0666); err == nil {
log.Printf("[use] set permissions 0666 on %s", devPath)
found = true
} else {
log.Printf("[use] chmod %s failed: %v", devPath, err)
}
}
// hidraw devices → /dev/hidrawN
if parent == "hidraw" && strings.HasPrefix(d.Name(), "hidraw") {
devPath := "/dev/" + d.Name()
if err := os.Chmod(devPath, 0666); err == nil {
log.Printf("[use] set permissions 0666 on %s", devPath)
found = true
} else {
log.Printf("[use] chmod %s failed: %v", devPath, err)
}
}
return nil
})
if attempt >= 2 && found {
return
}
}
log.Printf("[use] fixVHCIDevicePermissions: no device nodes found after 7.5s (port %d)", port)
}