#!/usr/bin/env python3 from dataclasses import dataclass import sys import subprocess import re import heapq from os import mkdir, path, listdir from argparse import ArgumentParser from typing import Iterable from venv import EnvBuilder parser = ArgumentParser(prog="py-setup", add_help=True) subparsers = parser.add_subparsers(dest="command") new = subparsers.add_parser("new", add_help=True) new.add_argument("name") new.add_argument("--with-venv", metavar="path") install = subparsers.add_parser("install", add_help=True) install.add_argument("-r", "--requirements", default="requirements.txt") install.add_argument("packages", nargs="+") def create_new(args): try: mkdir(args.name) except FileExistsError: if len(listdir(args.name)) > 0: print(f"directory `{args.name}` exists and is not empty, exiting") sys.exit() if args.with_venv is not None: venv_dir = path.join(args.name, args.with_venv) EnvBuilder().create(venv_dir) print(f"created {args.name} at {path.realpath(args.name)}") def grep(source, pattern): pattern = re.compile(pattern) return [line for line in source.split(b"\n") if pattern.search(line)] def pip(command): return subprocess.check_output([sys.executable, "-m", "pip", *command.split(" ")]) def flatten(xss): return [x for xs in xss for x in xs] def take_until(value: bytes, delim: Iterable[bytes]) -> Iterable[bytes]: for byte in value: if byte in delim: return yield byte specifiers = ["~", "=", "!", ">", "<"] @dataclass(frozen=True) class Requirement: value: bytes @property def name(self) -> bytes: return bytes(take_until(self.value, specifiers)) def __str__(self): return self.value.decode("utf-8") def create_requirements(packages) -> Requirement: output = pip("freeze") requirements = [grep(output, bytes(package, "utf-8")) for package in packages] requirements = [req.replace(b"==", b"~=") for req in flatten(requirements)] return [Requirement(req) for req in requirements] def read_requirements(path) -> Requirement: try: with open(path, "rb") as f: lines = [line.replace(b"\n", b"") for line in f.readlines() if line != b"\n"] return [Requirement(line) for line in lines] except: return [] def install_libraries(args): pip(f"install {" ".join(args.packages)}") new_requirements = create_requirements(args.packages) new_req_strs = [str(req) for req in new_requirements] print(f"installed {", ".join(new_req_strs)}") existing_requirements = filter( lambda req: req not in new_requirements, read_requirements(args.requirements) ) requirements = heapq.merge( existing_requirements, new_requirements, key=lambda x: x.name ) requirements = [str(req) for req in requirements] with open(args.requirements, "w") as f: f.write("\n".join(requirements)) count = len(new_requirements) entry_word = "entry" if count == 1 else "entries" print(f"updated {args.requirements} with {count} {entry_word}") def main(): args = parser.parse_args() match args.command: case "new": create_new(args) case "install": install_libraries(args) if __name__ == "__main__": main()