diff --git a/smpmgr/common.py b/smpmgr/common.py index b4813f4..17f3eb8 100644 --- a/smpmgr/common.py +++ b/smpmgr/common.py @@ -5,6 +5,7 @@ import typer from rich.progress import Progress, SpinnerColumn, TextColumn from serial import SerialException # type: ignore +from smp.exceptions import SMPBadStartDelimiter from smpclient import SMPClient from smpclient.generics import SMPRequest, TEr0, TEr1, TErr, TRep from smpclient.transport.serial import SMPSerialTransport @@ -75,3 +76,7 @@ async def smp_request( except asyncio.TimeoutError: progress.update(task, description=f"{description} timeout", completed=True) raise typer.Exit(code=1) + except SMPBadStartDelimiter: + progress.update(task, description=f"{description} SMP error", completed=True) + typer.echo("Is the device an SMP server?") + raise typer.Exit(code=1) diff --git a/smpmgr/image_management.py b/smpmgr/image_management.py index 5c1eda2..61585b2 100644 --- a/smpmgr/image_management.py +++ b/smpmgr/image_management.py @@ -1,6 +1,8 @@ """The image subcommand group.""" import asyncio +from io import BufferedReader +from pathlib import Path from typing import cast import typer @@ -13,8 +15,10 @@ TimeRemainingColumn, TransferSpeedColumn, ) +from smp.exceptions import SMPBadStartDelimiter from smpclient import SMPClient from smpclient.generics import error, success +from smpclient.mcuboot import ImageInfo from smpclient.requests.image_management import ImageStatesRead from typing_extensions import Annotated @@ -50,7 +54,9 @@ async def f() -> None: asyncio.run(f()) -async def upload_with_progress_bar(smpclient: SMPClient, file: typer.FileBinaryRead) -> None: +async def upload_with_progress_bar( + smpclient: SMPClient, file: typer.FileBinaryRead | BufferedReader, slot: int = 0 +) -> None: """Animate a progress bar while uploading the FW image.""" with Progress( @@ -67,21 +73,34 @@ async def upload_with_progress_bar(smpclient: SMPClient, file: typer.FileBinaryR image = file.read() file.close() task = progress.add_task("Uploading", total=len(image), filename=file.name, start=True) - async for offset in smpclient.upload(image): - progress.update(task, completed=offset) + try: + async for offset in smpclient.upload(image, slot): + progress.update(task, completed=offset) + except SMPBadStartDelimiter: + progress.stop() + typer.echo("Got an unexpected response, is the device an SMP server?") + raise typer.Exit(code=1) @app.command() def upload( ctx: typer.Context, - file: Annotated[typer.FileBinaryRead, typer.Argument(help="Path to FW image")], + file: Annotated[Path, typer.Argument(help="Path to FW image")], + slot: Annotated[int, typer.Option(help="The image slot to upload to")] = 0, ) -> None: """Upload a FW image.""" + try: + ImageInfo.load_file(str(file)) + except Exception as e: + typer.echo(f"Inspection of FW image failed: {e}") + raise typer.Exit(code=1) + smpclient = get_smpclient(cast(Options, ctx.obj)) async def f() -> None: await connect_with_spinner(smpclient) - await upload_with_progress_bar(smpclient, file) + with open(file, "rb") as f: + await upload_with_progress_bar(smpclient, f, slot) asyncio.run(f()) diff --git a/smpmgr/main.py b/smpmgr/main.py index ee897eb..be6aae5 100644 --- a/smpmgr/main.py +++ b/smpmgr/main.py @@ -1,17 +1,26 @@ """Entry point for the `smpmgr` application.""" import asyncio +from pathlib import Path from typing import cast import typer from rich import print from smp.os_management import OS_MGMT_RET_RC from smpclient.generics import error, success +from smpclient.mcuboot import IMAGE_TLV, ImageInfo, TLVNotFound +from smpclient.requests.image_management import ImageStatesWrite from smpclient.requests.os_management import ResetWrite from typing_extensions import Annotated from smpmgr import image_management, os_management -from smpmgr.common import Options, TransportDefinition, connect_with_spinner, get_smpclient +from smpmgr.common import ( + Options, + TransportDefinition, + connect_with_spinner, + get_smpclient, + smp_request, +) from smpmgr.image_management import upload_with_progress_bar app = typer.Typer() @@ -38,26 +47,61 @@ def options( @app.command() def upgrade( ctx: typer.Context, - file: Annotated[typer.FileBinaryRead, typer.Argument(help="Path to FW image")], + file: Annotated[Path, typer.Argument(help="Path to FW image")], + slot: Annotated[int, typer.Option(help="The image slot to upload to")] = 0, ) -> None: """Upload a FW image, mark it for next boot, and reset the device.""" - smpclient = get_smpclient(cast(Options, ctx.obj)) + try: + image_info = ImageInfo.load_file(str(file)) + except Exception as e: + typer.echo(f"Inspection of FW image failed: {e}") + raise typer.Exit(code=1) + + try: + image_tlv_sha256 = image_info.get_tlv(IMAGE_TLV.SHA256) + except TLVNotFound: + typer.echo("Could not find IMAGE_TLV_SHA256 in image.") + raise typer.Exit(code=1) + + options = cast(Options, ctx.obj) + smpclient = get_smpclient(options) async def f() -> None: await connect_with_spinner(smpclient) - await upload_with_progress_bar(smpclient, file) - r = await smpclient.request(ResetWrite()) # type: ignore + with open(file, "rb") as f: + await upload_with_progress_bar(smpclient, f, slot) + + if slot != 0: + # mark the new image for testing (swap) + r = await smp_request( + smpclient, + options, + ImageStatesWrite(hash=image_tlv_sha256.value), # type: ignore + "Marking uploaded image for test upgrade...", + ) # type: ignore + if error(r): + print(r) + raise typer.Exit(code=1) + elif success(r): + pass + else: + raise Exception("Unreachable") + + r = await smp_request(smpclient, options, ResetWrite()) # type: ignore if error(r): if r.rc != OS_MGMT_RET_RC.OK: print(r) - return + typer.Exit(code=1) elif success(r): pass else: raise Exception("Unreachable") - print("Upgrade complete. The device may take a few minutes to complete FW swap.") + print("Upgrade complete.") + + if slot != 0: + print("The device may take a few minutes to complete FW swap.") asyncio.run(f())