diff --git a/mlflow_stagein.py b/mlflow_stagein.py index a953546..64c9725 100644 --- a/mlflow_stagein.py +++ b/mlflow_stagein.py @@ -47,6 +47,7 @@ def copy_model(connection_id, **context): ssh_hook = get_connection(conn_id=connection_id, **context) + clt = ssh_hook.get_conn() sftp_client = ssh_hook.get_conn().open_sftp() with open(ret, "rb") as sr: @@ -55,7 +56,7 @@ def copy_model(connection_id, **context): if file_exist(sftp=sftp_client, name=target_name): print(target_name," exists. Overwritting.") - sftp_client.exec_command(command=f"touch {target_name}") + clt.exec_command(command=f"touch {target_name}") with sftp_client.open(target_name, "wb") as tr: tr.set_pipelined(pipelined=True) copy_streams(inp=sr, outp=tr)