//
// Syd: rock-solid application kernel
// src/syd-fd.rs: Interact with remote file descriptors
//
// Copyright (c) 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use std::{
    env,
    ffi::OsString,
    os::{
        fd::{AsRawFd, FromRawFd, OwnedFd, RawFd},
        unix::process::CommandExt,
    },
    process::{Command, ExitCode},
};

use memchr::memchr;
use nix::{
    errno::Errno,
    fcntl::{open, readlinkat, OFlag},
    sys::stat::Mode,
    unistd::{dup2_raw, getpid, Pid},
};
use syd::{
    compat::getdents64,
    config::*,
    err::SydResult,
    fs::{duprand, parse_fd, pidfd_getfd, pidfd_open, set_cloexec, PIDFD_THREAD},
    path::{XPath, XPathBuf},
};

fn main() -> SydResult<ExitCode> {
    use lexopt::prelude::*;

    syd::set_sigpipe_dfl()?;

    // Parse CLI options.
    //
    // Note, option parsing is POSIXly correct:
    // POSIX recommends that no more options are parsed after the first
    // positional argument. The other arguments are then all treated as
    // positional arguments.
    // See: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/V1_chap12.html#tag_12_02
    let mut opt_pid = None;
    let mut opt_cmd = env::var_os(ENV_SH).unwrap_or(OsString::from(SYD_SH));
    let mut opt_arg = Vec::new();
    let mut opt_fds = Vec::new();

    let mut parser = lexopt::Parser::from_env();
    while let Some(arg) = parser.next()? {
        match arg {
            Short('h') => {
                help();
                return Ok(ExitCode::SUCCESS);
            }
            Short('p') => {
                let pid = parser.value()?;
                opt_pid = match pid.parse::<libc::pid_t>() {
                    Ok(pid) if pid > 0 => Some(Pid::from_raw(pid)),
                    _ => {
                        eprintln!("syd-fd: Invalid PID specified with -p!");
                        return Err(Errno::EINVAL.into());
                    }
                };
            }
            Short('f') => {
                let fd = parser.value()?;

                // Validate UTF-8.
                let fd = match fd.to_str() {
                    Some(fd) => fd,
                    None => {
                        eprintln!("syd-fd: Invalid UTF-8 in FD argument!");
                        return Err(Errno::EINVAL.into());
                    }
                };

                if let Some(idx) = memchr(b':', fd.as_bytes()) {
                    // Parse remote fd.
                    let remote_fd = &fd[..idx];
                    let remote_fd = match remote_fd.parse::<RawFd>() {
                        Ok(fd) if fd >= 0 => fd,
                        _ => {
                            eprintln!("syd-fd: Invalid FD specified with -f!");
                            return Err(Errno::EINVAL.into());
                        }
                    };

                    // Parse optional local fd.
                    let local_fd = &fd[idx + 1..];
                    let local_fd = match local_fd {
                        "rand" => Some(libc::AT_FDCWD),
                        fd => match fd.parse::<RawFd>() {
                            Ok(fd) if fd >= 0 => Some(fd),
                            _ => {
                                eprintln!("syd-fd: Invalid FD specified with -f!");
                                return Err(Errno::EINVAL.into());
                            }
                        },
                    };

                    opt_fds.push((remote_fd, local_fd));
                } else {
                    // Parse remote fd.
                    let remote_fd = match fd.parse::<RawFd>() {
                        Ok(fd) if fd >= 0 => fd,
                        _ => {
                            eprintln!("syd-fd: Invalid FD specified with -f!");
                            return Err(Errno::EINVAL.into());
                        }
                    };

                    opt_fds.push((remote_fd, None));
                }
            }
            Value(prog) => {
                opt_cmd = prog;
                opt_arg.extend(parser.raw_args()?);
            }
            _ => return Err(arg.unexpected().into()),
        }
    }

    let pid = if opt_fds.is_empty() {
        // List /proc/$pid/fd.
        let fds = proc_pid_fd(opt_pid)?;

        // Serialize as line-oriented compact JSON.
        for fd in fds {
            #[allow(clippy::disallowed_methods)]
            let fd = serde_json::to_string(&fd).expect("JSON");
            println!("{fd}");
        }

        return Ok(ExitCode::SUCCESS);
    } else if let Some(pid) = opt_pid {
        pid
    } else {
        eprintln!("PID must be specified with -p!");
        return Err(Errno::EINVAL.into());
    };

    // Open a PIDFd to the specified PID or TID.
    let flags = if *HAVE_PIDFD_THREAD { PIDFD_THREAD } else { 0 };
    let pidfd = pidfd_open(pid, flags)?;

    // Transfer remote fds.
    for (remote_fd, local_fd) in opt_fds {
        // Transfer fd with pidfd_getfd(2).
        let fd = pidfd_getfd(pidfd.as_raw_fd(), remote_fd)?;

        // Handle local fd.
        let fd = match local_fd {
            Some(libc::AT_FDCWD) => {
                let fd_rand = duprand(fd.as_raw_fd(), OFlag::empty())?;
                drop(fd);
                // SAFETY: duprand returns a valid FD on success.
                unsafe { OwnedFd::from_raw_fd(fd_rand) }
            }
            Some(newfd) => {
                // SAFETY: User should ensure no double-close happens.
                let fd_dup = unsafe { dup2_raw(&fd, newfd) }?;
                drop(fd);
                fd_dup
            }
            None => fd,
        };

        // Log progress.
        eprintln!("syd-fd: GETFD {remote_fd} -> {}", fd.as_raw_fd());

        // Prepare to pass the fd to the child.
        set_cloexec(&fd, false)?;

        // Leak fd on purpose, child will take over.
        std::mem::forget(fd);
    }

    // Log progress.
    eprintln!("syd-fd: EXEC {}", XPathBuf::from(opt_cmd.clone()));

    // Execute command, /bin/sh by default.
    Ok(ExitCode::from(
        127 + Command::new(opt_cmd)
            .args(opt_arg)
            .exec()
            .raw_os_error()
            .unwrap_or(0) as u8,
    ))
}

fn help() {
    println!("Usage: syd-fd [-h] [-p pid] [-f remote_fd[:local_fd]].. {{command [args...]}}");
    println!("Interact with remote file descriptors");
    println!("Execute the given command or `/bin/sh' with inherited remote fds.");
    println!("List remote file descriptors with the given PID if no -f is given.");
    println!("Use -p to specify PID.");
    println!("Use -f remote_fd to specify remote file descriptor to transfer.");
    println!("Optionally specify comma-delimited local fd as target.");
    println!("Use `rand' as target fd to duplicate to a random valid slot.");
}

// List `/proc/pid/fd` contents.
//
// Return a vector of `(RawFd, XPathBuf)` tuples, where each `RawFd`
// is the file descriptor number and the `XPathBuf` is the path it points to.
//
// Useful for debugging file descriptor leaks.
#[allow(clippy::type_complexity)]
fn proc_pid_fd(pid: Option<Pid>) -> Result<Vec<(RawFd, XPathBuf)>, Errno> {
    let pid = pid.unwrap_or_else(getpid);

    let mut dir = XPathBuf::from("/proc");
    dir.push_pid(pid);
    dir.push(b"fd");

    #[allow(clippy::disallowed_methods)]
    let dir = open(
        &dir,
        OFlag::O_RDONLY | OFlag::O_DIRECTORY | OFlag::O_CLOEXEC,
        Mode::empty(),
    )?;

    let mut dot = 0u8;
    let mut res = vec![];
    loop {
        let mut entries = match getdents64(&dir, DIRENT_BUF_SIZE) {
            Ok(entries) => entries,
            Err(Errno::ECANCELED) => break, // EOF or empty directory
            Err(errno) => return Err(errno),
        };

        for entry in &mut entries {
            if dot < 2 && entry.is_dot() {
                dot += 1;
                continue;
            }
            let fd = parse_fd(XPath::from_bytes(entry.name_bytes()))?;
            let target = readlinkat(&dir, entry.name_bytes()).map(XPathBuf::from)?;

            res.push((fd, target));
        }
    }

    Ok(res)
}
