//
// Syd: rock-solid application kernel
// src/workers/mod.rs: Worker threads implementation
//
// Copyright (c) 2024, 2025 Ali Polatel <alip@chesswob.org>
// Based in part upon rusty_pool which is:
//     Copyright (c) Robin Friedli <robinfriedli@icloud.com>
//     SPDX-License-Identifier: Apache-2.0
//
// SPDX-License-Identifier: GPL-3.0

use std::{
    collections::{btree_map::Entry, BTreeMap, HashMap},
    option::Option,
    os::fd::{AsRawFd, OwnedFd, RawFd},
    sync::{
        atomic::{AtomicUsize, Ordering},
        Arc, Mutex, RwLock,
    },
};

use nix::{
    errno::Errno,
    sys::{
        epoll::{Epoll, EpollFlags},
        socket::UnixAddr,
    },
    unistd::{gettid, Pid},
};

use crate::{
    cache::{
        signal_map_new, sys_interrupt_map_new, sys_result_map_new, ExecResult, SignalMap,
        SigreturnResult, SysInterrupt, SysInterruptMap, SysResultMap,
    },
    compat::epoll_ctl_safe,
    config::*,
    elf::ExecutableFile,
    fs::{pidfd_open, seccomp_notify_id_valid, CanonicalPath, PIDFD_THREAD},
    hash::SydRandomState,
    hook::RemoteProcess,
    proc::proc_tgid,
    ScmpNotifReq, SydMemoryMap, SydSigSet,
};

// syd_aes: Encryptor helper thread
pub(crate) mod aes;
// syd_int: Interrupter helper thread
pub(crate) mod int;
// syd_ipc: IPC thread
pub(crate) mod ipc;
// syd_emu: Main worker threads
pub(crate) mod emu;

/// A cache for worker threads.
#[derive(Debug)]
pub(crate) struct WorkerCache<'a> {
    // Shared epoll instance
    pub(crate) poll: Arc<Epoll>,
    // Seccomp-notify fd
    pub(crate) scmp: RawFd,
    // Signal handlers map
    pub(crate) signal_map: SignalMap,
    // System call interrupt map
    pub(crate) sysint_map: SysInterruptMap,
    // System call result map
    pub(crate) sysres_map: SysResultMap<'a>,
}

impl<'a> WorkerCache<'a> {
    pub(crate) fn new(poll: Arc<Epoll>, scmp: RawFd) -> Self {
        Self {
            poll,
            scmp,
            signal_map: signal_map_new(),
            sysint_map: sys_interrupt_map_new(),
            sysres_map: sys_result_map_new(),
        }
    }

    // Increment count of handled signals.
    pub(crate) fn inc_sig_handle(&self, request_tgid: Pid) {
        let mut map = self
            .signal_map
            .sig_handle
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        map.entry(request_tgid)
            .and_modify(|v| *v = v.saturating_add(1))
            .or_insert(1);
        // let count = *count;
        drop(map);

        /*
        debug!("ctx": "count_signal",
            "msg": format!("forwarded {count} signals to TGID:{request_tgid}"),
            "pid": request_tgid.as_raw());
        */
    }

    // Decrement count of handled signals, return true if decremented, false if zero.
    #[allow(clippy::cognitive_complexity)]
    pub(crate) fn dec_sig_handle(&self, request_tgid: Pid) -> bool {
        let mut is_dec = false;

        let mut map = self
            .signal_map
            .sig_handle
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        if let Entry::Occupied(mut entry) = map.entry(request_tgid) {
            let count = entry.get_mut();

            /*
            debug!(
                "ctx": "count_signal",
                "msg": format!("returned from one of {count} signals for TGID:{request_tgid}"),
                "pid": request_tgid.as_raw()
            );
            */

            *count = count.saturating_sub(1);
            is_dec = true;

            if *count == 0 {
                let _ = entry.remove();
            }
        } /* else {
              debug!(
                  "ctx": "count_signal",
                  "msg": format!("returned from unknown signal for TGID:{request_tgid}"),
                  "pid": request_tgid.as_raw()
              );
          }*/

        is_dec
    }

    // Delete a TGID from the signal handle map.
    pub(crate) fn retire_sig_handle(&self, tgid: Pid) {
        let mut map = self
            .signal_map
            .sig_handle
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        map.remove(&tgid);
    }

    // Record a chdir result.
    pub(crate) fn add_chdir<'b>(&'b self, process: RemoteProcess, path: CanonicalPath<'a>) {
        self.sysres_map
            .trace_chdir
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(process, path);
    }

    // Query, remove and return a chdir result.
    #[allow(clippy::type_complexity)]
    pub(crate) fn get_chdir<'b>(&'b self, pid: Pid) -> Option<(RemoteProcess, CanonicalPath<'a>)> {
        let p = RemoteProcess {
            pid,
            pid_fd: libc::AT_FDCWD,
        };

        self.sysres_map
            .trace_chdir
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove_entry(&p)
    }

    // Record an error result.
    pub(crate) fn add_error(&self, process: RemoteProcess, errno: Option<Errno>) {
        self.sysres_map
            .trace_error
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(process, errno);
    }

    // Query, remove and return a error result.
    #[allow(clippy::type_complexity)]
    pub(crate) fn get_error(&self, pid: Pid) -> Option<(RemoteProcess, Option<Errno>)> {
        let p = RemoteProcess {
            pid,
            pid_fd: libc::AT_FDCWD,
        };

        self.sysres_map
            .trace_error
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove_entry(&p)
    }

    // Record a execv result.
    #[allow(clippy::too_many_arguments)]
    pub(crate) fn add_exec(
        &self,
        process: RemoteProcess,
        file: ExecutableFile,
        arch: u32,
        ip: u64,
        sp: u64,
        args: [u64; 6],
        ip_mem: Option<[u8; 64]>,
        sp_mem: Option<[u8; 64]>,
        memmap: Option<Vec<SydMemoryMap>>,
    ) {
        let result = ExecResult {
            file,
            arch,
            ip,
            sp,
            args,
            ip_mem,
            sp_mem,
            memmap,
        };

        self.sysres_map
            .trace_execv
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(process, result);
    }

    // Query, remove and return a exec result.
    pub(crate) fn get_exec(&self, pid: Pid) -> Option<(RemoteProcess, ExecResult)> {
        let p = RemoteProcess {
            pid,
            pid_fd: libc::AT_FDCWD,
        };

        self.sysres_map
            .trace_execv
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove_entry(&p)
    }

    // Record a sigreturn entry.
    #[allow(clippy::too_many_arguments)]
    pub(crate) fn add_sigreturn(
        &self,
        process: RemoteProcess,
        is_realtime: bool,
        ip: u64,
        sp: u64,
        args: [u64; 6],
        ip_mem: Option<[u8; 64]>,
        sp_mem: Option<[u8; 64]>,
    ) {
        let result = SigreturnResult {
            is_realtime,
            ip,
            sp,
            args,
            ip_mem,
            sp_mem,
        };
        self.sysres_map
            .trace_sigret
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(process, result);
    }

    // Query, remove and return a sigreturn entry info.
    pub(crate) fn get_sigreturn(&self, pid: Pid) -> Option<(RemoteProcess, SigreturnResult)> {
        let p = RemoteProcess {
            pid,
            pid_fd: libc::AT_FDCWD,
        };

        self.sysres_map
            .trace_sigret
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove_entry(&p)
    }

    // Add a restarting signal.
    pub(crate) fn add_sig_restart(&self, request_tgid: Pid, sig: libc::c_int) {
        let mut map = self
            .sysint_map
            .sig_restart
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        if let Some(set) = map.get_mut(&request_tgid) {
            set.add(sig);
            return;
        }

        let mut set = SydSigSet::new(0);
        set.add(sig);

        map.insert(request_tgid, set);
    }

    // Delete a restarting signal.
    pub(crate) fn del_sig_restart(&self, request_tgid: Pid, sig: libc::c_int) {
        let mut map = self
            .sysint_map
            .sig_restart
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        let set_nil = if let Some(set) = map.get_mut(&request_tgid) {
            set.del(sig);
            set.is_empty()
        } else {
            return;
        };

        if set_nil {
            map.remove(&request_tgid);
        }
    }

    // Delete a TGID from the signal restart map.
    pub(crate) fn retire_sig_restart(&self, tgid: Pid) {
        let mut map = self
            .sysint_map
            .sig_restart
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        map.remove(&tgid);
    }

    // Add a blocked syscall.
    #[allow(clippy::cast_possible_wrap)]
    pub(crate) fn add_sys_block(
        &self,
        request: ScmpNotifReq,
        ignore_restart: bool,
    ) -> Result<(), Errno> {
        let handler_tid = gettid();
        let request_tgid = proc_tgid(Pid::from_raw(request.pid as libc::pid_t))?;
        let interrupt = SysInterrupt::new(request, request_tgid, handler_tid, ignore_restart)?;

        let (ref lock, ref cvar) = *self.sysint_map.sys_block;
        let mut map = lock.lock().unwrap_or_else(|err| err.into_inner());

        map.insert(request.id, interrupt);

        cvar.notify_one();

        Ok(())
    }

    // Remove a blocked fifo.
    pub(crate) fn del_sys_block(&self, request_id: u64) {
        let (ref lock, ref _cvar) = *self.sysint_map.sys_block;
        let mut map = lock.lock().unwrap_or_else(|err| err.into_inner());
        map.remove(&request_id);
    }
}

// The absolute maximum number of workers. This corresponds to the
// maximum value that can be stored within half the bits of usize, as
// two counters (total workers and busy workers) are stored in one
// AtomicUsize.
const BITS: usize = std::mem::size_of::<usize>() * 8;
const MAX_SIZE: usize = (1 << (BITS / 2)) - 1;

const WORKER_BUSY_MASK: usize = MAX_SIZE;
const INCREMENT_TOTAL: usize = 1 << (BITS / 2);
const INCREMENT_BUSY: usize = 1;

/// 1. Struct containing data shared between workers.
/// 2. Struct that stores and handles an `AtomicUsize` that stores the
///    total worker count in the higher half of bits and the busy worker
///    count in the lower half of bits. This allows to to increment /
///    decrement both counters in a single atomic operation.
#[derive(Default)]
pub(crate) struct WorkerData(pub(crate) AtomicUsize);

impl WorkerData {
    /*
    fn increment_both(&self) -> (usize, usize) {
        let old_val = self
            .0
            .fetch_add(INCREMENT_TOTAL | INCREMENT_BUSY, Ordering::Relaxed);
        Self::split(old_val)
    }
    */

    pub(crate) fn decrement_both(&self) -> (usize, usize) {
        let old_val = self
            .0
            .fetch_sub(INCREMENT_TOTAL | INCREMENT_BUSY, Ordering::Relaxed);
        Self::split(old_val)
    }

    pub(crate) fn increment_worker_total(&self) -> usize {
        let old_val = self.0.fetch_add(INCREMENT_TOTAL, Ordering::Relaxed);
        Self::total(old_val)
    }

    #[allow(dead_code)]
    pub(crate) fn decrement_worker_total(&self) -> usize {
        let old_val = self.0.fetch_sub(INCREMENT_TOTAL, Ordering::Relaxed);
        Self::total(old_val)
    }

    pub(crate) fn increment_worker_busy(&self) -> usize {
        let old_val = self.0.fetch_add(INCREMENT_BUSY, Ordering::Relaxed);
        Self::busy(old_val)
    }

    pub(crate) fn decrement_worker_busy(&self) -> usize {
        let old_val = self.0.fetch_sub(INCREMENT_BUSY, Ordering::Relaxed);
        Self::busy(old_val)
    }

    /*
    fn get_total_count(&self) -> usize {
        Self::total(self.0.load(Ordering::Relaxed))
    }

    fn get_busy_count(&self) -> usize {
        Self::busy(self.0.load(Ordering::Relaxed))
    }
    */

    #[inline]
    pub(crate) fn split(val: usize) -> (usize, usize) {
        let total_count = val >> (BITS / 2);
        let busy_count = val & WORKER_BUSY_MASK;
        (total_count, busy_count)
    }

    #[inline]
    fn total(val: usize) -> usize {
        val >> (BITS / 2)
    }

    #[inline]
    fn busy(val: usize) -> usize {
        val & WORKER_BUSY_MASK
    }
}

// [inode,path] map of unix binds
//
// SAFETY:
// /proc/net/unix only gives inode information,
// and does not include information on device id
// or mount id so unfortunately we cannot check
// for that here.
pub(crate) type BindMap = Arc<RwLock<HashMap<u64, UnixAddr, SydRandomState>>>;

/// PidFd map, used to store pid file descriptors.
#[derive(Debug)]
#[allow(clippy::type_complexity)]
pub struct PidFdMap {
    /// Inner PidFd concurrent dash map.
    pub pidfd: Arc<Mutex<BTreeMap<Pid, OwnedFd>>>,
    /// A reference to the WorkerCache to clean relevant data on process exit.
    pub(crate) cache: Arc<WorkerCache<'static>>,
}

impl PidFdMap {
    /// Create a new PidFd map.
    pub(crate) fn new(cache: Arc<WorkerCache<'static>>) -> Self {
        Self {
            cache,
            pidfd: Arc::new(Mutex::new(BTreeMap::new())),
        }
    }

    /*
    #[inline]
    pub(crate) fn get_pidfd(&self, pid: Pid) -> Option<RawFd> {
        self.pidfd
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .get(&pid)
            .map(|fd| fd.as_raw_fd())
    }

    #[inline]
    pub(crate) fn add_pidfd(&self, pid: Pid, pid_fd: OwnedFd) {
        self.pidfd
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(pid, pid_fd);
    }
    */

    #[inline]
    pub(crate) fn del_pidfd(&self, pid: Pid) {
        // Retire TGID from signal maps.
        self.cache.retire_sig_handle(pid);
        self.cache.retire_sig_restart(pid);

        // Remove preexisting error record for pid.
        let _ = self.cache.get_error(pid);

        // Remove preexisting chdir record for pid.
        let _ = self.cache.get_chdir(pid);

        // Remove preexisting exec record for pid.
        let _ = self.cache.get_exec(pid);

        // Remove preexisting sigreturn record for pid.
        let _ = self.cache.get_sigreturn(pid);

        // Finally, remove the PidFd from the map.
        self.pidfd
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove(&pid);
    }

    pub(crate) fn pidfd_open(
        &self,
        request_pid: Pid,
        tgid: bool,
        request_id: Option<u64>,
    ) -> Result<RawFd, Errno> {
        let mut pidfd = self.pidfd.lock().unwrap_or_else(|err| err.into_inner());
        if let Some(fd) = pidfd.get(&request_pid) {
            return Ok(fd.as_raw_fd());
        }

        // Use PIDFD_THREAD if available.
        let (pid, flags) = if *HAVE_PIDFD_THREAD {
            (request_pid, PIDFD_THREAD)
        } else if tgid {
            (request_pid, 0)
        } else {
            (proc_tgid(request_pid)?, 0)
        };

        // Open the PIDFd.
        let pid_fd = pidfd_open(pid, flags)?;

        if let Some(request_id) = request_id {
            // SAFETY:
            // 1. Validate the PIDFd by validating the request ID if submitted.
            // 2. EAGAIN|EINTR is handled.
            // 3. ENOENT means child died mid-way.
            if seccomp_notify_id_valid(self.cache.scmp, request_id).is_err() {
                return Err(Errno::ESRCH);
            }
        }

        // SAFETY: Add the PIDFd to the epoll instance.
        //
        // Note: EPOLLEXCLUSIVE|EPOLLONESHOT is invalid!
        #[allow(clippy::cast_sign_loss)]
        let event = libc::epoll_event {
            events: (EpollFlags::EPOLLIN | EpollFlags::EPOLLONESHOT).bits() as u32,
            u64: request_pid.as_raw() as u64,
        };

        let pid_fd_raw = pid_fd.as_raw_fd();

        // SAFETY: In epoll(7) we trust.
        #[allow(clippy::disallowed_methods)]
        epoll_ctl_safe(&self.cache.poll.0, pid_fd_raw, Some(event))
            .expect("BUG: Failed to add PidFd to Epoll!");

        pidfd.insert(request_pid, pid_fd);

        Ok(pid_fd_raw)
    }
}
