Skip to content

Commit

Permalink
feat: support upgrades in different slots; automate marking for swap …
Browse files Browse the repository at this point in the history
…and reset in the upgrade command
  • Loading branch information
JPHutchins committed Dec 13, 2023
1 parent d0579db commit 8abe00a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 12 deletions.
5 changes: 5 additions & 0 deletions smpmgr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typer
from rich.progress import Progress, SpinnerColumn, TextColumn
from serial import SerialException # type: ignore
from smp.exceptions import SMPBadStartDelimiter
from smpclient import SMPClient
from smpclient.generics import SMPRequest, TEr0, TEr1, TErr, TRep
from smpclient.transport.serial import SMPSerialTransport
Expand Down Expand Up @@ -75,3 +76,7 @@ async def smp_request(
except asyncio.TimeoutError:
progress.update(task, description=f"{description} timeout", completed=True)
raise typer.Exit(code=1)
except SMPBadStartDelimiter:
progress.update(task, description=f"{description} SMP error", completed=True)
typer.echo("Is the device an SMP server?")
raise typer.Exit(code=1)
29 changes: 24 additions & 5 deletions smpmgr/image_management.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""The image subcommand group."""

import asyncio
from io import BufferedReader
from pathlib import Path
from typing import cast

import typer
Expand All @@ -13,8 +15,10 @@
TimeRemainingColumn,
TransferSpeedColumn,
)
from smp.exceptions import SMPBadStartDelimiter
from smpclient import SMPClient
from smpclient.generics import error, success
from smpclient.mcuboot import ImageInfo
from smpclient.requests.image_management import ImageStatesRead
from typing_extensions import Annotated

Expand Down Expand Up @@ -50,7 +54,9 @@ async def f() -> None:
asyncio.run(f())


async def upload_with_progress_bar(smpclient: SMPClient, file: typer.FileBinaryRead) -> None:
async def upload_with_progress_bar(
smpclient: SMPClient, file: typer.FileBinaryRead | BufferedReader, slot: int = 0
) -> None:
"""Animate a progress bar while uploading the FW image."""

with Progress(
Expand All @@ -67,21 +73,34 @@ async def upload_with_progress_bar(smpclient: SMPClient, file: typer.FileBinaryR
image = file.read()
file.close()
task = progress.add_task("Uploading", total=len(image), filename=file.name, start=True)
async for offset in smpclient.upload(image):
progress.update(task, completed=offset)
try:
async for offset in smpclient.upload(image, slot):
progress.update(task, completed=offset)
except SMPBadStartDelimiter:
progress.stop()
typer.echo("Got an unexpected response, is the device an SMP server?")
raise typer.Exit(code=1)


@app.command()
def upload(
ctx: typer.Context,
file: Annotated[typer.FileBinaryRead, typer.Argument(help="Path to FW image")],
file: Annotated[Path, typer.Argument(help="Path to FW image")],
slot: Annotated[int, typer.Option(help="The image slot to upload to")] = 0,
) -> None:
"""Upload a FW image."""

try:
ImageInfo.load_file(str(file))
except Exception as e:
typer.echo(f"Inspection of FW image failed: {e}")
raise typer.Exit(code=1)

smpclient = get_smpclient(cast(Options, ctx.obj))

async def f() -> None:
await connect_with_spinner(smpclient)
await upload_with_progress_bar(smpclient, file)
with open(file, "rb") as f:
await upload_with_progress_bar(smpclient, f, slot)

asyncio.run(f())
58 changes: 51 additions & 7 deletions smpmgr/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
"""Entry point for the `smpmgr` application."""

import asyncio
from pathlib import Path
from typing import cast

import typer
from rich import print
from smp.os_management import OS_MGMT_RET_RC
from smpclient.generics import error, success
from smpclient.mcuboot import IMAGE_TLV, ImageInfo, TLVNotFound
from smpclient.requests.image_management import ImageStatesWrite
from smpclient.requests.os_management import ResetWrite
from typing_extensions import Annotated

from smpmgr import image_management, os_management
from smpmgr.common import Options, TransportDefinition, connect_with_spinner, get_smpclient
from smpmgr.common import (
Options,
TransportDefinition,
connect_with_spinner,
get_smpclient,
smp_request,
)
from smpmgr.image_management import upload_with_progress_bar

app = typer.Typer()
Expand All @@ -38,26 +47,61 @@ def options(
@app.command()
def upgrade(
ctx: typer.Context,
file: Annotated[typer.FileBinaryRead, typer.Argument(help="Path to FW image")],
file: Annotated[Path, typer.Argument(help="Path to FW image")],
slot: Annotated[int, typer.Option(help="The image slot to upload to")] = 0,
) -> None:
"""Upload a FW image, mark it for next boot, and reset the device."""

smpclient = get_smpclient(cast(Options, ctx.obj))
try:
image_info = ImageInfo.load_file(str(file))
except Exception as e:
typer.echo(f"Inspection of FW image failed: {e}")
raise typer.Exit(code=1)

try:
image_tlv_sha256 = image_info.get_tlv(IMAGE_TLV.SHA256)
except TLVNotFound:
typer.echo("Could not find IMAGE_TLV_SHA256 in image.")
raise typer.Exit(code=1)

options = cast(Options, ctx.obj)
smpclient = get_smpclient(options)

async def f() -> None:
await connect_with_spinner(smpclient)
await upload_with_progress_bar(smpclient, file)

r = await smpclient.request(ResetWrite()) # type: ignore
with open(file, "rb") as f:
await upload_with_progress_bar(smpclient, f, slot)

if slot != 0:
# mark the new image for testing (swap)
r = await smp_request(
smpclient,
options,
ImageStatesWrite(hash=image_tlv_sha256.value), # type: ignore
"Marking uploaded image for test upgrade...",
) # type: ignore
if error(r):
print(r)
raise typer.Exit(code=1)
elif success(r):
pass
else:
raise Exception("Unreachable")

r = await smp_request(smpclient, options, ResetWrite()) # type: ignore
if error(r):
if r.rc != OS_MGMT_RET_RC.OK:
print(r)
return
typer.Exit(code=1)
elif success(r):
pass
else:
raise Exception("Unreachable")

print("Upgrade complete. The device may take a few minutes to complete FW swap.")
print("Upgrade complete.")

if slot != 0:
print("The device may take a few minutes to complete FW swap.")

asyncio.run(f())

0 comments on commit 8abe00a

Please # to comment.