Patrick Kelley 8fd444092b initial
2025-05-07 15:35:15 -04:00

427 lines
11 KiB
Python

"""
These are meant to be private utility methods for internal use.
"""
import errno
import importlib.machinery
import os
import shutil
import string
import tarfile
import types
import git
import semantic_version as semver
def make_dir(path):
"""Create a directory or do nothing if it already exists.
Raises:
OSError: if directory cannot be created
"""
try:
os.makedirs(path)
except OSError as exception:
if exception.errno != errno.EEXIST:
raise
elif os.path.isfile(path):
raise
def normalize_version_tag(tag):
"""Given version string "vX.Y.Z", returns "X.Y.Z".
Returns other input strings unchanged.
"""
if len(tag) > 1 and tag[0] == "v" and tag[1].isdigit():
return tag[1:]
return tag
def delete_path(path):
if os.path.islink(path):
os.remove(path)
return
if not os.path.exists(path):
return
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
def copy_over_path(src, dst, ignore=None):
delete_path(dst)
shutil.copytree(src, dst, symlinks=True, ignore=ignore)
def make_symlink(target_path, link_path, force=True):
try:
os.symlink(target_path, link_path)
except OSError as error:
if error.errno == errno.EEXIST and force and os.path.islink(link_path):
os.remove(link_path)
os.symlink(target_path, link_path)
else:
raise error
def safe_tarfile_extractall(tfile, destdir):
"""Wrapper to tarfile.extractall(), checking for path traversal.
This adds the safeguards the Python docs for tarfile.extractall warn about:
Never extract archives from untrusted sources without prior inspection. It
is possible that files are created outside of path, e.g. members that have
absolute filenames starting with "/" or filenames with two dots "..".
Args:
tfile (str): the tar file to extract
destdir (str): the destination directory into which to place contents
Raises:
Exception: if the tarfile would extract outside destdir
"""
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
with tarfile.open(tfile) as tar:
for member in tar.getmembers():
member_path = os.path.join(destdir, member.name)
if not is_within_directory(destdir, member_path):
raise Exception("attempted path traversal in tarfile")
tar.extractall(destdir)
def find_sentence_end(s):
beg = 0
while True:
period_idx = s.find(".", beg)
if period_idx == -1:
return -1
if period_idx == len(s) - 1:
return period_idx
next_char = s[period_idx + 1]
if next_char.isspace():
return period_idx
beg = period_idx + 1
def git_clone(git_url, dst_path, shallow=False):
if shallow:
try:
git.Git().clone(
git_url,
dst_path,
"--no-single-branch",
recursive=True,
depth=1,
)
except git.GitCommandError:
if not git_url.startswith(".") and not git_url.startswith("/"):
# Not a local repo
raise
if not os.path.exists(os.path.join(git_url, ".git", "shallow")):
raise
# Some git versions cannot clone from a shallow-clone, so copy
# and reset/clean it to a pristine condition.
copy_over_path(git_url, dst_path)
rval = git.Repo(dst_path)
rval.git.reset("--hard")
rval.git.clean("-ffdx")
else:
git.Git().clone(git_url, dst_path, recursive=True)
rval = git.Repo(dst_path)
# This setting of the "origin" remote will be a no-op in most cases, but
# for some reason, when cloning from a local directory, the clone may
# inherit the "origin" instead of using the local directory as its new
# "origin". This is bad in some cases since we do not want to try
# fetching from a remote location (e.g. when unbundling). This
# unintended inheritence of "origin" seems to only happen when cloning a
# local git repo that has submodules ?
rval.git.remote("set-url", "origin", git_url)
return rval
def git_checkout(clone, version):
"""Checkout a version of a git repo along with any associated submodules.
Args:
clone (git.Repo): the git clone on which to operate
version (str): the branch, tag, or commit to checkout
Raises:
git.GitCommandError: if the git repo is invalid
"""
clone.git.checkout(version)
clone.git.submodule("sync", "--recursive")
clone.git.submodule("update", "--recursive", "--init")
def git_default_branch(repo):
"""Return default branch of a git repo, like 'main' or 'master'.
If the Git repository has a remote named 'origin', the default branch
is taken from the value of its HEAD reference (if it has one).
If the Git repository has no remote named 'origin' or that remote has no
HEAD, the default branch is selected in this order: 'main' if it exists,
'master' if it exists, the currently checked out branch if any, else the
current detached commit.
Args:
repo (git.Repo): the git clone on which to operate
"""
try:
remote = repo.remote("origin")
except ValueError:
remote = None
if remote:
# Technically possible that remote has no HEAD, so guard against that.
try:
head_ref_name = remote.refs.HEAD.ref.name
except Exception:
head_ref_name = None
if head_ref_name:
remote_prefix = "origin/"
if head_ref_name.startswith(remote_prefix):
return head_ref_name[len(remote_prefix) :]
return head_ref_name
ref_names = [ref.name for ref in repo.refs]
if "main" in ref_names:
return "main"
if "master" in ref_names:
return "master"
try:
# See if there's a branch currently checked out
return repo.head.ref.name
except TypeError:
# No branch checked out, return commit hash
return repo.head.object.hexsha
def git_version_tags(repo):
"""Returns semver-sorted list of version tag strings in the given repo."""
tags = []
for tagref in repo.tags:
tag = str(tagref.name)
normal_tag = normalize_version_tag(tag)
try:
sv = semver.Version.coerce(normal_tag)
except ValueError:
# Skip tags that aren't compatible semantic versions.
continue
else:
tags.append((normal_tag, tag, sv))
return [t[1] for t in sorted(tags, key=lambda e: e[2])]
def git_pull(repo):
"""Does a git pull followed up a submodule update.
Args:
clone (git.Repo): the git clone on which to operate
Raises:
git.GitCommandError: in case of git trouble
"""
repo.git.pull()
repo.git.submodule("sync", "--recursive")
repo.git.submodule("update", "--recursive", "--init")
def git_remote_urls(repo):
"""Returns a map of remote name -> URL string for configured remotes.
You'd normally use repo.remotes[n].urls for this, but with old git versions
(<2.7, e.g. on CentOS 7 at the time of writing) this triggers a "git remote
show" (without -n) that will query the remotes, which can fail test
cases. We use the config subsystem to query the URLs directly -- one of the
fallback mechanisms in GitPython's Remote.urls() implementation.
"""
remote_details = repo.git.config("--get-regexp", r"remote\..+\.url")
remotes = {}
for line in remote_details.split("\n"):
try:
remote, url = line.split(maxsplit=1)
remote = remote.split(".")[1]
remotes[remote] = url
except (ValueError, IndexError):
pass
return remotes
def is_sha1(s):
if not s:
return False
if len(s) != 40:
return False
hexdigits = set(string.hexdigits.lower())
return all(c in hexdigits for c in s)
def is_exe(path):
return os.path.isfile(path) and os.access(path, os.X_OK)
def find_program(prog_name):
path, _ = os.path.split(prog_name)
if path:
return prog_name if is_exe(prog_name) else ""
for path in os.environ["PATH"].split(os.pathsep):
path = os.path.join(path.strip('"'), prog_name)
if is_exe(path):
return path
return ""
class ZeekInfo:
"""
Helper class holding information about a Zeek installation.
"""
def __init__(self, *, zeek: str):
self._zeek = zeek
@property
def zeek(self) -> str:
"""Path to zeek executable."""
if not self._zeek:
raise LookupError('No "zeek" executable in PATH')
return self._zeek
_zeek_info = None
def get_zeek_info() -> ZeekInfo:
global _zeek_info
if _zeek_info is None:
_zeek_info = ZeekInfo(
zeek=find_program("zeek"),
)
return _zeek_info
def std_encoding(stream):
if stream.encoding:
return stream.encoding
import locale
if locale.getdefaultlocale()[1] is None:
return "utf-8"
return locale.getpreferredencoding()
def read_zeek_config_line(stdout):
return stdout.readline().strip()
def get_zeek_version():
zeek_config = find_program("zeek-config")
if not zeek_config:
return ""
import subprocess
cmd = subprocess.Popen(
[zeek_config, "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
universal_newlines=True,
)
return read_zeek_config_line(cmd.stdout)
def load_source(filename):
"""Loads given Python script from disk.
Args:
filename (str): name of a Python script file
Returns:
types.ModuleType: a module representing the loaded file
"""
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
# We currrently require Python 3.5+, where the following looks sufficient:
absname = os.path.abspath(filename)
dirname = os.path.dirname(absname)
# Naming here is unimportant, since we access members of the new
# module via the returned instance.
loader = importlib.machinery.SourceFileLoader("template_" + dirname, absname)
mod = types.ModuleType(loader.name)
loader.exec_module(mod)
return mod
def configparser_section_dict(parser, section):
"""Returns a dict representing a ConfigParser section.
Args:
parser (configparser.ConfigParser): a ConfigParser instance
section (str): the name of a config section
Returns:
dict: a dict with key/val entries corresponding to the requested
section, or an empty dict if the given parser has no such section.
"""
res = {}
if not parser.has_section(section):
return {}
for key, val in parser.items(section):
res[key] = val
return res