package main import ( "fmt" "log" "os" "time" "google.golang.org/protobuf/proto" uartframe "powerpod/gotool/uart" "powerpod/gotool/pb" ) const ( otaHostChunkSize = 200 otaFlashBlockSize = 4096 otaPrepareTimeout = 120 * time.Second otaDefaultTimeout = 15 * time.Second ) const ( otaStPreparing = 1 otaStReady = 2 otaStBlockAck = 3 otaStSuccess = 4 otaStFailed = 5 ) func runOTA(sp *serialPort, args []string) error { if len(args) < 1 { return fmt.Errorf("usage: ota ") } data, err := os.ReadFile(args[0]) if err != nil { return err } if len(data) == 0 { return fmt.Errorf("empty firmware file") } if err := sp.port.SetReadTimeout(otaPrepareTimeout); err != nil { return err } defer sp.port.SetReadTimeout(readTimeout) sp.mu.Lock() defer sp.mu.Unlock() fmt.Printf("OTA start: %d bytes firmware\n", len(data)) if err := writeUartMessageLocked(sp, &pb.UartMessage{ Type: pb.MessageType_OTA_START, Payload: &pb.UartMessage_OtaStart{ OtaStart: &pb.OtaStartPayload{TotalSize: uint32(len(data))}, }, }, "OTA_START"); err != nil { return err } if _, err := waitOtaStatusLocked(sp, otaStReady, otaPrepareTimeout); err != nil { return err } if err := sp.port.SetReadTimeout(otaDefaultTimeout); err != nil { return err } var seq uint32 blockNum := 0 for offset := 0; offset < len(data); { bytesInBlock := 0 for bytesInBlock < otaFlashBlockSize && offset < len(data) { n := otaHostChunkSize room := otaFlashBlockSize - bytesInBlock if n > room { n = room } if offset+n > len(data) { n = len(data) - offset } chunk := data[offset : offset+n] if err := writeUartMessageLocked(sp, &pb.UartMessage{ Type: pb.MessageType_OTA_PAYLOAD, Payload: &pb.UartMessage_OtaPayload{ OtaPayload: &pb.OtaPayload{Seq: seq, Data: chunk}, }, }, "OTA_PAYLOAD"); err != nil { return err } seq++ offset += n bytesInBlock += n } if bytesInBlock == otaFlashBlockSize { blockNum++ st, err := waitOtaStatusLocked(sp, otaStBlockAck, otaDefaultTimeout) if err != nil { return err } fmt.Printf(" block %d ack (%d bytes in flash, %d%%)\n", blockNum, st.GetBytesWritten(), offset*100/len(data)) } } if err := writeUartMessageLocked(sp, &pb.UartMessage{ Type: pb.MessageType_OTA_END, Payload: &pb.UartMessage_OtaEnd{ OtaEnd: &pb.OtaEndPayload{}, }, }, "OTA_END"); err != nil { return err } st, err := readOtaStatusLocked(sp) if err != nil { return err } if st.GetStatus() != otaStSuccess { return fmt.Errorf("OTA failed: status=%d error=%d written=%d", st.GetStatus(), st.GetError(), st.GetBytesWritten()) } fmt.Printf("OTA success: %d bytes written (slot %d) — reboot to boot new image\n", st.GetBytesWritten(), st.GetTargetSlot()) return nil } func writeUartMessageLocked(sp *serialPort, msg *pb.UartMessage, cmdName string) error { frame, err := encodeUartMessage(msg) if err != nil { return err } if !sp.quiet { log.Printf("sending %s (%d frame bytes)", cmdName, len(frame)) } _, err = sp.port.Write(frame) return err } func encodeUartMessage(msg *pb.UartMessage) ([]byte, error) { body, err := proto.Marshal(msg) if err != nil { return nil, err } payload := append([]byte{byte(msg.Type)}, body...) return uartframe.EncodeFrame(payload) } func decodeUartPayload(payload []byte) (*pb.UartMessage, error) { if len(payload) == 0 { return nil, fmt.Errorf("empty response") } var msg pb.UartMessage if err := proto.Unmarshal(payload[1:], &msg); err != nil { return nil, err } msg.Type = pb.MessageType(payload[0]) return &msg, nil } func waitOtaStatusLocked(sp *serialPort, want uint32, timeout time.Duration) (*pb.OtaStatusPayload, error) { deadline := time.Now().Add(timeout) for { if time.Now().After(deadline) { return nil, fmt.Errorf("timeout waiting for OTA status %d", want) } if err := sp.port.SetReadTimeout(time.Until(deadline)); err != nil { return nil, err } st, err := readOtaStatusLocked(sp) if err != nil { return nil, err } switch st.GetStatus() { case want: if want == otaStReady { fmt.Printf("OTA ready: inactive slot %d\n", st.GetTargetSlot()) } return st, nil case otaStPreparing: fmt.Printf("OTA preparing partition (erase may take ~30s)…\n") case otaStFailed: return nil, fmt.Errorf("OTA failed (error=%d)", st.GetError()) } } } func readOtaStatusLocked(sp *serialPort) (*pb.OtaStatusPayload, error) { payload, err := uartframe.ReadFrame(sp.port, nil) if err != nil { return nil, fmt.Errorf("read response: %w", err) } msg, err := decodeUartPayload(payload) if err != nil { return nil, err } if msg.GetType() != pb.MessageType_OTA_STATUS { return nil, fmt.Errorf("unexpected response type %v", msg.GetType()) } st := msg.GetOtaStatus() if st == nil { return nil, fmt.Errorf("missing ota_status") } return st, nil }