Clean up spawn as root and CLI interface

This commit is contained in:
2023-05-22 20:48:32 -04:00
parent e62f2924b0
commit 311cd49c06
8 changed files with 199 additions and 163 deletions

6
antidrift.desktop Normal file
View File

@@ -0,0 +1,6 @@
[Desktop Entry]
Name=AntiDrift
Exec=antidrift
Terminal=false
Type=Application
StartupNotify=false

View File

@@ -1,73 +0,0 @@
import logging
import shutil
import sys
import os
import signal
import subprocess
from pathlib import Path
import antidrift.client as client
from antidrift.config import Config
from antidrift.daemon import AntiDriftDaemon
from dbus.mainloop.glib import DBusGMainLoop
DBusGMainLoop(set_as_default=True)
signal.signal(signal.SIGINT, signal.SIG_DFL)
def init_logging(log_file: Path):
class DuplicateFilter(logging.Filter):
def filter(self, record) -> bool:
current_log = (record.module, record.levelno, record.msg)
if current_log != getattr(self, "last_log", None):
self.last_log = current_log
return True
return False
format_str = '[bold pale_green3]%(asctime)s[/bold pale_green3] | ' \
'[light_steel_blue]%(levelname)-8s[/light_steel_blue] | ' \
'%(message)s'
logging.basicConfig(filename=log_file,
format=format_str,
datefmt='%a %H:%M:%S',
encoding='utf-8',
level=logging.DEBUG)
logger = logging.getLogger()
logger.addFilter(DuplicateFilter())
def check_for_xdotool():
""" Check if xdotool is in path and exit if not """
result = shutil.which("xdotool")
if not result:
logging.critical("Please install xdotool")
sys.exit(1)
def main_daemon(config):
init_logging(config.daemon_log_file)
add = AntiDriftDaemon(config)
add.run()
def main() -> None:
""" Main routine that dispatches to client or daemon """
config = Config.load(os.path.expanduser("~/.config/antidrift.yaml"))
check_for_xdotool()
if client.antidrift_daemon_is_running():
init_logging(config.client_log_file)
client.client_mode(config)
elif len(sys.argv) == 1:
if os.geteuid() == 0:
newpid = os.fork()
if newpid == 0:
main_daemon(config)
else:
cmd = ["sudo", "antidrift"]
subprocess.Popen(cmd)
elif len(sys.argv) == 2 and sys.argv[1] == '--daemon_user':
main_daemon(config)
else:
print("ad not running")
if __name__ == "__main__":
main()

View File

@@ -1,25 +1,13 @@
from antidrift.config import Config
from typing import Optional
from rich import print
import argparse
import time import time
import dbus import antidrift.daemon
import dbus.service from antidrift.config import Config
from argparse import Namespace
from rich import print
def get_dbus_interface() -> Optional[dbus.Interface]:
try:
bus = dbus.SessionBus()
bus_object = bus.get_object("com.antidrift", "/com/antidrift")
interface = dbus.Interface(bus_object, "com.antidrift")
return interface
except dbus.exceptions.DBusException:
return None
def antidrift_daemon_is_running() -> bool: def antidrift_daemon_is_running() -> bool:
""" Check if AntiDrift is running via the DBUS """ """Check if AntiDrift is running via the DBUS"""
interface = get_dbus_interface() interface = antidrift.daemon.get_dbus_interface()
if interface is None: if interface is None:
return False return False
reply = interface.status() reply = interface.status()
@@ -29,7 +17,7 @@ def antidrift_daemon_is_running() -> bool:
def tailf(config): def tailf(config):
with open(config.daemon_log_file, 'r') as f: with open(config.daemon_log_file, "r") as f:
f.seek(0, 2) f.seek(0, 2)
while True: while True:
line = f.readline() line = f.readline()
@@ -39,21 +27,12 @@ def tailf(config):
print(line.strip()) print(line.strip())
def client_mode(config: Config): def run(args: Namespace, config: Config):
parser = argparse.ArgumentParser(description='AntiDrift CLI.') interface = antidrift.daemon.get_dbus_interface()
parser.add_argument('--start', metavar='whiteblock', nargs='+', reply = "🟡 ad daemon active but no command"
help='start whiteblocks') if interface is None:
parser.add_argument('--stop', action='store_true', reply = "🔴 ad inactive"
help='stop session') elif args.start:
parser.add_argument('--schedule', metavar='blackblock',
help='schedule blackblock')
parser.add_argument('--status', action='store_true',
help='get status from daemon')
parser.add_argument('--tailf', action='store_true',
help='tail -f log file')
args = parser.parse_args()
interface = get_dbus_interface()
if args.start:
reply = interface.start(args.start) reply = interface.start(args.start)
elif args.stop: elif args.stop:
reply = interface.stop() reply = interface.stop()
@@ -63,6 +42,4 @@ def client_mode(config: Config):
tailf(config) tailf(config)
elif args.status: elif args.status:
reply = interface.status() reply = interface.status()
else:
reply = '[red]no command[/red]'
print(reply) print(reply)

View File

@@ -12,7 +12,7 @@ class Block(BaseModel):
delay: int = 0 delay: int = 0
class Config: class Config:
extra = 'forbid' extra = "forbid"
class Config(BaseModel): class Config(BaseModel):
@@ -25,7 +25,7 @@ class Config(BaseModel):
enforce_delay_ms: int = 5000 enforce_delay_ms: int = 5000
class Config: class Config:
extra = 'forbid' extra = "forbid"
@classmethod @classmethod
def load(cls, config_file: str) -> Config: def load(cls, config_file: str) -> Config:
@@ -43,4 +43,4 @@ class State(BaseModel):
inactive_blackblocks: List[Block] = [] inactive_blackblocks: List[Block] = []
class Config: class Config:
extra = 'forbid' extra = "forbid"

View File

@@ -4,12 +4,13 @@ import pwd
import re import re
import sys import sys
import antidrift.xwindow as xwindow import antidrift.xwindow as xwindow
from antidrift.config import Config, State, Block
from gi.repository import GLib, Gio
from typing import List
import dbus import dbus
import dbus.service import dbus.service
from antidrift.config import Config, State, Block
from gi.repository import GLib, Gio
from typing import List, Optional
BUS_NAME = "com.antidrift" BUS_NAME = "com.antidrift"
IFACE = "com.antidrift" IFACE = "com.antidrift"
OPATH = "/com/antidrift" OPATH = "/com/antidrift"
@@ -19,12 +20,21 @@ def reload_callback(m, f, o, event):
filename = f.get_basename() filename = f.get_basename()
m = f"[dark_orange3]Restart after change in '{filename}'.[/dark_orange3]" m = f"[dark_orange3]Restart after change in '{filename}'.[/dark_orange3]"
logging.warning(m) logging.warning(m)
os.execv(sys.executable, ['python3'] + sys.argv) os.execv(sys.executable, ["python3"] + sys.argv)
def get_dbus_interface() -> Optional[dbus.Interface]:
try:
bus = dbus.SessionBus()
bus_object = bus.get_object(BUS_NAME, OPATH)
interface = dbus.Interface(bus_object, IFACE)
return interface
except dbus.exceptions.DBusException:
return None
class AntiDriftDaemon(dbus.service.Object): class AntiDriftDaemon(dbus.service.Object):
def __init__(self, config: Config): def __init__(self, config: Config):
user_name = os.environ.get("SUDO_USER", pwd.getpwuid(os.getuid()).pw_name) user_name = os.environ.get("SUDO_USER", pwd.getpwuid(os.getuid()).pw_name)
user_uid = pwd.getpwnam(user_name).pw_uid user_uid = pwd.getpwnam(user_name).pw_uid
euid = os.geteuid() euid = os.geteuid()
@@ -44,10 +54,10 @@ class AntiDriftDaemon(dbus.service.Object):
self.state = State( self.state = State(
active_blackblocks=self.config.blackblocks, active_blackblocks=self.config.blackblocks,
active_whiteblocks=[], active_whiteblocks=[],
inactive_blackblocks=[]) inactive_blackblocks=[],
)
@dbus.service.method(dbus_interface=IFACE, @dbus.service.method(dbus_interface=IFACE, in_signature="as", out_signature="s")
in_signature="as", out_signature="s")
def start(self, whiteblocks: List[str]) -> str: def start(self, whiteblocks: List[str]) -> str:
self.reset_block_state() self.reset_block_state()
all_whiteblocks = {wb.name: wb for wb in self.config.whiteblocks} all_whiteblocks = {wb.name: wb for wb in self.config.whiteblocks}
@@ -59,7 +69,7 @@ class AntiDriftDaemon(dbus.service.Object):
else: else:
fail_blocks.append(block_name) fail_blocks.append(block_name)
if success_wbs: if success_wbs:
wbs = ', '.join(success_wbs) wbs = ", ".join(success_wbs)
r = f"Start whiteblocks [sky_blue3]{wbs}[/sky_blue3]." r = f"Start whiteblocks [sky_blue3]{wbs}[/sky_blue3]."
logging.info(r) logging.info(r)
else: else:
@@ -69,10 +79,9 @@ class AntiDriftDaemon(dbus.service.Object):
logging.warning(m) logging.warning(m)
return r return r
@dbus.service.method(dbus_interface=IFACE, @dbus.service.method(dbus_interface=IFACE, in_signature="s", out_signature="s")
in_signature="s", out_signature="s")
def schedule(self, blackblock_name: str) -> str: def schedule(self, blackblock_name: str) -> str:
""" Schedule blackblock based if it has a non-zero timeout value. """ """Schedule blackblock based if it has a non-zero timeout value."""
all_blackblocks = {bb.name: bb for bb in self.config.blackblocks} all_blackblocks = {bb.name: bb for bb in self.config.blackblocks}
if blackblock_name not in all_blackblocks: if blackblock_name not in all_blackblocks:
m = f"No blackblock [red3]{blackblock_name}[/red3]." m = f"No blackblock [red3]{blackblock_name}[/red3]."
@@ -87,41 +96,42 @@ class AntiDriftDaemon(dbus.service.Object):
def allow(): def allow():
self.allow_blackblock(blackblock) self.allow_blackblock(blackblock)
delay_ms = blackblock.delay * 1000 * 60 delay_ms = blackblock.delay * 1000 * 60
GLib.timeout_add(delay_ms, allow) GLib.timeout_add(delay_ms, allow)
m = f"Scheduled [sky_blue3]{blackblock_name}[/sky_blue3] in {blackblock.delay} minutes." m = f"Scheduled [sky_blue3]{blackblock_name}[/sky_blue3] in {blackblock.delay} minutes."
logging.info(m) logging.info(m)
return m return m
@dbus.service.method(dbus_interface=IFACE, @dbus.service.method(dbus_interface=IFACE, in_signature="", out_signature="s")
in_signature="", out_signature="s")
def stop(self) -> str: def stop(self) -> str:
self.reset_block_state() self.reset_block_state()
m = 'Blacklist only mode.' m = "Blacklist only mode."
logging.info(m) logging.info(m)
return m return m
@dbus.service.method(dbus_interface=IFACE, @dbus.service.method(dbus_interface=IFACE, in_signature="", out_signature="s")
in_signature="", out_signature="s")
def status(self) -> str: def status(self) -> str:
white_active = bool(self.state.active_whiteblocks) white_active = bool(self.state.active_whiteblocks)
black_active = bool(self.state.active_blackblocks) black_active = bool(self.state.active_blackblocks)
m = 'ad ' m = "🟢 ad "
inactive_bbs = ' '.join(map(lambda b: "-" + b.name, self.state.inactive_blackblocks)) inactive_bbs = " ".join(
map(lambda b: "-" + b.name, self.state.inactive_blackblocks)
)
match (white_active, black_active): match (white_active, black_active):
case (True, _): case (True, _):
m += 'wb: ' m += "wb: "
m += ' '.join(map(lambda b: b.name, self.state.active_whiteblocks)) m += " ".join(map(lambda b: b.name, self.state.active_whiteblocks))
if inactive_bbs: if inactive_bbs:
m += ' ' m += " "
m += inactive_bbs m += inactive_bbs
case (False, True): case (False, True):
m += 'bb' m += "bb"
if inactive_bbs: if inactive_bbs:
m += ': ' m += ": "
m += inactive_bbs m += inactive_bbs
case _: case _:
m = 'inactive' m = "inactive"
return m return m
def allow_blackblock(self, blackblock: Block): def allow_blackblock(self, blackblock: Block):
@@ -132,15 +142,20 @@ class AntiDriftDaemon(dbus.service.Object):
m = f"Blackblock [sky_blue3]{blackblock.name}[/sky_blue3] is now allowed." m = f"Blackblock [sky_blue3]{blackblock.name}[/sky_blue3] is now allowed."
logging.info(m) logging.info(m)
def run(self): def run(self, debug: bool = False):
def _enforce(): def _enforce():
self.enforce() self.enforce()
GLib.timeout_add(self.config.polling_cycle_ms, _enforce) GLib.timeout_add(self.config.polling_cycle_ms, _enforce)
# autorestart on file change for development # autorestart on file change for development
monitors = [] monitors = []
files = ["antidrift.py", "antidrift/daemon.py", "antidrift/client.py", files = [
"antidrift/config.py"] "antidrift.py",
"antidrift/daemon.py",
"antidrift/client.py",
"antidrift/config.py",
]
if debug:
for filename in files: for filename in files:
gio_file = Gio.File.new_for_path(filename) gio_file = Gio.File.new_for_path(filename)
monitor = gio_file.monitor_file(Gio.FileMonitorFlags.NONE, None) monitor = gio_file.monitor_file(Gio.FileMonitorFlags.NONE, None)
@@ -178,7 +193,7 @@ def window_is_blocked(state: State, silent: bool = False) -> bool:
return False return False
def keyword_matches_window(keyword: str, window: xwindow.XWindow): def keyword_matches_window(keyword: str, window: xwindow.XWindow):
if keyword.startswith('/') and keyword.endswith('/'): if keyword.startswith("/") and keyword.endswith("/"):
try: try:
r = re.compile(keyword[1:-1], re.IGNORECASE) r = re.compile(keyword[1:-1], re.IGNORECASE)
if r.findall(window.name): if r.findall(window.name):
@@ -186,7 +201,7 @@ def window_is_blocked(state: State, silent: bool = False) -> bool:
else: else:
return False return False
except re.error: except re.error:
m = f'Invalid regex [red3]{keyword}[/red3].' m = f"Invalid regex [red3]{keyword}[/red3]."
logging.warning(m) logging.warning(m)
return False return False
else: else:
@@ -204,8 +219,10 @@ def window_is_blocked(state: State, silent: bool = False) -> bool:
elif keyword_matches_window(k, window): elif keyword_matches_window(k, window):
if not silent: if not silent:
xwindow.notify(f"{window.name[:30]} blocked by {b.name}.") xwindow.notify(f"{window.name[:30]} blocked by {b.name}.")
logging.warning(f"[red]{window.name[:50]}[/red] " logging.warning(
f"blocked by [red]{b.name}[/red].") f"[red]{window.name[:50]}[/red] "
f"blocked by [red]{b.name}[/red]."
)
return True return True
if not whiteblocks: if not whiteblocks:
if not silent: if not silent:
@@ -215,9 +232,17 @@ def window_is_blocked(state: State, silent: bool = False) -> bool:
for k in w.keywords: for k in w.keywords:
if keyword_matches_window(k, window): if keyword_matches_window(k, window):
if not silent: if not silent:
logging.debug(f"[pale_green3]{window.name[:30]}[/pale_green3] " logging.debug(
f"allowed by [sky_blue3]{w.name}[/sky_blue3].") f"[pale_green3]{window.name[:30]}[/pale_green3] "
f"allowed by [sky_blue3]{w.name}[/sky_blue3]."
)
return False return False
if not silent: if not silent:
xwindow.notify(f"'{window.name[:30]}' not on any whiteblock.") xwindow.notify(f"'{window.name[:30]}' not on any whiteblock.")
return True return True
def run(config: Config):
add = AntiDriftDaemon(config)
xwindow.notify(f"AntiDrift.run()")
add.run()

View File

@@ -1,4 +1,3 @@
def is_window_blocked(window_name: str, blocked: List[re.Pattern]) -> bool: def is_window_blocked(window_name: str, blocked: List[re.Pattern]) -> bool:
for b in blocked: for b in blocked:
if b.findall(window_name): if b.findall(window_name):
@@ -23,11 +22,8 @@ def kill_sequence(blocked: List[re.Pattern]) -> None:
def notify(message: str) -> None: def notify(message: str) -> None:
""" Notify user via the Xorg notify-send command. """ """Notify user via the Xorg notify-send command."""
env = { env = {**os.environ, "DBUS_SESSION_BUS_ADDRESS": "unix:path=/run/user/1000/bus"}
**os.environ,
"DBUS_SESSION_BUS_ADDRESS": "unix:path=/run/user/1000/bus"
}
user = env.get("SUDO_USER", None) user = env.get("SUDO_USER", None)
if user is None: if user is None:
cmd = ["notify-send", message] cmd = ["notify-send", message]
@@ -141,7 +137,7 @@ def write_window_names(config: Config, window_names: Set[str]) -> None:
def main() -> None: def main() -> None:
""" Run main_root as root except while debugging. """ """Run main_root as root except while debugging."""
config_path = "~/.config/aw-focus/config.json" config_path = "~/.config/aw-focus/config.json"
config = Config.load_config(config_path) config = Config.load_config(config_path)
@@ -167,5 +163,3 @@ def main_root(config: Config) -> None:
kill_window_if_blocked(blocked) kill_window_if_blocked(blocked)
if config.enforce_aw_commit: if config.enforce_aw_commit:
enforce_aw_commit() enforce_aw_commit()

View File

@@ -36,7 +36,7 @@ class XWindow:
def notify(message: str) -> None: def notify(message: str) -> None:
""" Notify user via the Xorg notify-send command. """ """Notify user via the Xorg notify-send command."""
logging.debug(f"{message} - [grey]notify[/grey]") logging.debug(f"{message} - [grey]notify[/grey]")
env = dict(os.environ) env = dict(os.environ)
user = env.get("SUDO_USER", None) user = env.get("SUDO_USER", None)

107
main.py Normal file
View File

@@ -0,0 +1,107 @@
import logging
import shutil
import sys
import os
import signal
import subprocess
import argparse
import rich
from rich.logging import RichHandler
from pathlib import Path
import antidrift.client
import antidrift.daemon
from antidrift.config import Config
from dbus.mainloop.glib import DBusGMainLoop
DBusGMainLoop(set_as_default=True)
signal.signal(signal.SIGINT, signal.SIG_DFL)
def get_args():
parser = argparse.ArgumentParser(description="AntiDrift CLI.")
parser.add_argument("--daemon", action="store_true", help="run daemon")
parser.add_argument("--status", action="store_true", help="get status from daemon")
parser.add_argument("--tailf", action="store_true", help="tail -f log file")
parser.add_argument(
"--start", metavar="whiteblock", nargs="+", help="start whiteblocks"
)
parser.add_argument("--stop", action="store_true", help="stop session")
parser.add_argument("--schedule", metavar="blackblock", help="schedule blackblock")
args = parser.parse_args()
return args
def init_logging(log_file: Path, dev_mode: bool = False):
class DuplicateFilter(logging.Filter):
def filter(self, record) -> bool:
current_log = (record.module, record.levelno, record.msg)
if current_log != getattr(self, "last_log", None):
self.last_log = current_log
return True
return False
if dev_mode:
format_str = "%(message)s" # RichHandler will handle the formatting
logging.basicConfig(
level=logging.DEBUG,
format=format_str,
datefmt="%a %H:%M:%S",
handlers=[RichHandler(rich_tracebacks=True, markup=True)],
)
else:
format_str = (
"[bold pale_green3]%(asctime)s[/bold pale_green3] | "
"[light_steel_blue]%(levelname)-8s[/light_steel_blue] | "
"%(message)s"
)
logging.basicConfig(
filename=log_file,
format=format_str,
datefmt="%a %H:%M:%S",
encoding="utf-8",
level=logging.DEBUG,
)
logger = logging.getLogger()
logger.addFilter(DuplicateFilter())
def check_for_xdotool():
"""Check if xdotool is in path and exit if not"""
result = shutil.which("xdotool")
if not result:
logging.critical("Please install xdotool")
sys.exit(1)
def main_daemon(config):
if os.geteuid() == 0:
newpid = os.fork()
if newpid == 0:
init_logging(config.daemon_log_file)
antidrift.daemon.run(config)
else:
if sys.argv[0] == "antidrift":
cmd = ["sudo", "antidrift", "--daemon"]
subprocess.Popen(cmd)
else:
init_logging(config.daemon_log_file, dev_mode=True)
logging.warning("[red]Running in development mode.[/red]")
antidrift.daemon.run(config)
def main() -> None:
"""Main routine that dispatches to client or daemon"""
config = Config.load(os.path.expanduser("~/.config/antidrift.yaml"))
check_for_xdotool()
args = get_args()
if args.daemon:
main_daemon(config)
else:
antidrift.client.run(args, config)
if __name__ == "__main__":
main()