usb-server/internal/relay/server.go

139 lines
3.2 KiB
Go

package relay
import (
"encoding/json"
"log"
"net/http"
"time"
"github.com/duffy/usb-server/internal/protocol"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 64 * 1024,
WriteBufferSize: 64 * 1024,
CheckOrigin: func(r *http.Request) bool {
return true // relay accepts all origins
},
}
// Server is the WebSocket relay server
type Server struct {
hub *Hub
addr string
}
// NewServer creates a new relay server
func NewServer(addr string) *Server {
return &Server{
hub: NewHub(),
addr: addr,
}
}
// Run starts the relay server
func (s *Server) Run() error {
mux := http.NewServeMux()
mux.HandleFunc("/ws", s.handleWebSocket)
mux.HandleFunc("/health", s.handleHealth)
log.Printf("[relay] starting on %s", s.addr)
return http.ListenAndServe(s.addr, mux)
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("[relay] upgrade error: %v", err)
return
}
defer conn.Close()
// Set read limits and deadlines
conn.SetReadLimit(1024 * 1024) // 1MB max message
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// Wait for registration message
_, msgData, err := conn.ReadMessage()
if err != nil {
log.Printf("[relay] read error during registration: %v", err)
return
}
var reg protocol.Register
if err := json.Unmarshal(msgData, &reg); err != nil || reg.Type != protocol.MsgRegister {
log.Printf("[relay] invalid registration message")
conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "invalid registration"})
return
}
if reg.Hash == "" || reg.ClientID == "" || (reg.Mode != protocol.ModeShare && reg.Mode != protocol.ModeUse) {
conn.WriteJSON(&protocol.ErrorMsg{Type: protocol.MsgError, Message: "missing required fields"})
return
}
client := &Client{
ID: reg.ClientID,
Hash: reg.Hash,
Mode: reg.Mode,
Name: reg.Name,
Conn: conn,
Send: make(chan []byte, 256),
}
s.hub.Register(client)
defer s.hub.Unregister(client)
// Start ping ticker
done := make(chan struct{})
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
client.mu.Lock()
err := conn.WriteMessage(websocket.PingMessage, nil)
client.mu.Unlock()
if err != nil {
return
}
case <-done:
return
}
}
}()
defer close(done)
// Read loop
for {
msgType, data, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
log.Printf("[relay] read error from %s: %v", client.ID, err)
}
break
}
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
switch msgType {
case websocket.TextMessage:
s.hub.HandleTextMessage(client, data)
case websocket.BinaryMessage:
s.hub.HandleBinaryMessage(client, data)
}
}
}