Firmware buffers 200-byte chunks into 4 KiB blocks for esp_ota_write; goTool uploads with per-block ACK flow control and larger UART buffers to avoid stalls. Co-authored-by: Cursor <cursoragent@cursor.com>
205 lines
4.8 KiB
Go
205 lines
4.8 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"time"
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
uartframe "powerpod/gotool/uart"
|
|
"powerpod/gotool/pb"
|
|
)
|
|
|
|
const (
|
|
otaHostChunkSize = 200
|
|
otaFlashBlockSize = 4096
|
|
otaPrepareTimeout = 120 * time.Second
|
|
otaDefaultTimeout = 15 * time.Second
|
|
)
|
|
|
|
const (
|
|
otaStPreparing = 1
|
|
otaStReady = 2
|
|
otaStBlockAck = 3
|
|
otaStSuccess = 4
|
|
otaStFailed = 5
|
|
)
|
|
|
|
func runOTA(sp *serialPort, args []string) error {
|
|
if len(args) < 1 {
|
|
return fmt.Errorf("usage: ota <firmware.bin>")
|
|
}
|
|
data, err := os.ReadFile(args[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(data) == 0 {
|
|
return fmt.Errorf("empty firmware file")
|
|
}
|
|
|
|
if err := sp.port.SetReadTimeout(otaPrepareTimeout); err != nil {
|
|
return err
|
|
}
|
|
defer sp.port.SetReadTimeout(readTimeout)
|
|
|
|
sp.mu.Lock()
|
|
defer sp.mu.Unlock()
|
|
|
|
fmt.Printf("OTA start: %d bytes firmware\n", len(data))
|
|
if err := writeUartMessageLocked(sp, &pb.UartMessage{
|
|
Type: pb.MessageType_OTA_START,
|
|
Payload: &pb.UartMessage_OtaStart{
|
|
OtaStart: &pb.OtaStartPayload{TotalSize: uint32(len(data))},
|
|
},
|
|
}, "OTA_START"); err != nil {
|
|
return err
|
|
}
|
|
if _, err := waitOtaStatusLocked(sp, otaStReady, otaPrepareTimeout); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := sp.port.SetReadTimeout(otaDefaultTimeout); err != nil {
|
|
return err
|
|
}
|
|
|
|
var seq uint32
|
|
blockNum := 0
|
|
for offset := 0; offset < len(data); {
|
|
bytesInBlock := 0
|
|
for bytesInBlock < otaFlashBlockSize && offset < len(data) {
|
|
n := otaHostChunkSize
|
|
room := otaFlashBlockSize - bytesInBlock
|
|
if n > room {
|
|
n = room
|
|
}
|
|
if offset+n > len(data) {
|
|
n = len(data) - offset
|
|
}
|
|
chunk := data[offset : offset+n]
|
|
|
|
if err := writeUartMessageLocked(sp, &pb.UartMessage{
|
|
Type: pb.MessageType_OTA_PAYLOAD,
|
|
Payload: &pb.UartMessage_OtaPayload{
|
|
OtaPayload: &pb.OtaPayload{Seq: seq, Data: chunk},
|
|
},
|
|
}, "OTA_PAYLOAD"); err != nil {
|
|
return err
|
|
}
|
|
seq++
|
|
offset += n
|
|
bytesInBlock += n
|
|
}
|
|
|
|
if bytesInBlock == otaFlashBlockSize {
|
|
blockNum++
|
|
st, err := waitOtaStatusLocked(sp, otaStBlockAck, otaDefaultTimeout)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fmt.Printf(" block %d ack (%d bytes in flash, %d%%)\n",
|
|
blockNum, st.GetBytesWritten(), offset*100/len(data))
|
|
}
|
|
}
|
|
|
|
if err := writeUartMessageLocked(sp, &pb.UartMessage{
|
|
Type: pb.MessageType_OTA_END,
|
|
Payload: &pb.UartMessage_OtaEnd{
|
|
OtaEnd: &pb.OtaEndPayload{},
|
|
},
|
|
}, "OTA_END"); err != nil {
|
|
return err
|
|
}
|
|
st, err := readOtaStatusLocked(sp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if st.GetStatus() != otaStSuccess {
|
|
return fmt.Errorf("OTA failed: status=%d error=%d written=%d",
|
|
st.GetStatus(), st.GetError(), st.GetBytesWritten())
|
|
}
|
|
fmt.Printf("OTA success: %d bytes written (slot %d) — reboot to boot new image\n",
|
|
st.GetBytesWritten(), st.GetTargetSlot())
|
|
return nil
|
|
}
|
|
|
|
func writeUartMessageLocked(sp *serialPort, msg *pb.UartMessage, cmdName string) error {
|
|
frame, err := encodeUartMessage(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !sp.quiet {
|
|
log.Printf("sending %s (%d frame bytes)", cmdName, len(frame))
|
|
}
|
|
_, err = sp.port.Write(frame)
|
|
return err
|
|
}
|
|
|
|
func encodeUartMessage(msg *pb.UartMessage) ([]byte, error) {
|
|
body, err := proto.Marshal(msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
payload := append([]byte{byte(msg.Type)}, body...)
|
|
return uartframe.EncodeFrame(payload)
|
|
}
|
|
|
|
func decodeUartPayload(payload []byte) (*pb.UartMessage, error) {
|
|
if len(payload) == 0 {
|
|
return nil, fmt.Errorf("empty response")
|
|
}
|
|
var msg pb.UartMessage
|
|
if err := proto.Unmarshal(payload[1:], &msg); err != nil {
|
|
return nil, err
|
|
}
|
|
msg.Type = pb.MessageType(payload[0])
|
|
return &msg, nil
|
|
}
|
|
|
|
func waitOtaStatusLocked(sp *serialPort, want uint32, timeout time.Duration) (*pb.OtaStatusPayload, error) {
|
|
deadline := time.Now().Add(timeout)
|
|
for {
|
|
if time.Now().After(deadline) {
|
|
return nil, fmt.Errorf("timeout waiting for OTA status %d", want)
|
|
}
|
|
if err := sp.port.SetReadTimeout(time.Until(deadline)); err != nil {
|
|
return nil, err
|
|
}
|
|
st, err := readOtaStatusLocked(sp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
switch st.GetStatus() {
|
|
case want:
|
|
if want == otaStReady {
|
|
fmt.Printf("OTA ready: inactive slot %d\n", st.GetTargetSlot())
|
|
}
|
|
return st, nil
|
|
case otaStPreparing:
|
|
fmt.Printf("OTA preparing partition (erase may take ~30s)…\n")
|
|
case otaStFailed:
|
|
return nil, fmt.Errorf("OTA failed (error=%d)", st.GetError())
|
|
}
|
|
}
|
|
}
|
|
|
|
func readOtaStatusLocked(sp *serialPort) (*pb.OtaStatusPayload, error) {
|
|
payload, err := uartframe.ReadFrame(sp.port, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response: %w", err)
|
|
}
|
|
msg, err := decodeUartPayload(payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if msg.GetType() != pb.MessageType_OTA_STATUS {
|
|
return nil, fmt.Errorf("unexpected response type %v", msg.GetType())
|
|
}
|
|
st := msg.GetOtaStatus()
|
|
if st == nil {
|
|
return nil, fmt.Errorf("missing ota_status")
|
|
}
|
|
return st, nil
|
|
}
|