From 8781993968fd964ac723ff5f360b6f259e809a3e Mon Sep 17 00:00:00 2001 From: Aleksa Sarai Date: Mon, 1 Jul 2024 15:12:01 +1000 Subject: [PATCH] [1.1] rootfs: consolidate mountpoint creation logic The logic for how we create mountpoints is spread over each mountpoint preparation function, when in reality the behaviour is pretty uniform with only a handful of exceptions. So just move it all to one function that is easier to understand. Signed-off-by: Aleksa Sarai --- libcontainer/container_linux.go | 28 ++---- libcontainer/rootfs_linux.go | 160 ++++++++++++++----------------- libcontainer/utils/utils_unix.go | 15 +++ 3 files changed, 94 insertions(+), 109 deletions(-) diff --git a/libcontainer/container_linux.go b/libcontainer/container_linux.go index 7eb3a99a9d6..0c07ae6c875 100644 --- a/libcontainer/container_linux.go +++ b/libcontainer/container_linux.go @@ -1276,8 +1276,7 @@ func (c *linuxContainer) restoreNetwork(req *criurpc.CriuReq, criuOpts *CriuOpts // restore using CRIU. This function is inspired from the code in // rootfs_linux.go func (c *linuxContainer) makeCriuRestoreMountpoints(m *configs.Mount) error { - switch m.Device { - case "cgroup": + if m.Device == "cgroup" { // No mount point(s) need to be created: // // * for v1, mount points are saved by CRIU because @@ -1286,26 +1285,11 @@ func (c *linuxContainer) makeCriuRestoreMountpoints(m *configs.Mount) error { // * for v2, /sys/fs/cgroup is a real mount, but // the mountpoint appears as soon as /sys is mounted return nil - case "bind": - // The prepareBindMount() function checks if source - // exists. So it cannot be used for other filesystem types. - // TODO: pass something else than nil? Not sure if criu is - // impacted by issue #2484 - if err := prepareBindMount(m, c.config.Rootfs, nil); err != nil { - return err - } - default: - // for all other filesystems just create the mountpoints - dest, err := securejoin.SecureJoin(c.config.Rootfs, m.Destination) - if err != nil { - return err - } - if err := checkProcMount(c.config.Rootfs, dest, m, ""); err != nil { - return err - } - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } + } + // TODO: pass something else than nil? Not sure if criu is + // impacted by issue #2484 + if _, err := createMountpoint(c.config.Rootfs, m, nil, ""); err != nil { + return fmt.Errorf("create criu restore mount for %s mount: %w", m.Destination, err) } return nil } diff --git a/libcontainer/rootfs_linux.go b/libcontainer/rootfs_linux.go index f66267307e5..e6269bea812 100644 --- a/libcontainer/rootfs_linux.go +++ b/libcontainer/rootfs_linux.go @@ -224,36 +224,6 @@ func mountCmd(cmd configs.Command) error { return nil } -func prepareBindMount(m *configs.Mount, rootfs string, mountFd *int) error { - source := m.Source - if mountFd != nil { - source = "/proc/self/fd/" + strconv.Itoa(*mountFd) - } - - stat, err := os.Stat(source) - if err != nil { - // error out if the source of a bind mount does not exist as we will be - // unable to bind anything to it. - return err - } - // ensure that the destination of the bind mount is resolved of symlinks at mount time because - // any previous mounts can invalidate the next mount's destination. - // this can happen when a user specifies mounts within other mounts to cause breakouts or other - // evil stuff to try to escape the container's rootfs. - var dest string - if dest, err = securejoin.SecureJoin(rootfs, m.Destination); err != nil { - return err - } - if err := checkProcMount(rootfs, dest, m, source); err != nil { - return err - } - if err := createIfNotExists(dest, stat.IsDir()); err != nil { - return err - } - - return nil -} - func mountCgroupV1(m *configs.Mount, c *mountConfig) error { binds, err := getCgroupMounts(m) if err != nil { @@ -282,6 +252,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error { for _, b := range binds { if c.cgroupns { subsystemPath := filepath.Join(c.root, b.Destination) + subsystemName := filepath.Base(b.Destination) if err := os.MkdirAll(subsystemPath, 0o755); err != nil { return err } @@ -292,7 +263,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error { } var ( source = "cgroup" - data = filepath.Base(subsystemPath) + data = subsystemName ) if data == "systemd" { data = cgroups.CgroupNamePrefix + data @@ -322,14 +293,7 @@ func mountCgroupV1(m *configs.Mount, c *mountConfig) error { } func mountCgroupV2(m *configs.Mount, c *mountConfig) error { - dest, err := securejoin.SecureJoin(c.root, m.Destination) - if err != nil { - return err - } - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } - err = utils.WithProcfd(c.root, m.Destination, func(procfd string) error { + err := utils.WithProcfd(c.root, m.Destination, func(procfd string) error { return mount(m.Source, m.Destination, procfd, "cgroup2", uintptr(m.Flags), m.Data) }) if err == nil || !(errors.Is(err, unix.EPERM) || errors.Is(err, unix.EBUSY)) { @@ -411,6 +375,70 @@ func doTmpfsCopyUp(m *configs.Mount, rootfs, mountLabel string) (Err error) { }) } +var errRootfsToFile = errors.New("config tries to change rootfs to file") + +func createMountpoint(rootfs string, m *configs.Mount, mountFd *int, source string) (string, error) { + dest, err := securejoin.SecureJoin(rootfs, m.Destination) + if err != nil { + return "", err + } + if err := checkProcMount(rootfs, dest, m, source); err != nil { + return "", fmt.Errorf("check proc-safety of %s mount: %w", m.Destination, err) + } + + switch m.Device { + case "bind": + source := m.Source + if mountFd != nil { + source = "/proc/self/fd/" + strconv.Itoa(*mountFd) + } + + fi, err := os.Stat(source) + if err != nil { + // Error out if the source of a bind mount does not exist as we + // will be unable to bind anything to it. + return "", fmt.Errorf("bind mount source stat: %w", err) + } + // If the original source is not a directory, make the target a file. + if !fi.IsDir() { + // Make sure we aren't tricked into trying to make the root a file. + if rootfs == dest { + return "", fmt.Errorf("%w: file bind mount over rootfs", errRootfsToFile) + } + // Make the parent directory. + if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { + return "", fmt.Errorf("make parent dir of file bind-mount: %w", err) + } + // Make the target file. + f, err := os.OpenFile(dest, os.O_CREATE, 0o755) + if err != nil { + return "", fmt.Errorf("create target of file bind-mount: %w", err) + } + _ = f.Close() + // Nothing left to do. + return dest, nil + } + + case "tmpfs": + // If the original target exists, copy the mode for the tmpfs mount. + if stat, err := os.Stat(dest); err == nil { + dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode())) + if m.Data != "" { + dt = dt + "," + m.Data + } + m.Data = dt + + // Nothing left to do. + return dest, nil + } + } + + if err := os.MkdirAll(dest, 0o755); err != nil { + return "", err + } + return dest, nil +} + func mountToRootfs(m *configs.Mount, c *mountConfig) error { rootfs := c.root @@ -442,46 +470,27 @@ func mountToRootfs(m *configs.Mount, c *mountConfig) error { return mountPropagate(m, rootfs, "", nil) } - mountLabel := c.label mountFd := c.fd - dest, err := securejoin.SecureJoin(rootfs, m.Destination) + dest, err := createMountpoint(rootfs, m, mountFd, m.Source) if err != nil { - return err + return fmt.Errorf("create mount destination for %s mount: %w", m.Destination, err) } + mountLabel := c.label switch m.Device { case "mqueue": - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } if err := mountPropagate(m, rootfs, "", nil); err != nil { return err } return label.SetFileLabel(dest, mountLabel) case "tmpfs": - if stat, err := os.Stat(dest); err != nil { - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } - } else { - dt := fmt.Sprintf("mode=%04o", syscallMode(stat.Mode())) - if m.Data != "" { - dt = dt + "," + m.Data - } - m.Data = dt - } - if m.Extensions&configs.EXT_COPYUP == configs.EXT_COPYUP { err = doTmpfsCopyUp(m, rootfs, mountLabel) } else { err = mountPropagate(m, rootfs, mountLabel, nil) } - return err case "bind": - if err := prepareBindMount(m, rootfs, mountFd); err != nil { - return err - } if err := mountPropagate(m, rootfs, mountLabel, mountFd); err != nil { return err } @@ -509,12 +518,6 @@ func mountToRootfs(m *configs.Mount, c *mountConfig) error { } return mountCgroupV1(m, c) default: - if err := checkProcMount(rootfs, dest, m, m.Source); err != nil { - return err - } - if err := os.MkdirAll(dest, 0o755); err != nil { - return err - } return mountPropagate(m, rootfs, mountLabel, mountFd) } if err := setRecAttr(m, rootfs); err != nil { @@ -745,6 +748,9 @@ func createDeviceNode(rootfs string, node *devices.Device, bind bool) error { if err != nil { return err } + if dest == rootfs { + return fmt.Errorf("%w: mknod over rootfs", errRootfsToFile) + } if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { return err } @@ -1011,26 +1017,6 @@ func chroot() error { return nil } -// createIfNotExists creates a file or a directory only if it does not already exist. -func createIfNotExists(path string, isDir bool) error { - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - if isDir { - return os.MkdirAll(path, 0o755) - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - f, err := os.OpenFile(path, os.O_CREATE, 0o755) - if err != nil { - return err - } - _ = f.Close() - } - } - return nil -} - // readonlyPath will make a path read only. func readonlyPath(path string) error { if err := mount(path, path, "", "", unix.MS_BIND|unix.MS_REC, ""); err != nil { diff --git a/libcontainer/utils/utils_unix.go b/libcontainer/utils/utils_unix.go index bf3237a2911..0d95a203789 100644 --- a/libcontainer/utils/utils_unix.go +++ b/libcontainer/utils/utils_unix.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "strconv" + "strings" _ "unsafe" // for go:linkname "golang.org/x/sys/unix" @@ -115,3 +116,17 @@ func NewSockPair(name string) (parent *os.File, child *os.File, err error) { } return os.NewFile(uintptr(fds[1]), name+"-p"), os.NewFile(uintptr(fds[0]), name+"-c"), nil } + +// IsLexicallyInRoot is shorthand for strings.HasPrefix(path+"/", root+"/"), +// but properly handling the case where path or root are "/". +// +// NOTE: The return value only make sense if the path doesn't contain "..". +func IsLexicallyInRoot(root, path string) bool { + if root != "/" { + root += "/" + } + if path != "/" { + path += "/" + } + return strings.HasPrefix(path, root) +}