427 lines
11 KiB
Python
427 lines
11 KiB
Python
"""
|
|
These are meant to be private utility methods for internal use.
|
|
"""
|
|
|
|
import errno
|
|
import importlib.machinery
|
|
import os
|
|
import shutil
|
|
import string
|
|
import tarfile
|
|
import types
|
|
|
|
import git
|
|
import semantic_version as semver
|
|
|
|
|
|
def make_dir(path):
|
|
"""Create a directory or do nothing if it already exists.
|
|
|
|
Raises:
|
|
OSError: if directory cannot be created
|
|
"""
|
|
try:
|
|
os.makedirs(path)
|
|
except OSError as exception:
|
|
if exception.errno != errno.EEXIST:
|
|
raise
|
|
elif os.path.isfile(path):
|
|
raise
|
|
|
|
|
|
def normalize_version_tag(tag):
|
|
"""Given version string "vX.Y.Z", returns "X.Y.Z".
|
|
Returns other input strings unchanged.
|
|
"""
|
|
if len(tag) > 1 and tag[0] == "v" and tag[1].isdigit():
|
|
return tag[1:]
|
|
|
|
return tag
|
|
|
|
|
|
def delete_path(path):
|
|
if os.path.islink(path):
|
|
os.remove(path)
|
|
return
|
|
|
|
if not os.path.exists(path):
|
|
return
|
|
|
|
if os.path.isdir(path):
|
|
shutil.rmtree(path)
|
|
else:
|
|
os.remove(path)
|
|
|
|
|
|
def copy_over_path(src, dst, ignore=None):
|
|
delete_path(dst)
|
|
shutil.copytree(src, dst, symlinks=True, ignore=ignore)
|
|
|
|
|
|
def make_symlink(target_path, link_path, force=True):
|
|
try:
|
|
os.symlink(target_path, link_path)
|
|
except OSError as error:
|
|
if error.errno == errno.EEXIST and force and os.path.islink(link_path):
|
|
os.remove(link_path)
|
|
os.symlink(target_path, link_path)
|
|
else:
|
|
raise error
|
|
|
|
|
|
def safe_tarfile_extractall(tfile, destdir):
|
|
"""Wrapper to tarfile.extractall(), checking for path traversal.
|
|
|
|
This adds the safeguards the Python docs for tarfile.extractall warn about:
|
|
|
|
Never extract archives from untrusted sources without prior inspection. It
|
|
is possible that files are created outside of path, e.g. members that have
|
|
absolute filenames starting with "/" or filenames with two dots "..".
|
|
|
|
Args:
|
|
tfile (str): the tar file to extract
|
|
|
|
destdir (str): the destination directory into which to place contents
|
|
|
|
Raises:
|
|
Exception: if the tarfile would extract outside destdir
|
|
"""
|
|
|
|
def is_within_directory(directory, target):
|
|
abs_directory = os.path.abspath(directory)
|
|
abs_target = os.path.abspath(target)
|
|
prefix = os.path.commonprefix([abs_directory, abs_target])
|
|
return prefix == abs_directory
|
|
|
|
with tarfile.open(tfile) as tar:
|
|
for member in tar.getmembers():
|
|
member_path = os.path.join(destdir, member.name)
|
|
if not is_within_directory(destdir, member_path):
|
|
raise Exception("attempted path traversal in tarfile")
|
|
|
|
tar.extractall(destdir)
|
|
|
|
|
|
def find_sentence_end(s):
|
|
beg = 0
|
|
|
|
while True:
|
|
period_idx = s.find(".", beg)
|
|
|
|
if period_idx == -1:
|
|
return -1
|
|
|
|
if period_idx == len(s) - 1:
|
|
return period_idx
|
|
|
|
next_char = s[period_idx + 1]
|
|
|
|
if next_char.isspace():
|
|
return period_idx
|
|
|
|
beg = period_idx + 1
|
|
|
|
|
|
def git_clone(git_url, dst_path, shallow=False):
|
|
if shallow:
|
|
try:
|
|
git.Git().clone(
|
|
git_url,
|
|
dst_path,
|
|
"--no-single-branch",
|
|
recursive=True,
|
|
depth=1,
|
|
)
|
|
except git.GitCommandError:
|
|
if not git_url.startswith(".") and not git_url.startswith("/"):
|
|
# Not a local repo
|
|
raise
|
|
|
|
if not os.path.exists(os.path.join(git_url, ".git", "shallow")):
|
|
raise
|
|
|
|
# Some git versions cannot clone from a shallow-clone, so copy
|
|
# and reset/clean it to a pristine condition.
|
|
copy_over_path(git_url, dst_path)
|
|
rval = git.Repo(dst_path)
|
|
rval.git.reset("--hard")
|
|
rval.git.clean("-ffdx")
|
|
else:
|
|
git.Git().clone(git_url, dst_path, recursive=True)
|
|
|
|
rval = git.Repo(dst_path)
|
|
|
|
# This setting of the "origin" remote will be a no-op in most cases, but
|
|
# for some reason, when cloning from a local directory, the clone may
|
|
# inherit the "origin" instead of using the local directory as its new
|
|
# "origin". This is bad in some cases since we do not want to try
|
|
# fetching from a remote location (e.g. when unbundling). This
|
|
# unintended inheritence of "origin" seems to only happen when cloning a
|
|
# local git repo that has submodules ?
|
|
rval.git.remote("set-url", "origin", git_url)
|
|
return rval
|
|
|
|
|
|
def git_checkout(clone, version):
|
|
"""Checkout a version of a git repo along with any associated submodules.
|
|
|
|
Args:
|
|
clone (git.Repo): the git clone on which to operate
|
|
|
|
version (str): the branch, tag, or commit to checkout
|
|
|
|
Raises:
|
|
git.GitCommandError: if the git repo is invalid
|
|
"""
|
|
clone.git.checkout(version)
|
|
clone.git.submodule("sync", "--recursive")
|
|
clone.git.submodule("update", "--recursive", "--init")
|
|
|
|
|
|
def git_default_branch(repo):
|
|
"""Return default branch of a git repo, like 'main' or 'master'.
|
|
|
|
If the Git repository has a remote named 'origin', the default branch
|
|
is taken from the value of its HEAD reference (if it has one).
|
|
|
|
If the Git repository has no remote named 'origin' or that remote has no
|
|
HEAD, the default branch is selected in this order: 'main' if it exists,
|
|
'master' if it exists, the currently checked out branch if any, else the
|
|
current detached commit.
|
|
|
|
Args:
|
|
repo (git.Repo): the git clone on which to operate
|
|
"""
|
|
|
|
try:
|
|
remote = repo.remote("origin")
|
|
except ValueError:
|
|
remote = None
|
|
|
|
if remote:
|
|
# Technically possible that remote has no HEAD, so guard against that.
|
|
try:
|
|
head_ref_name = remote.refs.HEAD.ref.name
|
|
except Exception:
|
|
head_ref_name = None
|
|
|
|
if head_ref_name:
|
|
remote_prefix = "origin/"
|
|
|
|
if head_ref_name.startswith(remote_prefix):
|
|
return head_ref_name[len(remote_prefix) :]
|
|
|
|
return head_ref_name
|
|
|
|
ref_names = [ref.name for ref in repo.refs]
|
|
|
|
if "main" in ref_names:
|
|
return "main"
|
|
|
|
if "master" in ref_names:
|
|
return "master"
|
|
|
|
try:
|
|
# See if there's a branch currently checked out
|
|
return repo.head.ref.name
|
|
except TypeError:
|
|
# No branch checked out, return commit hash
|
|
return repo.head.object.hexsha
|
|
|
|
|
|
def git_version_tags(repo):
|
|
"""Returns semver-sorted list of version tag strings in the given repo."""
|
|
tags = []
|
|
|
|
for tagref in repo.tags:
|
|
tag = str(tagref.name)
|
|
normal_tag = normalize_version_tag(tag)
|
|
|
|
try:
|
|
sv = semver.Version.coerce(normal_tag)
|
|
except ValueError:
|
|
# Skip tags that aren't compatible semantic versions.
|
|
continue
|
|
else:
|
|
tags.append((normal_tag, tag, sv))
|
|
|
|
return [t[1] for t in sorted(tags, key=lambda e: e[2])]
|
|
|
|
|
|
def git_pull(repo):
|
|
"""Does a git pull followed up a submodule update.
|
|
|
|
Args:
|
|
clone (git.Repo): the git clone on which to operate
|
|
|
|
Raises:
|
|
git.GitCommandError: in case of git trouble
|
|
"""
|
|
repo.git.pull()
|
|
repo.git.submodule("sync", "--recursive")
|
|
repo.git.submodule("update", "--recursive", "--init")
|
|
|
|
|
|
def git_remote_urls(repo):
|
|
"""Returns a map of remote name -> URL string for configured remotes.
|
|
|
|
You'd normally use repo.remotes[n].urls for this, but with old git versions
|
|
(<2.7, e.g. on CentOS 7 at the time of writing) this triggers a "git remote
|
|
show" (without -n) that will query the remotes, which can fail test
|
|
cases. We use the config subsystem to query the URLs directly -- one of the
|
|
fallback mechanisms in GitPython's Remote.urls() implementation.
|
|
"""
|
|
remote_details = repo.git.config("--get-regexp", r"remote\..+\.url")
|
|
remotes = {}
|
|
|
|
for line in remote_details.split("\n"):
|
|
try:
|
|
remote, url = line.split(maxsplit=1)
|
|
remote = remote.split(".")[1]
|
|
remotes[remote] = url
|
|
except (ValueError, IndexError):
|
|
pass
|
|
|
|
return remotes
|
|
|
|
|
|
def is_sha1(s):
|
|
if not s:
|
|
return False
|
|
|
|
if len(s) != 40:
|
|
return False
|
|
|
|
hexdigits = set(string.hexdigits.lower())
|
|
return all(c in hexdigits for c in s)
|
|
|
|
|
|
def is_exe(path):
|
|
return os.path.isfile(path) and os.access(path, os.X_OK)
|
|
|
|
|
|
def find_program(prog_name):
|
|
path, _ = os.path.split(prog_name)
|
|
|
|
if path:
|
|
return prog_name if is_exe(prog_name) else ""
|
|
|
|
for path in os.environ["PATH"].split(os.pathsep):
|
|
path = os.path.join(path.strip('"'), prog_name)
|
|
|
|
if is_exe(path):
|
|
return path
|
|
|
|
return ""
|
|
|
|
|
|
class ZeekInfo:
|
|
"""
|
|
Helper class holding information about a Zeek installation.
|
|
"""
|
|
|
|
def __init__(self, *, zeek: str):
|
|
self._zeek = zeek
|
|
|
|
@property
|
|
def zeek(self) -> str:
|
|
"""Path to zeek executable."""
|
|
if not self._zeek:
|
|
raise LookupError('No "zeek" executable in PATH')
|
|
return self._zeek
|
|
|
|
|
|
_zeek_info = None
|
|
|
|
|
|
def get_zeek_info() -> ZeekInfo:
|
|
global _zeek_info
|
|
|
|
if _zeek_info is None:
|
|
_zeek_info = ZeekInfo(
|
|
zeek=find_program("zeek"),
|
|
)
|
|
|
|
return _zeek_info
|
|
|
|
|
|
def std_encoding(stream):
|
|
if stream.encoding:
|
|
return stream.encoding
|
|
|
|
import locale
|
|
|
|
if locale.getdefaultlocale()[1] is None:
|
|
return "utf-8"
|
|
|
|
return locale.getpreferredencoding()
|
|
|
|
|
|
def read_zeek_config_line(stdout):
|
|
return stdout.readline().strip()
|
|
|
|
|
|
def get_zeek_version():
|
|
zeek_config = find_program("zeek-config")
|
|
|
|
if not zeek_config:
|
|
return ""
|
|
|
|
import subprocess
|
|
|
|
cmd = subprocess.Popen(
|
|
[zeek_config, "--version"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
bufsize=1,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
return read_zeek_config_line(cmd.stdout)
|
|
|
|
|
|
def load_source(filename):
|
|
"""Loads given Python script from disk.
|
|
|
|
Args:
|
|
filename (str): name of a Python script file
|
|
|
|
Returns:
|
|
types.ModuleType: a module representing the loaded file
|
|
"""
|
|
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
|
# We currrently require Python 3.5+, where the following looks sufficient:
|
|
absname = os.path.abspath(filename)
|
|
dirname = os.path.dirname(absname)
|
|
|
|
# Naming here is unimportant, since we access members of the new
|
|
# module via the returned instance.
|
|
loader = importlib.machinery.SourceFileLoader("template_" + dirname, absname)
|
|
mod = types.ModuleType(loader.name)
|
|
loader.exec_module(mod)
|
|
|
|
return mod
|
|
|
|
|
|
def configparser_section_dict(parser, section):
|
|
"""Returns a dict representing a ConfigParser section.
|
|
|
|
Args:
|
|
parser (configparser.ConfigParser): a ConfigParser instance
|
|
|
|
section (str): the name of a config section
|
|
|
|
Returns:
|
|
dict: a dict with key/val entries corresponding to the requested
|
|
section, or an empty dict if the given parser has no such section.
|
|
"""
|
|
res = {}
|
|
|
|
if not parser.has_section(section):
|
|
return {}
|
|
|
|
for key, val in parser.items(section):
|
|
res[key] = val
|
|
|
|
return res
|