234 lines
5.3 KiB
Go
234 lines
5.3 KiB
Go
package virtiofsd
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/samber/mo"
|
|
)
|
|
|
|
type Manager struct {
|
|
stateDir string
|
|
pids map[string]int
|
|
}
|
|
|
|
func NewManager(stateDir string) *Manager {
|
|
return &Manager{
|
|
stateDir: stateDir,
|
|
pids: make(map[string]int),
|
|
}
|
|
}
|
|
|
|
func findVirtiofsd() (string, error) {
|
|
// First try PATH
|
|
if path, err := exec.LookPath("virtiofsd"); err == nil {
|
|
return path, nil
|
|
}
|
|
|
|
// Fall back to nix
|
|
cmd := exec.Command("nix", "path-info", "nixpkgs#virtiofsd")
|
|
output, err := cmd.Output()
|
|
if err != nil {
|
|
return "", fmt.Errorf("virtiofsd not found in PATH and nix lookup failed: %w", err)
|
|
}
|
|
|
|
storePath := strings.TrimSpace(string(output))
|
|
virtiofsdPath := filepath.Join(storePath, "bin", "virtiofsd")
|
|
|
|
if _, err := os.Stat(virtiofsdPath); err != nil {
|
|
return "", fmt.Errorf("virtiofsd binary not found at %s", virtiofsdPath)
|
|
}
|
|
|
|
return virtiofsdPath, nil
|
|
}
|
|
|
|
func (m *Manager) StartMount(mount Mount) mo.Result[int] {
|
|
if err := m.CleanStale([]Mount{mount}); err != nil {
|
|
return mo.Err[int](fmt.Errorf("failed to clean stale socket for %s: %w", mount.Tag, err))
|
|
}
|
|
|
|
if err := os.MkdirAll(mount.HostPath, 0755); err != nil {
|
|
return mo.Err[int](fmt.Errorf("failed to create host directory %s: %w", mount.HostPath, err))
|
|
}
|
|
|
|
virtiofsd, err := findVirtiofsd()
|
|
if err != nil {
|
|
return mo.Err[int](err)
|
|
}
|
|
|
|
cmd := exec.Command(virtiofsd,
|
|
"--socket-path="+mount.SocketPath,
|
|
"--shared-dir="+mount.HostPath,
|
|
"--cache=auto",
|
|
)
|
|
|
|
if err := cmd.Start(); err != nil {
|
|
return mo.Err[int](fmt.Errorf("failed to start virtiofsd for %s: %w", mount.Tag, err))
|
|
}
|
|
|
|
pid := cmd.Process.Pid
|
|
m.pids[mount.Tag] = pid
|
|
|
|
pidFile := m.pidFilePath(mount.Tag)
|
|
if err := os.WriteFile(pidFile, []byte(strconv.Itoa(pid)), 0644); err != nil {
|
|
_ = cmd.Process.Kill()
|
|
return mo.Err[int](fmt.Errorf("failed to write PID file for %s: %w", mount.Tag, err))
|
|
}
|
|
|
|
for i := 0; i < 50; i++ {
|
|
if _, err := os.Stat(mount.SocketPath); err == nil {
|
|
return mo.Ok(pid)
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
|
|
_ = m.StopMount(mount)
|
|
return mo.Err[int](fmt.Errorf("virtiofsd socket for %s did not appear within 5 seconds", mount.Tag))
|
|
}
|
|
|
|
func (m *Manager) StopMount(mount Mount) mo.Result[struct{}] {
|
|
pidFile := m.pidFilePath(mount.Tag)
|
|
pidBytes, err := os.ReadFile(pidFile)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return mo.Ok(struct{}{})
|
|
}
|
|
return mo.Err[struct{}](fmt.Errorf("failed to read PID file for %s: %w", mount.Tag, err))
|
|
}
|
|
|
|
pid, err := strconv.Atoi(strings.TrimSpace(string(pidBytes)))
|
|
if err != nil {
|
|
return mo.Err[struct{}](fmt.Errorf("invalid PID in file for %s: %w", mount.Tag, err))
|
|
}
|
|
|
|
process, err := os.FindProcess(pid)
|
|
if err != nil {
|
|
_ = os.Remove(pidFile)
|
|
_ = os.Remove(mount.SocketPath)
|
|
return mo.Ok(struct{}{})
|
|
}
|
|
|
|
if err := process.Signal(syscall.SIGTERM); err != nil {
|
|
_ = os.Remove(pidFile)
|
|
_ = os.Remove(mount.SocketPath)
|
|
return mo.Ok(struct{}{})
|
|
}
|
|
|
|
done := make(chan bool, 1)
|
|
go func() {
|
|
_, _ = process.Wait()
|
|
done <- true
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(5 * time.Second):
|
|
_ = process.Signal(syscall.SIGKILL)
|
|
<-done
|
|
}
|
|
|
|
_ = os.Remove(pidFile)
|
|
_ = os.Remove(mount.SocketPath)
|
|
delete(m.pids, mount.Tag)
|
|
|
|
return mo.Ok(struct{}{})
|
|
}
|
|
|
|
func (m *Manager) StartAll(mounts []Mount) mo.Result[struct{}] {
|
|
started := []Mount{}
|
|
|
|
for _, mount := range mounts {
|
|
result := m.StartMount(mount)
|
|
if result.IsError() {
|
|
for i := len(started) - 1; i >= 0; i-- {
|
|
_ = m.StopMount(started[i])
|
|
}
|
|
return mo.Err[struct{}](fmt.Errorf("failed to start mount %s: %w", mount.Tag, result.Error()))
|
|
}
|
|
started = append(started, mount)
|
|
}
|
|
|
|
return mo.Ok(struct{}{})
|
|
}
|
|
|
|
func (m *Manager) StopAll() mo.Result[struct{}] {
|
|
files, err := filepath.Glob(filepath.Join(m.stateDir, "virtiofsd-*.pid"))
|
|
if err != nil {
|
|
return mo.Err[struct{}](fmt.Errorf("failed to list PID files: %w", err))
|
|
}
|
|
|
|
for _, pidFile := range files {
|
|
pidBytes, err := os.ReadFile(pidFile)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
pid, err := strconv.Atoi(strings.TrimSpace(string(pidBytes)))
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if process, err := os.FindProcess(pid); err == nil {
|
|
_ = process.Signal(syscall.SIGTERM)
|
|
time.Sleep(100 * time.Millisecond)
|
|
_ = process.Signal(syscall.SIGKILL)
|
|
}
|
|
|
|
_ = os.Remove(pidFile)
|
|
}
|
|
|
|
sockFiles, err := filepath.Glob(filepath.Join(m.stateDir, "*.sock"))
|
|
if err == nil {
|
|
for _, sockFile := range sockFiles {
|
|
_ = os.Remove(sockFile)
|
|
}
|
|
}
|
|
|
|
m.pids = make(map[string]int)
|
|
|
|
return mo.Ok(struct{}{})
|
|
}
|
|
|
|
func (m *Manager) CleanStale(mounts []Mount) error {
|
|
for _, mount := range mounts {
|
|
if _, err := os.Stat(mount.SocketPath); err == nil {
|
|
pidFile := m.pidFilePath(mount.Tag)
|
|
pidBytes, err := os.ReadFile(pidFile)
|
|
if err != nil {
|
|
_ = os.Remove(mount.SocketPath)
|
|
continue
|
|
}
|
|
|
|
pid, err := strconv.Atoi(strings.TrimSpace(string(pidBytes)))
|
|
if err != nil {
|
|
_ = os.Remove(mount.SocketPath)
|
|
_ = os.Remove(pidFile)
|
|
continue
|
|
}
|
|
|
|
process, err := os.FindProcess(pid)
|
|
if err != nil {
|
|
_ = os.Remove(mount.SocketPath)
|
|
_ = os.Remove(pidFile)
|
|
continue
|
|
}
|
|
|
|
if err := process.Signal(syscall.Signal(0)); err != nil {
|
|
_ = os.Remove(mount.SocketPath)
|
|
_ = os.Remove(pidFile)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) pidFilePath(tag string) string {
|
|
return filepath.Join(m.stateDir, "virtiofsd-"+tag+".pid")
|
|
}
|