package relay import ( "encoding/json" "log" "net/http" "time" "github.com/duffy/usb-server/internal/protocol" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 64 * 1024, WriteBufferSize: 64 * 1024, CheckOrigin: func(r *http.Request) bool { return true // relay accepts all origins }, } // Server is the WebSocket relay server type Server struct { hub *Hub addr string } // NewServer creates a new relay server func NewServer(addr string) *Server { return &Server{ hub: NewHub(), addr: addr, } } // Run starts the relay server func (s *Server) Run() error { mux := http.NewServeMux() mux.HandleFunc("/ws", s.handleWebSocket) mux.HandleFunc("/health", s.handleHealth) log.Printf("[relay] starting on %s", s.addr) return http.ListenAndServe(s.addr, mux) } func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) } func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("[relay] upgrade error: %v", err) return } defer conn.Close() // Set read limits and deadlines conn.SetReadLimit(1024 * 1024) // 1MB max message conn.SetReadDeadline(time.Now().Add(60 * time.Second)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) // Wait for registration message _, msgData, err := conn.ReadMessage() if err != nil { log.Printf("[relay] read error during registration: %v", err) return } var reg protocol.Register if err := json.Unmarshal(msgData, ®); err != nil || reg.Type != protocol.MsgRegister { log.Printf("[relay] invalid registration message") conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "invalid registration"}) return } if reg.Hash == "" || reg.ClientID == "" || (reg.Mode != protocol.ModeShare && reg.Mode != protocol.ModeUse) { conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "missing required fields"}) return } client := &Client{ ID: reg.ClientID, Hash: reg.Hash, Mode: reg.Mode, Name: reg.Name, Conn: conn, Send: make(chan []byte, 256), } s.hub.Register(client) defer s.hub.Unregister(client) // Start ping ticker done := make(chan struct{}) go func() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: client.mu.Lock() err := conn.WriteMessage(websocket.PingMessage, nil) client.mu.Unlock() if err != nil { return } case <-done: return } } }() defer close(done) // Read loop for { msgType, data, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { log.Printf("[relay] read error from %s: %v", client.ID, err) } break } conn.SetReadDeadline(time.Now().Add(60 * time.Second)) switch msgType { case websocket.TextMessage: s.hub.HandleTextMessage(client, data) case websocket.BinaryMessage: s.hub.HandleBinaryMessage(client, data) } } }