Skip to main content

entracte_lib/
dnd.rs

1pub fn is_active() -> bool {
2    #[cfg(target_os = "macos")]
3    return macos::check();
4    #[cfg(target_os = "windows")]
5    return windows::check();
6    #[cfg(not(any(target_os = "macos", target_os = "windows")))]
7    return false;
8}
9
10#[cfg(target_os = "macos")]
11mod macos {
12    pub fn check() -> bool {
13        let Some(home) = std::env::var_os("HOME") else {
14            return false;
15        };
16        let path = std::path::Path::new(&home).join("Library/DoNotDisturb/DB/Assertions.json");
17        let Ok(content) = std::fs::read_to_string(&path) else {
18            return false;
19        };
20        let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&content) else {
21            return false;
22        };
23        parsed
24            .get("data")
25            .and_then(|d| d.as_array())
26            .map(|arr| {
27                arr.iter().any(|entry| {
28                    entry
29                        .get("storeAssertionRecords")
30                        .and_then(|r| r.as_array())
31                        .map(|records| !records.is_empty())
32                        .unwrap_or(false)
33                })
34            })
35            .unwrap_or(false)
36    }
37}
38
39#[cfg(target_os = "windows")]
40mod windows {
41    use std::sync::OnceLock;
42    use windows_sys::Wdk::System::SystemServices::RtlGetVersion;
43    use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleA, GetProcAddress};
44    use windows_sys::Win32::System::SystemInformation::OSVERSIONINFOW;
45
46    const WNF_FOCUS_ASSIST: u64 = 0xA3BC1875_A3BC0875;
47    // Windows 10 1809 (October 2018) was the first build where Focus
48    // Assist exposed its state through this WNF name with the
49    // six-argument NtQueryWnfStateData signature we transmute below.
50    const MIN_SUPPORTED_BUILD: u32 = 17763;
51
52    type NtQueryWnfStateDataFn = unsafe extern "system" fn(
53        *const u64,
54        *const u8,
55        *const u8,
56        *mut u32,
57        *mut u8,
58        *mut u32,
59    ) -> i32;
60
61    fn os_build() -> Option<u32> {
62        let mut info: OSVERSIONINFOW = unsafe { std::mem::zeroed() };
63        info.dwOSVersionInfoSize = std::mem::size_of::<OSVERSIONINFOW>() as u32;
64        let status = unsafe { RtlGetVersion(&mut info) };
65        if status != 0 {
66            return None;
67        }
68        Some(info.dwBuildNumber)
69    }
70
71    fn version_supported() -> bool {
72        static CACHED: OnceLock<bool> = OnceLock::new();
73        *CACHED.get_or_init(|| match os_build() {
74            Some(build) if build >= MIN_SUPPORTED_BUILD => true,
75            Some(build) => {
76                log::info!(
77                    "dnd: Windows build {build} < {MIN_SUPPORTED_BUILD}; \
78                     skipping Focus Assist probe"
79                );
80                false
81            }
82            None => {
83                log::info!("dnd: RtlGetVersion failed; skipping Focus Assist probe");
84                false
85            }
86        })
87    }
88
89    // SAFETY: The signature for `NtQueryWnfStateData` is undocumented but
90    // has been stable across Windows 10 build 17763+ and all Windows 11
91    // builds shipped to date. `version_supported` gates the transmute to
92    // those releases. On older builds, or if `RtlGetVersion` fails, we
93    // return `false` from `check()` without ever calling the symbol.
94    fn query_fn() -> Option<NtQueryWnfStateDataFn> {
95        static CACHED: OnceLock<Option<NtQueryWnfStateDataFn>> = OnceLock::new();
96        *CACHED.get_or_init(|| unsafe {
97            if !version_supported() {
98                return None;
99            }
100            let ntdll = GetModuleHandleA(c"ntdll.dll".as_ptr().cast());
101            if ntdll.is_null() {
102                return None;
103            }
104            let ptr = GetProcAddress(ntdll, c"NtQueryWnfStateData".as_ptr().cast());
105            ptr.map(|p| std::mem::transmute::<_, NtQueryWnfStateDataFn>(p))
106        })
107    }
108
109    pub fn check() -> bool {
110        let Some(query) = query_fn() else {
111            return false;
112        };
113        let state_name = WNF_FOCUS_ASSIST;
114        let mut buffer = [0u8; 4];
115        let mut buffer_size: u32 = buffer.len() as u32;
116        let mut change_stamp: u32 = 0;
117        let status = unsafe {
118            query(
119                &state_name,
120                std::ptr::null(),
121                std::ptr::null(),
122                &mut change_stamp,
123                buffer.as_mut_ptr(),
124                &mut buffer_size,
125            )
126        };
127        if status != 0 || buffer_size < 4 {
128            return false;
129        }
130        let mode = u32::from_le_bytes(buffer);
131        mode > 0
132    }
133}