powerpods/goTool/ota_upload.go
simon 0eea27a876 Fix web OTA upload and isolate OTA sessions across firmware and goTool.
Split ESP-NOW into core/master/slave modules, block non-OTA UART traffic during updates, and hold the host serial port exclusively so dashboard polling cannot interleave with firmware uploads.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-31 16:35:18 +02:00

619 lines
16 KiB
Go

package main
import (
"fmt"
"time"
"google.golang.org/protobuf/proto"
uartframe "powerpod/gotool/uart"
"powerpod/gotool/pb"
)
const (
otaHostChunkSize = 200
otaFlashBlockSize = 4096
otaPrepareTimeout = 120 * time.Second
otaDefaultTimeout = 15 * time.Second
otaStatusPollTimeout = 3 * time.Second
otaDistReadTimeout = 400 * time.Millisecond
otaDistQueryInterval = 500 * time.Millisecond
otaDistQueryTimeout = 2 * time.Second
otaDistEmitMinInterval = 150 * time.Millisecond
)
const (
otaStPreparing = 1
otaStReady = 2
otaStBlockAck = 3
otaStSuccess = 4
otaStFailed = 5
otaStDistributing = 6
otaDistAggregate = 0
otaDistPerSlave = 1
otaDistTimeout = 45 * time.Minute
)
// OtaSlaveDetail is per-slave ESP-NOW OTA state from OTA_SLAVE_PROGRESS.
type OtaSlaveDetail struct {
BytesWritten uint32 `json:"bytes_written"`
TotalBytes uint32 `json:"total_bytes"`
Status uint32 `json:"status"`
Error uint32 `json:"error"`
}
// OTAProgress is pushed to the dashboard during web uploads.
type OTAProgress struct {
Type string `json:"type"` // always "ota_progress"
Phase string `json:"phase"`
Step string `json:"step,omitempty"` // master, slaves
Percent int `json:"percent"`
MasterPercent int `json:"master_percent,omitempty"`
MasterDone bool `json:"master_done,omitempty"`
Message string `json:"message"`
MasterMessage string `json:"master_message,omitempty"`
Bytes uint32 `json:"bytes_written,omitempty"`
Slot uint32 `json:"target_slot,omitempty"`
Slaves uint32 `json:"slaves,omitempty"`
ImageSize uint32 `json:"image_size,omitempty"`
SlaveProgress map[uint32]uint32 `json:"slave_progress,omitempty"` // client_id -> bytes
SlaveDetails map[uint32]OtaSlaveDetail `json:"slave_details,omitempty"`
}
type otaProgressFn func(OTAProgress)
const (
otaStepMaster = "master"
otaStepSlaves = "slaves"
)
func runOTAUpload(m *managedSerial, firmware []byte, onProgress otaProgressFn) error {
push := func(phase, msg string) {
if onProgress == nil {
return
}
onProgress(OTAProgress{
Type: "ota_progress", Phase: phase, Step: otaStepMaster,
Percent: 0, Message: msg, MasterMessage: msg,
})
}
push("preparing", "UART wird vorbereitet…")
// Block until the UART is free, then hold m.mu for the entire upload so
// dashboard/API polling cannot interleave on the serial port.
m.mu.Lock()
if m.otaActive {
m.mu.Unlock()
return errOTAInProgress
}
m.otaActive = true
if m.sp == nil {
if err := m.openLocked(); err != nil {
m.otaActive = false
m.mu.Unlock()
push("error", err.Error())
return err
}
}
sp := m.sp
err := runOTAOnPortUnlocked(sp, firmware, onProgress)
if err != nil {
m.invalidateLocked(err)
}
m.otaActive = false
m.mu.Unlock()
return err
}
func runOTAOnPortUnlocked(sp *serialPort, firmware []byte, onProgress otaProgressFn) error {
if len(firmware) == 0 {
return fmt.Errorf("empty firmware")
}
imageSize := len(firmware)
masterPct := 0
masterMsg := ""
notify := func(phase, step string, percent int, msg string, extra ...OTAProgress) {
if onProgress == nil {
return
}
p := OTAProgress{
Type: "ota_progress", Phase: phase, Step: step,
Percent: percent, Message: msg,
ImageSize: uint32(imageSize),
}
if step == otaStepMaster || phase == "preparing" || phase == "ready" || phase == "uploading" {
masterPct = percent
masterMsg = msg
}
p.MasterPercent = masterPct
p.MasterMessage = masterMsg
if step == otaStepSlaves || phase == "distributing" || phase == "done" {
p.MasterDone = true
}
if len(extra) > 0 {
e := extra[0]
p.Bytes = e.Bytes
p.Slot = e.Slot
p.Slaves = e.Slaves
p.SlaveProgress = e.SlaveProgress
p.SlaveDetails = e.SlaveDetails
if e.MasterPercent > 0 {
p.MasterPercent = e.MasterPercent
}
if e.MasterMessage != "" {
p.MasterMessage = e.MasterMessage
}
}
onProgress(p)
}
if err := sp.port.SetReadTimeout(readTimeout); err != nil {
notify("error", "", 0, err.Error())
return err
}
notify("preparing", otaStepMaster, 0, fmt.Sprintf("Master: OTA start (%d bytes)…", imageSize))
flushSerialInput(sp)
if err := writeUartMessage(sp, &pb.UartMessage{
Type: pb.MessageType_OTA_START,
Payload: &pb.UartMessage_OtaStart{
OtaStart: &pb.OtaStartPayload{TotalSize: uint32(imageSize)},
},
}); err != nil {
notify("error", "", 0, err.Error())
return err
}
if err := sp.port.SetReadTimeout(otaPrepareTimeout); err != nil {
notify("error", "", 0, err.Error())
return err
}
defer func() { _ = sp.port.SetReadTimeout(readTimeout) }()
ready, err := waitOtaStatus(sp, otaStReady, otaPrepareTimeout, func(msg string) {
notify("preparing", otaStepMaster, 2, msg)
})
if err != nil {
notify("error", "", 0, err.Error())
return err
}
notify("ready", otaStepMaster, 5, fmt.Sprintf("Master: Slot %d bereit", ready.GetTargetSlot()))
if err := sp.port.SetReadTimeout(otaDefaultTimeout); err != nil {
notify("error", "", 0, err.Error())
return err
}
var seq uint32
for offset := 0; offset < imageSize; {
bytesInBlock := 0
for bytesInBlock < otaFlashBlockSize && offset < imageSize {
n := otaHostChunkSize
room := otaFlashBlockSize - bytesInBlock
if n > room {
n = room
}
if offset+n > imageSize {
n = imageSize - offset
}
chunk := firmware[offset : offset+n]
if err := writeUartMessage(sp, &pb.UartMessage{
Type: pb.MessageType_OTA_PAYLOAD,
Payload: &pb.UartMessage_OtaPayload{
OtaPayload: &pb.OtaPayload{Seq: seq, Data: chunk},
},
}); err != nil {
notify("error", "", 0, err.Error())
return err
}
seq++
offset += n
bytesInBlock += n
pct := offset * 100 / imageSize
if pct > 99 {
pct = 99
}
notify("uploading", otaStepMaster, pct, fmt.Sprintf("Master: %d / %d bytes", offset, imageSize))
}
if bytesInBlock == otaFlashBlockSize {
st, err := waitOtaStatus(sp, otaStBlockAck, otaDefaultTimeout, nil)
if err != nil {
notify("error", "", 0, err.Error())
return err
}
pct := offset * 100 / imageSize
if pct > 99 {
pct = 99
}
notify("uploading", otaStepMaster, pct,
fmt.Sprintf("Master: Block geschrieben (%d bytes)", st.GetBytesWritten()),
OTAProgress{Bytes: st.GetBytesWritten()})
}
}
masterPct = 100
masterMsg = "Master: UART-Upload abgeschlossen"
notify("uploading", otaStepMaster, 100, masterMsg)
if err := writeUartMessage(sp, &pb.UartMessage{
Type: pb.MessageType_OTA_END,
Payload: &pb.UartMessage_OtaEnd{
OtaEnd: &pb.OtaEndPayload{},
},
}); err != nil {
notify("error", "", 0, err.Error())
return err
}
slaveBytes := make(map[uint32]uint32)
slaveDetails := make(map[uint32]OtaSlaveDetail)
emitSlaveOTA := func(msg string, aggBytes uint32, slaveCount uint32) {
if slaveCount == 0 && len(slaveDetails) > 0 {
slaveCount = uint32(len(slaveDetails))
}
notify("distributing", otaStepSlaves, 0, msg,
OTAProgress{
Bytes: aggBytes, Slaves: slaveCount,
MasterPercent: 100, MasterMessage: masterMsg,
SlaveProgress: copySlaveMap(slaveBytes),
SlaveDetails: copySlaveDetails(slaveDetails),
})
}
onDistStatus := func(st *pb.OtaStatusPayload) {
applyDistributingOtaStatus(st, imageSize, slaveBytes, slaveDetails)
}
var lastEmit, lastQuery time.Time
slaveDistMessage := func() (msg string, aggBytes, slaveCount uint32) {
slaveCount = uint32(len(slaveDetails))
for _, d := range slaveDetails {
if d.BytesWritten > aggBytes {
aggBytes = d.BytesWritten
}
}
if slaveCount == 0 {
return "Keine verfügbaren Slaves — Verteilung übersprungen", 0, 0
}
return fmt.Sprintf("ESP-NOW: %d / %d bytes (%d Slaves)",
aggBytes, imageSize, slaveCount), aggBytes, slaveCount
}
emitSlaveThrottled := func(force bool) {
if !force && time.Since(lastEmit) < otaDistEmitMinInterval {
return
}
lastEmit = time.Now()
msg, agg, n := slaveDistMessage()
emitSlaveOTA(msg, agg, n)
}
querySlaveProgress := func() {
if time.Since(lastQuery) < otaDistQueryInterval {
return
}
lastQuery = time.Now()
prog, err := queryOtaSlaveProgressLocked(sp, 0, onDistStatus, otaDistQueryTimeout)
if err != nil {
if len(slaveDetails) > 0 {
emitSlaveThrottled(true)
}
return
}
mergeSlaveProgressResponse(prog, slaveBytes, slaveDetails)
emitSlaveThrottled(true)
}
pushSlaveDist := func(st *pb.OtaStatusPayload) {
onDistStatus(st)
emitSlaveThrottled(false)
}
onWaitTick := func() {
querySlaveProgress()
}
lastQuery = time.Time{} // first query immediately when distribution starts
querySlaveProgress()
st, err := waitOtaComplete(sp, otaDistTimeout, pushSlaveDist, onWaitTick, otaDistReadTimeout)
if err != nil {
notify("error", "", 0, err.Error())
return err
}
if prog, err := queryOtaSlaveProgressLocked(sp, 0, nil, otaDistQueryTimeout); err == nil {
mergeSlaveProgressResponse(prog, slaveBytes, slaveDetails)
}
notify("done", "", 100,
fmt.Sprintf("Fertig — %d bytes, Boot-Slot %d. Master und Slaves neu starten.",
st.GetBytesWritten(), st.GetTargetSlot()),
OTAProgress{
Bytes: st.GetBytesWritten(), Slot: st.GetTargetSlot(),
MasterPercent: 100, MasterMessage: "Master: OK",
SlaveProgress: copySlaveMap(slaveBytes),
SlaveDetails: copySlaveDetails(slaveDetails),
})
return nil
}
// QueryOtaSlaveProgress queries the master for per-slave ESP-NOW OTA progress.
func QueryOtaSlaveProgress(sp *serialPort, clientID uint32) (*pb.OtaSlaveProgressResponse, error) {
sp.mu.Lock()
defer sp.mu.Unlock()
return queryOtaSlaveProgressLocked(sp, clientID, nil, otaDefaultTimeout)
}
func queryOtaSlaveProgressLocked(sp *serialPort, clientID uint32,
onStatus func(*pb.OtaStatusPayload), queryTimeout time.Duration) (*pb.OtaSlaveProgressResponse, error) {
req := &pb.UartMessage{
Type: pb.MessageType_OTA_SLAVE_PROGRESS,
Payload: &pb.UartMessage_OtaSlaveProgressRequest{
OtaSlaveProgressRequest: &pb.OtaSlaveProgressRequest{
ClientId: clientID,
},
},
}
if err := writeUartMessage(sp, req); err != nil {
return nil, err
}
if queryTimeout <= 0 {
queryTimeout = otaDefaultTimeout
}
deadline := time.Now().Add(queryTimeout)
msg, err := readUartMessageUntil(sp, deadline, pb.MessageType_OTA_SLAVE_PROGRESS, onStatus, otaDistReadTimeout)
if err != nil {
return nil, err
}
r := msg.GetOtaSlaveProgressResponse()
if r == nil {
return nil, fmt.Errorf("missing ota_slave_progress_response")
}
return r, nil
}
func applyDistributingOtaStatus(st *pb.OtaStatusPayload, imageSize int,
slaveBytes map[uint32]uint32, details map[uint32]OtaSlaveDetail) {
if st == nil || st.GetStatus() != otaStDistributing {
return
}
if st.GetError() != otaDistPerSlave {
return
}
id := st.GetTargetSlot()
bw := st.GetBytesWritten()
slaveBytes[id] = bw
d := details[id]
d.BytesWritten = bw
if d.TotalBytes == 0 {
d.TotalBytes = uint32(imageSize)
}
if d.Status == 0 || d.Status == 1 || d.Status == 2 {
d.Status = 3
}
details[id] = d
}
func readUartMessageUntil(sp *serialPort, deadline time.Time, want pb.MessageType,
onStatus func(*pb.OtaStatusPayload), readChunk time.Duration) (*pb.UartMessage, error) {
if readChunk <= 0 {
readChunk = otaStatusPollTimeout
}
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("timeout waiting for %v", want)
}
wait := time.Until(deadline)
if wait > readChunk {
wait = readChunk
}
if err := sp.port.SetReadTimeout(wait); err != nil {
return nil, err
}
payload, err := uartframe.ReadFrame(sp.port, nil)
if err != nil {
return nil, err
}
msg, err := decodeUartPayload(payload)
if err != nil {
continue
}
if msg.GetType() == pb.MessageType_OTA_STATUS {
if onStatus != nil {
if st := msg.GetOtaStatus(); st != nil {
onStatus(st)
}
}
continue
}
if msg.GetType() == want {
return msg, nil
}
}
}
func mergeSlaveProgressResponse(r *pb.OtaSlaveProgressResponse,
bytesOut map[uint32]uint32, detailsOut map[uint32]OtaSlaveDetail) {
if r == nil {
return
}
for _, s := range r.GetSlaves() {
id := s.GetClientId()
bytesOut[id] = s.GetBytesWritten()
detailsOut[id] = OtaSlaveDetail{
BytesWritten: s.GetBytesWritten(),
TotalBytes: s.GetTotalBytes(),
Status: s.GetStatus(),
Error: s.GetError(),
}
}
}
func copySlaveDetails(m map[uint32]OtaSlaveDetail) map[uint32]OtaSlaveDetail {
out := make(map[uint32]OtaSlaveDetail, len(m))
for k, v := range m {
out[k] = v
}
return out
}
func copySlaveMap(m map[uint32]uint32) map[uint32]uint32 {
out := make(map[uint32]uint32, len(m))
for k, v := range m {
out[k] = v
}
return out
}
func waitOtaComplete(sp *serialPort, timeout time.Duration,
onDistributing func(*pb.OtaStatusPayload), onInterval func(),
readTimeout time.Duration) (*pb.OtaStatusPayload, error) {
if readTimeout <= 0 {
readTimeout = otaStatusPollTimeout
}
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("timeout waiting for OTA success (slave distribution?)")
}
readWait := time.Until(deadline)
if readWait > readTimeout {
readWait = readTimeout
}
if err := sp.port.SetReadTimeout(readWait); err != nil {
return nil, err
}
st, err := readOtaStatus(sp)
if err != nil {
if onInterval != nil {
onInterval()
}
continue
}
switch st.GetStatus() {
case otaStSuccess:
return st, nil
case otaStFailed:
return nil, fmt.Errorf("OTA failed (error=%d)", st.GetError())
case otaStDistributing:
if onDistributing != nil {
onDistributing(st)
}
if onInterval != nil {
onInterval()
}
default:
// ignore other interim statuses
}
}
}
func writeUartMessage(sp *serialPort, msg *pb.UartMessage) error {
frame, err := encodeUartMessage(msg)
if err != nil {
return err
}
_, err = sp.port.Write(frame)
return err
}
func waitOtaStatus(sp *serialPort, want uint32, timeout time.Duration, onPreparing func(string)) (*pb.OtaStatusPayload, error) {
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("timeout waiting for OTA status %d", want)
}
readWait := time.Until(deadline)
if readWait > otaStatusPollTimeout {
readWait = otaStatusPollTimeout
}
if err := sp.port.SetReadTimeout(readWait); err != nil {
return nil, err
}
payload, err := uartframe.ReadFrame(sp.port, nil)
if err != nil {
continue
}
msg, err := decodeUartPayload(payload)
if err != nil || msg.GetType() != pb.MessageType_OTA_STATUS {
continue
}
st := msg.GetOtaStatus()
if st == nil {
continue
}
switch st.GetStatus() {
case want:
return st, nil
case otaStPreparing:
if onPreparing != nil {
onPreparing("Partition wird vorbereitet (~30s)…")
}
case otaStFailed:
return nil, fmt.Errorf("OTA failed (error=%d)", st.GetError())
}
}
}
func readOtaStatus(sp *serialPort) (*pb.OtaStatusPayload, error) {
payload, err := uartframe.ReadFrame(sp.port, nil)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
msg, err := decodeUartPayload(payload)
if err != nil {
return nil, err
}
if msg.GetType() != pb.MessageType_OTA_STATUS {
return nil, fmt.Errorf("unexpected response type %v", msg.GetType())
}
st := msg.GetOtaStatus()
if st == nil {
return nil, fmt.Errorf("missing ota_status")
}
return st, nil
}
func encodeUartMessage(msg *pb.UartMessage) ([]byte, error) {
body, err := proto.Marshal(msg)
if err != nil {
return nil, err
}
payload := append([]byte{byte(msg.Type)}, body...)
return uartframe.EncodeFrame(payload)
}
// flushSerialInput drops stale RX bytes (not full frames — avoids ReadFrame blocking).
func flushSerialInput(sp *serialPort) {
if sp == nil {
return
}
_ = sp.port.SetReadTimeout(10 * time.Millisecond)
buf := make([]byte, 256)
deadline := time.Now().Add(50 * time.Millisecond)
for time.Now().Before(deadline) {
n, err := sp.port.Read(buf)
if n == 0 || err != nil {
return
}
}
}
func decodeUartPayload(payload []byte) (*pb.UartMessage, error) {
if len(payload) == 0 {
return nil, fmt.Errorf("empty response")
}
var msg pb.UartMessage
if err := proto.Unmarshal(payload[1:], &msg); err != nil {
return nil, err
}
msg.Type = pb.MessageType(payload[0])
return &msg, nil
}