From e564f73cd013a86b29ad376ba0d64341730b67da Mon Sep 17 00:00:00 2001 From: umagnus Date: Tue, 26 Sep 2023 08:21:59 +0000 Subject: [PATCH] smb restore from snapshot --- pkg/azurefile/controllerserver.go | 62 ++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/pkg/azurefile/controllerserver.go b/pkg/azurefile/controllerserver.go index bdb8e2ff3d..10206df498 100644 --- a/pkg/azurefile/controllerserver.go +++ b/pkg/azurefile/controllerserver.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net/url" + "os/exec" "strconv" "strings" "time" @@ -718,7 +719,7 @@ func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, a vs := req.VolumeContentSource switch vs.Type.(type) { case *csi.VolumeContentSource_Snapshot: - return status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported") + return d.restoreSnapshot(ctx, req, accountKey, shareOptions, storageEndpointSuffix) case *csi.VolumeContentSource_Volume: return d.copyFileShare(ctx, req, accountKey, shareOptions, storageEndpointSuffix) default: @@ -1072,6 +1073,65 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques return nil, status.Error(codes.Unimplemented, "") } +// restoreSnapshot restores from a snapshot +func (d *Driver) restoreSnapshot(ctx context.Context, req *csi.CreateVolumeRequest, accountKey string, shareOptions *fileclient.ShareOptions, storageEndpointSuffix string) error { + if shareOptions.Protocol == storage.EnabledProtocolsNFS { + return fmt.Errorf("protocol nfs is not supported for snapshot restore") + } + var sourceSnapshotID string + if req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetSnapshot() != nil { + sourceSnapshotID = req.GetVolumeContentSource().GetSnapshot().GetSnapshotId() + } + resourceGroupName, accountName, srcFileShareName, _, _, _, err := GetFileShareInfo(sourceSnapshotID) //nolint:dogsled + if err != nil { + return status.Error(codes.NotFound, err.Error()) + } + dstFileShareName := shareOptions.Name + if srcFileShareName == "" || dstFileShareName == "" { + return fmt.Errorf("srcFileShareName(%s) or dstFileShareName(%s) is empty", srcFileShareName, dstFileShareName) + } + + klog.V(2).Infof("generate sas token for account(%s)", accountName) + accountSasToken, genErr := generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes) + if genErr != nil { + return genErr + } + + timeAfter := time.After(waitForCopyTimeout) + timeTick := time.Tick(waitForCopyInterval) + srcPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, srcFileShareName, accountSasToken) + dstPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, dstFileShareName, accountSasToken) + + jobState, percent, err := getAzcopyJob(dstFileShareName) + klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) + if jobState == AzcopyJobError || jobState == AzcopyJobCompleted { + return err + } + klog.V(2).Infof("begin to copy fileshare %s to %s", srcFileShareName, dstFileShareName) + for { + select { + case <-timeTick: + jobState, percent, err := getAzcopyJob(dstFileShareName) + klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err) + switch jobState { + case AzcopyJobError, AzcopyJobCompleted: + return err + case AzcopyJobNotFound: + klog.V(2).Infof("copy fileshare %s to %s", srcFileShareName, dstFileShareName) + out, copyErr := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false").CombinedOutput() + if copyErr != nil { + klog.Warningf("CopyFileShare(%s, %s, %s) failed with error(%v): %v", resourceGroupName, accountName, dstFileShareName, copyErr, string(out)) + } else { + klog.V(2).Infof("copied fileshare %s to %s successfully", srcFileShareName, dstFileShareName) + } + return copyErr + } + case <-timeAfter: + return fmt.Errorf("timeout waiting for copy fileshare %s to %s succeed", srcFileShareName, dstFileShareName) + } + } +} + // ControllerExpandVolume controller expand volume func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { volumeID := req.GetVolumeId()