//
// Syd: rock-solid application kernel
// src/kernel/mem.rs: Memory syscall handlers
//
// Copyright (c) 2023, 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use std::{
    fs::File,
    io::Seek,
    os::fd::{AsRawFd, RawFd},
};

use libseccomp::ScmpNotifResp;
use memchr::arch::all::is_prefix;
use nix::{errno::Errno, fcntl::OFlag};

use crate::{
    config::{PAGE_SIZE, PROC_FILE},
    elf::ExecutableFile,
    error,
    fs::{is_writable_fd, safe_open_magicsym, CanonicalPath},
    hook::UNotifyEventRequest,
    kernel::sandbox_path,
    path::XPathBuf,
    proc::{proc_mem_limit, proc_statm},
    sandbox::{Action, Capability, IntegrityError},
    warn,
};

const PROT_EXEC: u64 = libc::PROT_EXEC as u64;
const MAP_ANONYMOUS: u64 = libc::MAP_ANONYMOUS as u64;

pub(crate) fn sys_brk(request: UNotifyEventRequest) -> ScmpNotifResp {
    syscall_mem_handler(request, "brk", Capability::CAP_MEM)
}

pub(crate) fn sys_mmap(request: UNotifyEventRequest) -> ScmpNotifResp {
    syscall_mem_handler(
        request,
        "mmap",
        Capability::CAP_MEM | Capability::CAP_EXEC | Capability::CAP_FORCE | Capability::CAP_TPE,
    )
}

pub(crate) fn sys_mmap2(request: UNotifyEventRequest) -> ScmpNotifResp {
    syscall_mem_handler(
        request,
        "mmap2",
        Capability::CAP_MEM | Capability::CAP_EXEC | Capability::CAP_FORCE | Capability::CAP_TPE,
    )
}

pub(crate) fn sys_mremap(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;
    let old_size = req.data.args[1];
    let new_size = req.data.args[2];
    if new_size < old_size {
        // SAFETY: System call wants to shrink memory.
        // No pointer dereference in size check.
        return unsafe { request.continue_syscall() };
    }
    syscall_mem_handler(request, "mremap", Capability::CAP_MEM)
}

#[allow(clippy::cognitive_complexity)]
fn syscall_mem_handler(
    request: UNotifyEventRequest,
    name: &str,
    caps: Capability,
) -> ScmpNotifResp {
    let req = request.scmpreq;

    // Get mem & vm max.
    let sandbox = request.get_sandbox();
    let verbose = sandbox.verbose;
    let caps = sandbox.getcaps(caps);
    let exec = caps.contains(Capability::CAP_EXEC);
    let force = caps.contains(Capability::CAP_FORCE);
    let tpe = caps.contains(Capability::CAP_TPE);
    let mem = caps.contains(Capability::CAP_MEM);
    let mem_max = sandbox.mem_max;
    let mem_vm_max = sandbox.mem_vm_max;
    let mem_act = sandbox.default_action(Capability::CAP_MEM);
    let restrict_memory = !sandbox.allow_unsafe_memory();
    let restrict_stack = !sandbox.allow_unsafe_stack();

    if !exec
        && !force
        && !tpe
        && !restrict_memory
        && !restrict_stack
        && (!mem || (mem_max == 0 && mem_vm_max == 0))
    {
        // SAFETY: No pointer dereference in security check.
        // This is safe to continue.
        return unsafe { request.continue_syscall() };
    }

    if (exec || force || tpe || restrict_memory || restrict_stack)
        && is_prefix(name.as_bytes(), b"mmap")
        && req.data.args[2] & PROT_EXEC != 0
        && req.data.args[3] & MAP_ANONYMOUS == 0
    {
        // Check file descriptor for Exec access.
        // Read and Write were already checked at open(2).
        #[allow(clippy::cast_possible_truncation)]
        let remote_fd = req.data.args[4] as RawFd;
        if remote_fd < 0 {
            return request.fail_syscall(Errno::EBADF);
        }

        // SAFETY: Get the file descriptor before access check
        // as it may change after which is a TOCTOU vector.
        let fd = match request.get_fd(remote_fd) {
            Ok(fd) => fd,
            Err(_) => return request.fail_syscall(Errno::EBADF),
        };

        // Step 1: Check if file is open for write,
        // but set as PROT_READ|PROT_EXEC which breaks W^X!
        // We do not need to check for PROT_WRITE here as
        // this is already enforced at kernel-level when
        // trace/allow_unsafe_memory:1 is not set at startup.
        if restrict_memory && is_writable_fd(&fd).unwrap_or(true) {
            return request.fail_syscall(Errno::EACCES);
        }

        let mut path = match CanonicalPath::new_fd(fd.into(), req.pid(), remote_fd) {
            Ok(path) => path,
            Err(errno) => return request.fail_syscall(errno),
        };

        // Step 2: Check for Exec sandboxing.
        if exec {
            if let Err(errno) = sandbox_path(
                Some(&request),
                &sandbox,
                request.scmpreq.pid(), // Unused when request.is_some()
                path.abs(),
                Capability::CAP_EXEC,
                false,
                name,
            ) {
                return request.fail_syscall(errno);
            }
        }

        // Step 3: Check for TPE sandboxing.
        if tpe {
            // MUST_PATH ensures path.dir is Some.
            #[allow(clippy::disallowed_methods)]
            let file = path.dir.as_ref().unwrap();
            let (action, msg) = sandbox.check_tpe(file, path.abs());
            if !matches!(action, Action::Allow | Action::Filter) {
                let msg = msg.as_deref().unwrap_or("?");
                if verbose {
                    error!("ctx": "trusted_path_execution",
                        "err": format!("library load from untrusted path blocked: {msg}"),
                        "sys": name, "path": &path,
                        "req": &request,
                        "tip": "move the library to a safe location or use `sandbox/tpe:off'");
                } else {
                    error!("ctx": "trusted_path_execution",
                        "err": format!("library load from untrusted path blocked: {msg}"),
                        "sys": name, "path": &path,
                        "pid": request.scmpreq.pid,
                        "tip": "move the library to a safe location or use `sandbox/tpe:off'");
                }
            }
            match action {
                Action::Allow | Action::Warn => {}
                Action::Deny | Action::Filter => return request.fail_syscall(Errno::EACCES),
                Action::Panic => panic!(),
                Action::Exit => std::process::exit(libc::EACCES),
                action => {
                    // Stop|Kill
                    let _ = request.kill(action);
                    return request.fail_syscall(Errno::EACCES);
                }
            }
        }

        if force || restrict_stack {
            // The following checks require the contents of the file.
            // SAFETY:
            // 1. Reopen the file via `/proc/thread-self/fd` to avoid sharing the file offset.
            // 2. `path` is a remote-fd transfer which asserts `path.dir` is Some.
            #[allow(clippy::disallowed_methods)]
            let fd = path.dir.take().unwrap();
            let pfd = XPathBuf::from_self_fd(fd.as_raw_fd());

            let mut file =
                match safe_open_magicsym(PROC_FILE(), &pfd, OFlag::O_RDONLY).map(File::from) {
                    Ok(file) => file,
                    Err(_) => {
                        return request.fail_syscall(Errno::EBADF);
                    }
                };

            if restrict_stack {
                // Step 4: Check for non-executable stack.
                // An execstack library that is dlopened into an executable
                // that is otherwise mapped no-execstack can change the
                // stack permissions to executable! This has been
                // (ab)used in at least one CVE:
                // https://www.qualys.com/2023/07/19/cve-2023-38408/rce-openssh-forwarded-ssh-agent.txt
                let result = (|file: &mut File| -> Result<(), Errno> {
                    let exe = ExecutableFile::parse(&mut *file, true).or(Err(Errno::EACCES))?;
                    if matches!(exe, ExecutableFile::Elf { xs: true, .. }) {
                        if !sandbox.filter_path(Capability::CAP_EXEC, path.abs()) {
                            if verbose {
                                error!("ctx": "check_lib",
                                    "err": "library load with executable stack blocked",
                                    "sys": name, "path": path.abs(),
                                    "tip": "configure `trace/allow_unsafe_stack:1'",
                                    "lib": format!("{exe}"),
                                    "req": &request);
                            } else {
                                error!("ctx": "check_lib",
                                    "err": "library load with executable stack blocked",
                                    "sys": name, "path": path.abs(),
                                    "tip": "configure `trace/allow_unsafe_stack:1'",
                                    "lib": format!("{exe}"),
                                    "pid": request.scmpreq.pid);
                            }
                        }
                        Err(Errno::EACCES)
                    } else {
                        Ok(())
                    }
                })(&mut file);

                if let Err(errno) = result {
                    return request.fail_syscall(errno);
                }
            }

            if force {
                // Step 5: Check for Force sandboxing.
                if restrict_stack && file.rewind().is_err() {
                    drop(sandbox); // release the read-lock.
                    return request.fail_syscall(Errno::EBADF);
                }
                let result = sandbox.check_force2(path.abs(), &mut file);

                let deny = match result {
                    Ok(action) => {
                        if !matches!(action, Action::Allow | Action::Filter) {
                            if verbose {
                                warn!("ctx": "verify_lib", "act": action,
                                    "sys": name, "path": path.abs(),
                                    "tip": format!("configure `force+{}:<checksum>'", path.abs()),
                                    "sys": name, "req": &request);
                            } else {
                                warn!("ctx": "verify_lib", "act": action,
                                    "sys": name, "path": path.abs(),
                                    "tip": format!("configure `force+{}:<checksum>'", path.abs()),
                                    "pid": request.scmpreq.pid);
                            }
                        }
                        match action {
                            Action::Allow | Action::Warn => false,
                            Action::Deny | Action::Filter => true,
                            Action::Panic => panic!(),
                            Action::Exit => std::process::exit(libc::EACCES),
                            _ => {
                                // Stop|Kill
                                let _ = request.kill(action);
                                true
                            }
                        }
                    }
                    Err(IntegrityError::Sys(errno)) => {
                        if verbose {
                            error!("ctx": "verify_lib",
                                "err": format!("system error during library checksum calculation: {errno}"),
                                "sys": name, "path": path.abs(),
                                "tip": format!("configure `force+{}:<checksum>'", path.abs()),
                                "req": &request);
                        } else {
                            error!("ctx": "verify_lib",
                                "err": format!("system error during library checksum calculation: {errno}"),
                                "sys": name, "path": path.abs(),
                                "tip": format!("configure `force+{}:<checksum>'", path.abs()),
                                "pid": request.scmpreq.pid);
                        }
                        true
                    }
                    Err(IntegrityError::Hash {
                        action,
                        expected,
                        found,
                    }) => {
                        if action != Action::Filter {
                            if sandbox.verbose {
                                error!("ctx": "verify_lib", "act": action,
                                    "err": format!("library checksum mismatch: {found} is not {expected}"),
                                    "sys": name, "path": path.abs(),
                                    "tip": format!("configure `force+{}:<checksum>'", path.abs()),
                                    "req": &request);
                            } else {
                                error!("ctx": "verify_lib", "act": action,
                                    "err": format!("library checksum mismatch: {found} is not {expected}"),
                                    "sys": name, "path": path.abs(),
                                    "tip": format!("configure `force+{}:<checksum>'", path.abs()),
                                    "pid": request.scmpreq.pid);
                            }
                        }
                        match action {
                            // Allow cannot happen.
                            Action::Warn => false,
                            Action::Deny | Action::Filter => true,
                            Action::Panic => panic!(),
                            Action::Exit => std::process::exit(libc::EACCES),
                            _ => {
                                // Stop|Kill
                                let _ = request.kill(action);
                                true
                            }
                        }
                    }
                };

                if deny {
                    return request.fail_syscall(Errno::EACCES);
                }
            }
        }
    }
    drop(sandbox); // release the read-lock.

    if !mem || (mem_max == 0 && mem_vm_max == 0) {
        // SAFETY:
        // (a) Exec and Memory sandboxing are both disabled.
        // (b) Exec granted access, Memory sandboxing is disabled.
        // The first candidate is safe as sandboxing is disabled,
        // however (b) should theoretically suffer from VFS TOCTOU as
        // the fd can change after the access check. However, our tests
        // show this is not the case, see vfsmod_toctou_mmap integration
        // test.
        return unsafe { request.continue_syscall() };
    }

    // Check VmSize
    if mem_vm_max > 0 {
        let mem_vm_cur = match proc_statm(req.pid()) {
            Ok(statm) => statm.size.saturating_mul(*PAGE_SIZE),
            Err(errno) => return request.fail_syscall(errno),
        };
        if mem_vm_cur >= mem_vm_max {
            if mem_act != Action::Filter {
                if verbose {
                    warn!("ctx": "access", "cap": Capability::CAP_MEM, "act": mem_act,
                        "sys": name, "mem_vm_max": mem_vm_max, "mem_vm_cur": mem_vm_cur,
                        "tip": "increase `mem/vm_max'",
                        "req": &request);
                } else {
                    warn!("ctx": "access", "cap": Capability::CAP_MEM, "act": mem_act,
                        "sys": name, "mem_vm_max": mem_vm_max, "mem_vm_cur": mem_vm_cur,
                        "tip": "increase `mem/vm_max'",
                        "pid": request.scmpreq.pid);
                }
            }
            match mem_act {
                // Allow cannot happen.
                Action::Warn => {}
                Action::Deny | Action::Filter => return request.fail_syscall(Errno::ENOMEM),
                Action::Panic => panic!(),
                Action::Exit => std::process::exit(libc::ENOMEM),
                _ => {
                    // Stop|Kill
                    let _ = request.kill(mem_act);
                    return request.fail_syscall(Errno::ENOMEM);
                }
            }
        }
    }

    // Check PSS
    if mem_max > 0 {
        match proc_mem_limit(req.pid(), mem_max) {
            Ok(false) => {
                // SAFETY: No pointer dereference in security check.
                unsafe { request.continue_syscall() }
            }
            Ok(true) => {
                if mem_act != Action::Filter {
                    if verbose {
                        warn!("ctx": "access", "cap": Capability::CAP_MEM, "act": mem_act,
                            "sys": name, "mem_max": mem_max,
                            "tip": "increase `mem/max'",
                            "req": &request);
                    } else {
                        warn!("ctx": "access", "cap": Capability::CAP_MEM, "act": mem_act,
                            "sys": name, "mem_max": mem_max,
                            "tip": "increase `mem/max'",
                            "pid": request.scmpreq.pid);
                    }
                }
                match mem_act {
                    // Allow cannot happen.
                    Action::Warn => {
                        // SAFETY: No pointer dereference in security check.
                        unsafe { request.continue_syscall() }
                    }
                    Action::Deny | Action::Filter => request.fail_syscall(Errno::ENOMEM),
                    Action::Panic => panic!(),
                    Action::Exit => std::process::exit(libc::ENOMEM),
                    _ => {
                        // Stop|Kill
                        let _ = request.kill(mem_act);
                        request.fail_syscall(Errno::ENOMEM)
                    }
                }
            }
            Err(errno) => request.fail_syscall(errno),
        }
    } else {
        // SAFETY: No pointer dereference in security check.
        unsafe { request.continue_syscall() }
    }
}
