Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions pkg/driver/mounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package driver

import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"

Expand Down Expand Up @@ -63,12 +61,8 @@ func (m *mountServiceMounter) Mount(target string) (Unmounter, error) {
filers[i] = string(address)
}

cacheBase := m.driver.CacheDir
if cacheBase == "" {
cacheBase = os.TempDir()
}
cacheDir := filepath.Join(cacheBase, m.volumeID)
localSocket := mountmanager.LocalSocketPath(m.driver.volumeSocketDir, m.volumeID)
cacheDir := GetCacheDir(m.driver.CacheDir, m.volumeID)
localSocket := GetLocalSocket(m.driver.volumeSocketDir, m.volumeID)

args, err := m.buildMountArgs(target, cacheDir, localSocket, filers)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions pkg/driver/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ func (ns *NodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstag

// make sure there is no any garbage
_ = mount.CleanupMountPoint(stagingTargetPath, mountutil, true)

// Also clean up cache directory and socket if they exist
CleanupVolumeResources(ns.Driver, volumeID)
} else {
if err := volume.(*Volume).Unstage(stagingTargetPath); err != nil {
return nil, status.Error(codes.Internal, err.Error())
Expand Down
47 changes: 45 additions & 2 deletions pkg/driver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@ import (

"github.com/container-storage-interface/spec/lib/go/csi"
"github.com/seaweedfs/seaweedfs-csi-driver/pkg/datalocality"
"github.com/seaweedfs/seaweedfs-csi-driver/pkg/mountmanager"
"github.com/seaweedfs/seaweedfs/weed/glog"
"golang.org/x/net/context"
"google.golang.org/grpc"
"k8s.io/mount-utils"
)

func NewNodeServer(n *SeaweedFsDriver) *NodeServer {
if err := removeDirContent(n.CacheDir); err != nil {
glog.Warning("error cleaning up cache dir")
if n.CacheDir != "" {
cleanCacheDir := filepath.Clean(n.CacheDir)
cleanTempDir := filepath.Clean(os.TempDir())
if cleanCacheDir != cleanTempDir {
if err := removeDirContent(cleanCacheDir); err != nil {
glog.Warningf("error cleaning up cache dir %s: %v", cleanCacheDir, err)
}
}
}

return &NodeServer{
Expand All @@ -26,6 +33,42 @@ func NewNodeServer(n *SeaweedFsDriver) *NodeServer {
}
}

func GetCacheDir(cacheBase, volumeID string) string {
if cacheBase == "" {
cacheBase = os.TempDir()
}
return filepath.Join(cacheBase, volumeID)
}

func GetLocalSocket(volumeSocketDir, volumeID string) string {
return mountmanager.LocalSocketPath(volumeSocketDir, volumeID)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication in NodeUnstageVolume and Volume.Unstage, you can add a new helper function here to handle the cleanup of cache directories and socket files. This centralizes the cleanup logic.

func CleanupVolumeResources(driver *SeaweedFsDriver, volumeID string) {
	cacheDir := GetCacheDir(driver.CacheDir, volumeID)
	if err := os.RemoveAll(cacheDir); err != nil {
		glog.Warningf("failed to remove cache dir %s for volume %s: %v", cacheDir, volumeID, err)
	}

	localSocket := GetLocalSocket(driver.volumeSocketDir, volumeID)
	if err := os.Remove(localSocket); err != nil && !os.IsNotExist(err) {
		glog.Warningf("failed to remove local socket %s for volume %s: %v", localSocket, volumeID, err)
	}
}

func CleanupVolumeResources(driver *SeaweedFsDriver, volumeID string) {
cacheDir := GetCacheDir(driver.CacheDir, volumeID)

// Validate that cacheDir is within cacheBase to prevent path traversal
cacheBase := driver.CacheDir
if cacheBase == "" {
cacheBase = os.TempDir()
}
cleanCacheBase := filepath.Clean(cacheBase)
cleanCacheDir := filepath.Clean(cacheDir)
rel, err := filepath.Rel(cleanCacheBase, cleanCacheDir)
if err == nil && rel != "." && !strings.HasPrefix(rel, "..") {
if err := os.RemoveAll(cleanCacheDir); err != nil {
glog.Warningf("failed to remove cache dir %s for volume %s: %v", cleanCacheDir, volumeID, err)
}
} else {
glog.Warningf("skipping cache dir removal for volume %s: invalid path %s (rel: %s, err: %v)", volumeID, cleanCacheDir, rel, err)
}

localSocket := GetLocalSocket(driver.volumeSocketDir, volumeID)
if err := os.Remove(localSocket); err != nil && !os.IsNotExist(err) {
glog.Warningf("failed to remove local socket %s for volume %s: %v", localSocket, volumeID, err)
}
}

func NewIdentityServer(d *SeaweedFsDriver) *IdentityServer {
return &IdentityServer{
Driver: d,
Expand Down
21 changes: 11 additions & 10 deletions pkg/driver/volume.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,20 @@ func (vol *Volume) Unstage(stagingTargetPath string) error {
glog.Errorf("error cleaning up mount point for volume %s: %v", vol.VolumeId, err)
return err
}
} else {
if err := vol.unmounter.Unmount(); err != nil {
glog.Errorf("error unmounting volume during unstage: %s, err: %v", stagingTargetPath, err)
return err
}

return nil
}

if err := vol.unmounter.Unmount(); err != nil {
glog.Errorf("error unmounting volume during unstage: %s, err: %v", stagingTargetPath, err)
return err
if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) {
glog.Errorf("error removing staging path for volume %s at %s, err: %v", vol.VolumeId, stagingTargetPath, err)
return err
}
}

if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) {
glog.Errorf("error removing staging path for volume %s at %s, err: %v", vol.VolumeId, stagingTargetPath, err)
return err
}
// Always attempt to remove the cache directory and socket file
CleanupVolumeResources(vol.driver, vol.VolumeId)

return nil
}
Loading