Skip to content

Commit

Permalink
smb restore from snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
umagnus committed Sep 26, 2023
1 parent 5cdcab0 commit e564f73
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion pkg/azurefile/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"net/url"
"os/exec"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -718,7 +719,7 @@ func (d *Driver) copyVolume(ctx context.Context, req *csi.CreateVolumeRequest, a
vs := req.VolumeContentSource
switch vs.Type.(type) {
case *csi.VolumeContentSource_Snapshot:
return status.Errorf(codes.InvalidArgument, "copy volume from volumeSnapshot is not supported")
return d.restoreSnapshot(ctx, req, accountKey, shareOptions, storageEndpointSuffix)
case *csi.VolumeContentSource_Volume:
return d.copyFileShare(ctx, req, accountKey, shareOptions, storageEndpointSuffix)
default:
Expand Down Expand Up @@ -1072,6 +1073,65 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
return nil, status.Error(codes.Unimplemented, "")
}

// restoreSnapshot restores from a snapshot
func (d *Driver) restoreSnapshot(ctx context.Context, req *csi.CreateVolumeRequest, accountKey string, shareOptions *fileclient.ShareOptions, storageEndpointSuffix string) error {
if shareOptions.Protocol == storage.EnabledProtocolsNFS {
return fmt.Errorf("protocol nfs is not supported for snapshot restore")
}
var sourceSnapshotID string
if req.GetVolumeContentSource() != nil && req.GetVolumeContentSource().GetSnapshot() != nil {
sourceSnapshotID = req.GetVolumeContentSource().GetSnapshot().GetSnapshotId()
}
resourceGroupName, accountName, srcFileShareName, _, _, _, err := GetFileShareInfo(sourceSnapshotID) //nolint:dogsled
if err != nil {
return status.Error(codes.NotFound, err.Error())
}
dstFileShareName := shareOptions.Name
if srcFileShareName == "" || dstFileShareName == "" {
return fmt.Errorf("srcFileShareName(%s) or dstFileShareName(%s) is empty", srcFileShareName, dstFileShareName)
}

klog.V(2).Infof("generate sas token for account(%s)", accountName)
accountSasToken, genErr := generateSASToken(accountName, accountKey, storageEndpointSuffix, d.sasTokenExpirationMinutes)
if genErr != nil {
return genErr
}

timeAfter := time.After(waitForCopyTimeout)
timeTick := time.Tick(waitForCopyInterval)
srcPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, srcFileShareName, accountSasToken)
dstPath := fmt.Sprintf("https://%s.file.%s/%s%s", accountName, storageEndpointSuffix, dstFileShareName, accountSasToken)

jobState, percent, err := getAzcopyJob(dstFileShareName)
klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err)
if jobState == AzcopyJobError || jobState == AzcopyJobCompleted {
return err
}
klog.V(2).Infof("begin to copy fileshare %s to %s", srcFileShareName, dstFileShareName)
for {
select {
case <-timeTick:
jobState, percent, err := getAzcopyJob(dstFileShareName)
klog.V(2).Infof("azcopy job status: %s, copy percent: %s%%, error: %v", jobState, percent, err)
switch jobState {
case AzcopyJobError, AzcopyJobCompleted:
return err
case AzcopyJobNotFound:
klog.V(2).Infof("copy fileshare %s to %s", srcFileShareName, dstFileShareName)
out, copyErr := exec.Command("azcopy", "copy", srcPath, dstPath, "--recursive", "--check-length=false").CombinedOutput()
if copyErr != nil {
klog.Warningf("CopyFileShare(%s, %s, %s) failed with error(%v): %v", resourceGroupName, accountName, dstFileShareName, copyErr, string(out))
} else {
klog.V(2).Infof("copied fileshare %s to %s successfully", srcFileShareName, dstFileShareName)
}
return copyErr
}
case <-timeAfter:
return fmt.Errorf("timeout waiting for copy fileshare %s to %s succeed", srcFileShareName, dstFileShareName)
}
}
}

// ControllerExpandVolume controller expand volume
func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
volumeID := req.GetVolumeId()
Expand Down

0 comments on commit e564f73

Please # to comment.