NumPy 源码解析(九十三)
.\numpy\pavement.py
r"""
This paver file is intended to help with the release process as much as
possible. It relies on virtualenv to generate 'bootstrap' environments as
independent from the user system as possible (e.g. to make sure the sphinx doc
is built against the built numpy, not an installed one).
Building changelog + notes
==========================
Assumes you have git and the binaries/tarballs in installers/::
paver write_release
paver write_note
This automatically put the checksum into README.rst, and writes the Changelog.
TODO
====
- the script is messy, lots of global variables
- make it more easily customizable (through command line args)
- missing targets: install & test, sdist test, debian packaging
- fix bdist_mpkg: we build the same source twice -> how to make sure we use
the same underlying python for egg install in venv and for bdist_mpkg
"""
import os
import sys
import shutil
import hashlib
import textwrap
import paver
from paver.easy import Bunch, options, task, sh
RELEASE_NOTES = 'doc/source/release/2.1.0-notes.rst'
options(installers=Bunch(releasedir="release",
installersdir=os.path.join("release", "installers")),)
def _compute_hash(idirs, hashfunc):
"""Hash files using given hashfunc.
Parameters
----------
idirs : directory path
Directory containing files to be hashed.
hashfunc : hash function
Function to be used to hash the files.
"""
released = paver.path.path(idirs).listdir()
checksums = []
for fpath in sorted(released):
with open(fpath, 'rb') as fin:
fhash = hashfunc(fin.read())
checksums.append(
'%s %s' % (fhash.hexdigest(), os.path.basename(fpath)))
return checksums
def compute_md5(idirs):
"""Compute md5 hash of files in idirs.
Parameters
----------
idirs : directory path
Directory containing files to be hashed.
"""
return _compute_hash(idirs, hashlib.md5)
def compute_sha256(idirs):
"""Compute sha256 hash of files in idirs.
Parameters
----------
idirs : directory path
Directory containing files to be hashed.
"""
return _compute_hash(idirs, hashlib.sha256)
def write_release_task(options, filename='README'):
"""Append hashes of release files to release notes.
Parameters
----------
options : paver.easy.Bunch
Options object containing release configurations.
filename : str, optional
Name of the file to append hashes to (default is 'README').
"""
This appends file hashes to the release notes and creates
four README files of the result in various formats:
- README.rst
- README.rst.gpg
- README.md
- README.md.gpg
The md file are created using `pandoc` so that the links are
properly updated. The gpg files are kept separate, so that
the unsigned files may be edited before signing if needed.
Parameters
----------
options :
Set by ``task`` decorator.
由 ``task`` 装饰器设置的选项。
filename : str
Filename of the modified notes. The file is written
in the release directory.
修改后的笔记文件名。文件将被写入发布目录。
"""
idirs = options.installers.installersdir
notes = paver.path.path(RELEASE_NOTES)
rst_readme = paver.path.path(filename + '.rst')
md_readme = paver.path.path(filename + '.md')
# append hashes
with open(rst_readme, 'w') as freadme:
with open(notes) as fnotes:
freadme.write(fnotes.read())
# Write MD5 hashes to README.rst
freadme.writelines(textwrap.dedent(
"""
Checksums
=========
MD5
---
::
"""))
freadme.writelines([f' {c}\n' for c in compute_md5(idirs)])
# Write SHA256 hashes to README.rst
freadme.writelines(textwrap.dedent(
"""
SHA256
------
::
"""))
freadme.writelines([f' {c}\n' for c in compute_sha256(idirs)])
# generate md file using pandoc before signing
sh(f"pandoc -s -o {md_readme} {rst_readme}")
# Sign files
if hasattr(options, 'gpg_key'):
# Generate a clear-signed, armored ASCII file with default or specified GPG key
cmd = f'gpg --clearsign --armor --default_key {options.gpg_key}'
else:
# Generate a clear-signed, armored ASCII file
cmd = 'gpg --clearsign --armor'
# Sign README.rst and README.md using GPG
sh(cmd + f' --output {rst_readme}.gpg {rst_readme}')
sh(cmd + f' --output {md_readme}.gpg {md_readme}')
# 使用 @task 装饰器声明一个任务函数,用于生成发布说明文件。
@task
# 定义函数 write_release,用于生成 README 文件。
def write_release(options):
"""Write the README files.
Two README files are generated from the release notes, one in ``rst``
markup for the general release, the other in ``md`` markup for the github
release notes.
Parameters
----------
options :
Set by ``task`` decorator.
"""
# 获取发布目录路径
rdir = options.installers.releasedir
# 调用 write_release_task 函数,生成 README 文件,文件路径为发布目录下的 README
write_release_task(options, os.path.join(rdir, 'README'))
.\numpy\tools\changelog.py
"""
Script to generate contributor and pull request lists
This script generates contributor and pull request lists for release
changelogs using Github v3 protocol. Use requires an authentication token in
order to have sufficient bandwidth, you can get one following the directions at
`<https://help.github.com/articles/creating-an-access-token-for-command-line-use/>_
Don't add any scope, as the default is read access to public information. The
token may be stored in an environment variable as you only get one chance to
see it.
Usage::
$ ./tools/announce.py <token> <revision range>
The output is utf8 rst.
Dependencies
------------
- gitpython
- pygithub
- git >= 2.29.0
Some code was copied from scipy `tools/gh_list.py` and `tools/authors.py`.
Examples
--------
From the bash command line with $GITHUB token::
$ ./tools/announce $GITHUB v1.13.0..v1.14.0 > 1.14.0-changelog.rst
"""
import os
import sys
import re
from git import Repo
from github import Github
this_repo = Repo(os.path.join(os.path.dirname(__file__), ".."))
author_msg =\
"""
A total of %d people contributed to this release. People with a "+" by their
names contributed a patch for the first time.
"""
pull_request_msg =\
"""
A total of %d pull requests were merged for this release.
"""
def get_authors(revision_range):
lst_release, cur_release = [r.strip() for r in revision_range.split('..')]
authors_pat = r'^.*\t(.*)$'
grp1 = '--group=author'
grp2 = '--group=trailer:co-authored-by'
cur = this_repo.git.shortlog('-s', grp1, grp2, revision_range)
pre = this_repo.git.shortlog('-s', grp1, grp2, lst_release)
authors_cur = set(re.findall(authors_pat, cur, re.M))
authors_pre = set(re.findall(authors_pat, pre, re.M))
authors_cur.discard('Homu')
authors_pre.discard('Homu')
authors_cur.discard('dependabot-preview')
authors_pre.discard('dependabot-preview')
authors_new = [s + ' +' for s in authors_cur - authors_pre]
authors_old = [s for s in authors_cur & authors_pre]
authors = authors_new + authors_old
authors.sort()
return authors
def get_pull_requests(repo, revision_range):
prnums = []
merges = this_repo.git.log('--oneline', '--merges', revision_range)
issues = re.findall(r"Merge pull request \#(\d*)", merges)
prnums.extend(int(s) for s in issues)
issues = re.findall(r"Auto merge of \#(\d*)", merges)
prnums.extend(int(s) for s in issues)
commits = this_repo.git.log('--oneline', '--no-merges', '--first-parent', revision_range)
issues = re.findall(r'^.*\((\#|gh-|gh-\#)(\d+)\)$', commits, re.M)
prnums.extend(int(s[1]) for s in issues)
prnums.sort()
prs = [repo.get_pull(n) for n in prnums]
return prs
def main(token, revision_range):
lst_release, cur_release = [r.strip() for r in revision_range.split('..')]
github = Github(token)
github_repo = github.get_repo('numpy/numpy')
authors = get_authors(revision_range)
heading = "Contributors"
print()
print(heading)
print("="*len(heading))
print(author_msg % len(authors))
for s in authors:
print('* ' + s)
pull_requests = get_pull_requests(github_repo, revision_range)
heading = "Pull requests merged"
pull_msg = "* `#{0} <{1}>`__: {2}"
print()
print(heading)
print("="*len(heading))
print(pull_request_msg % len(pull_requests))
for pull in pull_requests:
title = re.sub(r"\s+", " ", pull.title.strip())
if len(title) > 60:
remainder = re.sub(r"\s.*$", "...", title[60:])
if len(remainder) > 20:
remainder = title[:80] + "..."
else:
title = title[:60] + remainder
print(pull_msg.format(pull.number, pull.html_url, title))
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(description="Generate author/pr lists for release")
parser.add_argument('token', help='github access token')
parser.add_argument('revision_range', help='<revision>..<revision>')
args = parser.parse_args()
main(args.token, args.revision_range)
.\numpy\tools\check_installed_files.py
"""
Check if all the test and .pyi files are installed after building.
Examples::
$ python check_installed_files.py install_dirname
install_dirname:
the relative path to the directory where NumPy is installed after
building and running `meson install`.
Notes
=====
The script will stop on encountering the first missing file in the install dir,
it will not give a full listing. This should be okay, because the script is
meant for use in CI so it's not like many files will be missing at once.
"""
import os
import glob
import sys
import json
CUR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__)))
ROOT_DIR = os.path.dirname(CUR_DIR)
NUMPY_DIR = os.path.join(ROOT_DIR, 'numpy')
changed_installed_path = {
}
def main(install_dir, tests_check):
INSTALLED_DIR = os.path.join(ROOT_DIR, install_dir)
if not os.path.exists(INSTALLED_DIR):
raise ValueError(
f"Provided install dir {INSTALLED_DIR} does not exist"
)
numpy_test_files = get_files(NUMPY_DIR, kind='test')
installed_test_files = get_files(INSTALLED_DIR, kind='test')
if tests_check == "--no-tests":
if len(installed_test_files) > 0:
raise Exception("Test files aren't expected to be installed in %s"
", found %s" % (INSTALLED_DIR, installed_test_files))
print("----------- No test files were installed --------------")
else:
for test_file in numpy_test_files.keys():
if test_file not in installed_test_files.keys():
raise Exception(
"%s is not installed" % numpy_test_files[test_file]
)
print("----------- All the test files were installed --------------")
numpy_pyi_files = get_files(NUMPY_DIR, kind='stub')
installed_pyi_files = get_files(INSTALLED_DIR, kind='stub')
for pyi_file in numpy_pyi_files.keys():
if pyi_file not in installed_pyi_files.keys():
if (tests_check == "--no-tests" and
"tests" in numpy_pyi_files[pyi_file]):
continue
raise Exception("%s is not installed" % numpy_pyi_files[pyi_file])
print("----------- All the necessary .pyi files "
"were installed --------------")
def get_files(dir_to_check, kind='test'):
files = dict()
patterns = {
'test': f'{dir_to_check}/**/test_*.py',
'stub': f'{dir_to_check}/**/*.pyi',
}
for path in glob.glob(patterns[kind], recursive=True):
relpath = os.path.relpath(path, dir_to_check)
files[relpath] = path
if sys.version_info >= (3, 12):
files = {
k: v for k, v in files.items() if not k.startswith('distutils')
}
return files
files = {
k: v for k, v in files.items() if 'pythoncapi-compat' not in k
}
return files
if __name__ == '__main__':
if len(sys.argv) < 2:
raise ValueError("Incorrect number of input arguments, need "
"check_installation.py relpath/to/installed/numpy")
install_dir = sys.argv[1]
tests_check = ""
if len(sys.argv) >= 3:
tests_check = sys.argv[2]
main(install_dir, tests_check)
all_tags = set()
with open(os.path.join('build', 'meson-info',
'intro-install_plan.json'), 'r') as f:
targets = json.load(f)
for key in targets.keys():
for values in list(targets[key].values()):
if not values['tag'] in all_tags:
all_tags.add(values['tag'])
if all_tags != set(['runtime', 'python-runtime', 'devel', 'tests']):
raise AssertionError(f"Found unexpected install tag: {all_tags}")
.\numpy\tools\check_openblas_version.py
"""
usage: check_openblas_version.py <min_version>
Check the blas version is blas from scipy-openblas and is higher than
min_version
example: check_openblas_version.py 0.3.26
"""
import numpy
import pprint
import sys
version = sys.argv[1]
deps = numpy.show_config('dicts')['Build Dependencies']
assert "blas" in deps
print("Build Dependencies: blas")
pprint.pprint(deps["blas"])
assert deps["blas"]["version"].split(".") >= version.split(".")
assert deps["blas"]["name"] == "scipy-openblas"
.\numpy\tools\ci\push_docs_to_repo.py
import argparse
import subprocess
import tempfile
import os
import sys
import shutil
parser = argparse.ArgumentParser(
description='Upload files to a remote repo, replacing existing content'
)
parser.add_argument('dir', help='directory of which content will be uploaded')
parser.add_argument('remote', help='remote to which content will be pushed')
parser.add_argument('--message', default='Commit bot upload',
help='commit message to use')
parser.add_argument('--committer', default='numpy-commit-bot',
help='Name of the git committer')
parser.add_argument('--email', default='numpy-commit-bot@nomail',
help='Email of the git committer')
parser.add_argument('--count', default=1, type=int,
help="minimum number of expected files, defaults to 1")
parser.add_argument(
'--force', action='store_true',
help='hereby acknowledge that remote repo content will be overwritten'
)
args = parser.parse_args()
args.dir = os.path.abspath(args.dir)
if not os.path.exists(args.dir):
print('Content directory does not exist')
sys.exit(1)
count = len([name for name in os.listdir(args.dir) if os.path.isfile(os.path.join(args.dir, name))])
if count < args.count:
print(f"Expected {args.count} top-directory files to upload, got {count}")
sys.exit(1)
def run(cmd, stdout=True):
pipe = None if stdout else subprocess.DEVNULL
try:
subprocess.check_call(cmd, stdout=pipe, stderr=pipe)
except subprocess.CalledProcessError:
print("\n! Error executing: `%s;` aborting" % ' '.join(cmd))
sys.exit(1)
workdir = tempfile.mkdtemp()
os.chdir(workdir)
run(['git', 'init'])
run(['git', 'checkout', '-b', 'main'])
run(['git', 'remote', 'add', 'origin', args.remote])
run(['git', 'config', '--local', 'user.name', args.committer])
run(['git', 'config', '--local', 'user.email', args.email])
print('- committing new content: "%s"' % args.message)
run(['cp', '-R', os.path.join(args.dir, '.'), '.'])
run(['git', 'add', '.'], stdout=False)
run(['git', 'commit', '--allow-empty', '-m', args.message], stdout=False)
print('- uploading as %s <%s>' % (args.committer, args.email))
if args.force:
run(['git', 'push', 'origin', 'main', '--force'])
else:
print('\n!! No `--force` argument specified; aborting')
print('!! Before enabling that flag, make sure you know what it does\n')
sys.exit(1)
shutil.rmtree(workdir)
.\numpy\tools\ci\test_all_newsfragments_used.py
import sys
import toml
import os
def main():
path = toml.load("pyproject.toml")["tool"]["towncrier"]["directory"]
fragments = os.listdir(path)
fragments.remove("README.rst")
fragments.remove("template.rst")
if fragments:
print("The following files were not found by towncrier:")
print(" " + "\n ".join(fragments))
sys.exit(1)
if __name__ == "__main__":
main()
.\numpy\tools\commitstats.py
command = 'svn log -l 2300 > output.txt'
import re
import numpy as np
import os
names = re.compile(r'r\d+\s\|\s(.*)\s\|\s200')
def get_count(filename, repo):
mystr = open(filename).read()
result = names.findall(mystr)
u = np.unique(result)
count = [(x, result.count(x), repo) for x in u]
return count
os.chdir('..')
os.system(command)
count = get_count('output.txt', 'NumPy')
os.chdir('../scipy')
os.system(command)
count.extend(get_count('output.txt', 'SciPy'))
os.chdir('../scikits')
os.system(command)
count.extend(get_count('output.txt', 'SciKits'))
count.sort()
print("** SciPy and NumPy **")
print("=====================")
for val in count:
print(val)
.\numpy\tools\c_coverage\c_coverage_report.py
"""
A script to create C code-coverage reports based on the output of
valgrind's callgrind tool.
"""
import os
import re
import sys
from xml.sax.saxutils import quoteattr, escape
try:
import pygments
if tuple([int(x) for x in pygments.__version__.split('.')]) < (0, 11):
raise ImportError()
from pygments import highlight
from pygments.lexers import CLexer
from pygments.formatters import HtmlFormatter
has_pygments = True
except ImportError:
print("This script requires pygments 0.11 or greater to generate HTML")
has_pygments = False
class FunctionHtmlFormatter(HtmlFormatter):
"""Custom HTML formatter to insert extra information with the lines."""
def __init__(self, lines, **kwargs):
HtmlFormatter.__init__(self, **kwargs)
self.lines = lines
def wrap(self, source, outfile):
for i, (c, t) in enumerate(HtmlFormatter.wrap(self, source, outfile)):
as_functions = self.lines.get(i-1, None)
if as_functions is not None:
yield 0, ('<div title=%s style="background: #ccffcc">[%2d]' %
(quoteattr('as ' + ', '.join(as_functions)),
len(as_functions)))
else:
yield 0, ' '
yield c, t
if as_functions is not None:
yield 0, '</div>'
class SourceFile:
def __init__(self, path):
self.path = path
self.lines = {}
def mark_line(self, lineno, as_func=None):
line = self.lines.setdefault(lineno, set())
if as_func is not None:
as_func = as_func.split("'", 1)[0]
line.add(as_func)
def write_text(self, fd):
source = open(self.path, "r")
for i, line in enumerate(source):
if i + 1 in self.lines:
fd.write("> ")
else:
fd.write("! ")
fd.write(line)
source.close()
def write_html(self, fd):
source = open(self.path, 'r')
code = source.read()
lexer = CLexer()
formatter = FunctionHtmlFormatter(
self.lines,
full=True,
linenos='inline')
fd.write(highlight(code, lexer, formatter))
source.close()
class SourceFiles:
def __init__(self):
self.files = {}
self.prefix = None
def get_file(self, path):
if path not in self.files:
self.files[path] = SourceFile(path)
if self.prefix is None:
self.prefix = path
else:
self.prefix = os.path.commonprefix([self.prefix, path])
return self.files[path]
def clean_path(self, path):
path = path[len(self.prefix):]
return re.sub(r"[^A-Za-z0-9\.]", '_', path)
def write_text(self, root):
for path, source in self.files.items():
fd = open(os.path.join(root, self.clean_path(path)), "w")
source.write_text(fd)
fd.close()
def write_html(self, root):
for path, source in self.files.items():
fd = open(os.path.join(root, self.clean_path(path) + ".html"), "w")
source.write_html(fd)
fd.close()
fd = open(os.path.join(root, 'index.html'), 'w')
fd.write("<html>")
paths = sorted(self.files.keys())
for path in paths:
fd.write('<p><a href="%s.html">%s</a></p>' %
(self.clean_path(path), escape(path[len(self.prefix):])))
fd.write("</html>")
fd.close()
def collect_stats(files, fd, pattern):
line_regexs = [
re.compile(r"(?P<lineno>[0-9]+)(\s[0-9]+)+"),
re.compile(r"((jump)|(jcnd))=([0-9]+)\s(?P<lineno>[0-9]+)")
]
current_file = None
current_function = None
for i, line in enumerate(fd):
if re.match("f[lie]=.+", line):
path = line.split('=', 2)[1].strip()
if os.path.exists(path) and re.search(pattern, path):
current_file = files.get_file(path)
else:
current_file = None
elif re.match("fn=.+", line):
current_function = line.split('=', 2)[1].strip()
elif current_file is not None:
for regex in line_regexs:
match = regex.match(line)
if match:
lineno = int(match.group('lineno'))
current_file.mark_line(lineno, current_function)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'callgrind_file', nargs='+',
help='One or more callgrind files')
parser.add_argument(
'-d', '--directory', default='coverage',
help='Destination directory for output (default: %(default)s)')
parser.add_argument(
'-p', '--pattern', default='numpy',
help='Regex pattern to match against source file paths '
'(default: %(default)s)')
parser.add_argument(
'-f', '--format', action='append', default=[],
choices=['text', 'html'],
help="Output format(s) to generate. "
"If option not provided, both will be generated.")
args = parser.parse_args()
files = SourceFiles()
for log_file in args.callgrind_file:
log_fd = open(log_file, 'r')
collect_stats(files, log_fd, args.pattern)
log_fd.close()
if not os.path.exists(args.directory):
os.makedirs(args.directory)
if args.format == []:
formats = ['text', 'html']
else:
formats = args.format
if 'text' in formats:
files.write_text(args.directory)
if 'html' in formats:
if not has_pygments:
print("Pygments 0.11 or later is required to generate HTML")
sys.exit(1)
files.write_html(args.directory)
.\numpy\tools\download-wheels.py
"""
Script to download NumPy wheels from the Anaconda staging area.
Usage::
$ ./tools/download-wheels.py <version> -w <optional-wheelhouse>
The default wheelhouse is ``release/installers``.
Dependencies
------------
- beautifulsoup4
- urllib3
Examples
--------
While in the repository root::
$ python tools/download-wheels.py 1.19.0
$ python tools/download-wheels.py 1.19.0 -w ~/wheelhouse
"""
import os
import re
import shutil
import argparse
import urllib3
from bs4 import BeautifulSoup
__version__ = "0.1"
STAGING_URL = "https://anaconda.org/multibuild-wheels-staging/numpy"
PREFIX = "numpy"
WHL = r"-.*\.whl$"
ZIP = r"\.zip$"
GZIP = r"\.tar\.gz$"
SUFFIX = rf"({WHL}|{GZIP}|{ZIP})"
def get_wheel_names(version):
""" Get wheel names from Anaconda HTML directory.
This looks in the Anaconda multibuild-wheels-staging page and
parses the HTML to get all the wheel names for a release version.
Parameters
----------
version : str
The release version. For instance, "1.18.3".
"""
http = urllib3.PoolManager(cert_reqs="CERT_REQUIRED")
tmpl = re.compile(rf"^.*{PREFIX}-{version}{SUFFIX}")
index_url = f"{STAGING_URL}/files"
index_html = http.request("GET", index_url)
soup = BeautifulSoup(index_html.data, "html.parser")
return soup.find_all(string=tmpl)
def download_wheels(version, wheelhouse):
"""Download release wheels.
The release wheels for the given NumPy version are downloaded
into the given directory.
Parameters
----------
version : str
The release version. For instance, "1.18.3".
wheelhouse : str
Directory in which to download the wheels.
"""
http = urllib3.PoolManager(cert_reqs="CERT_REQUIRED")
wheel_names = get_wheel_names(version)
for i, wheel_name in enumerate(wheel_names):
wheel_url = f"{STAGING_URL}/{version}/download/{wheel_name}"
wheel_path = os.path.join(wheelhouse, wheel_name)
with open(wheel_path, "wb") as f:
with http.request("GET", wheel_url, preload_content=False) as r:
print(f"{i + 1:<4}{wheel_name}")
shutil.copyfileobj(r, f)
print(f"\nTotal files downloaded: {len(wheel_names)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"version",
help="NumPy version to download.")
parser.add_argument(
"-w", "--wheelhouse",
default=os.path.join(os.getcwd(), "release", "installers"),
help="Directory in which to store downloaded wheels\n"
"[defaults to <cwd>/release/installers]")
args = parser.parse_args()
wheelhouse = os.path.expanduser(args.wheelhouse)
if not os.path.isdir(wheelhouse):
raise RuntimeError(
f"{wheelhouse} wheelhouse directory is not present."
" Perhaps you need to use the '-w' flag to specify one.")
download_wheels(args.version, wheelhouse)
.\numpy\tools\find_deprecated_escaped_characters.py
r"""
Look for escape sequences deprecated in Python 3.6.
Python 3.6 deprecates a number of non-escape sequences starting with '\' that
were accepted before. For instance, '\(' was previously accepted but must now
be written as '\\(' or r'\('.
"""
def main(root):
"""Find deprecated escape sequences.
Checks for deprecated escape sequences in ``*.py files``. If `root` is a
file, that file is checked, if `root` is a directory all ``*.py`` files
found in a recursive descent are checked.
If a deprecated escape sequence is found, the file and line where found is
printed. Note that for multiline strings the line where the string ends is
printed and the error(s) are somewhere in the body of the string.
Parameters
----------
root : str
File or directory to check.
Returns
-------
None
"""
import ast
import tokenize
import warnings
from pathlib import Path
count = 0
base = Path(root)
paths = base.rglob("*.py") if base.is_dir() else [base]
for path in paths:
with tokenize.open(str(path)) as f:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
tree = ast.parse(f.read())
if w:
print("file: ", str(path))
for e in w:
print('line: ', e.lineno, ': ', e.message)
print()
count += len(w)
print("Errors Found", count)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(description="Find deprecated escaped characters")
parser.add_argument('root', help='directory or file to be checked')
args = parser.parse_args()
main(args.root)
.\numpy\tools\functions_missing_types.py
"""Find the functions in a module missing type annotations.
To use it run
./functions_missing_types.py <module>
and it will print out a list of functions in the module that don't
have types.
"""
import argparse
import ast
import importlib
import os
NUMPY_ROOT = os.path.dirname(os.path.join(
os.path.abspath(__file__), "..",
))
EXCLUDE_LIST = {
"numpy": {
"absolute_import",
"division",
"print_function",
"warnings",
"sys",
"os",
"math",
"Tester",
"_core",
"get_array_wrap",
"int_asbuffer",
"numarray",
"oldnumeric",
"safe_eval",
"test",
"typeDict",
"bool",
"complex",
"float",
"int",
"long",
"object",
"str",
"unicode",
"alltrue",
"sometrue",
}
}
class FindAttributes(ast.NodeVisitor):
"""Find top-level attributes/functions/classes in stubs files.
Do this by walking the stubs ast. See e.g.
https://greentreesnakes.readthedocs.io/en/latest/index.html
for more information on working with Python's ast.
"""
def __init__(self):
self.attributes = set()
def visit_FunctionDef(self, node):
if node.name == "__getattr__":
return
self.attributes.add(node.name)
return
def visit_ClassDef(self, node):
if not node.name.startswith("_"):
self.attributes.add(node.name)
return
def visit_AnnAssign(self, node):
self.attributes.add(node.target.id)
def find_missing(module_name):
module_path = os.path.join(
NUMPY_ROOT,
module_name.replace(".", os.sep),
"__init__.pyi",
)
module = importlib.import_module(module_name)
module_attributes = {
attribute for attribute in dir(module) if not attribute.startswith("_")
}
if os.path.isfile(module_path):
with open(module_path) as f:
tree = ast.parse(f.read())
ast_visitor = FindAttributes()
ast_visitor.visit(tree)
stubs_attributes = ast_visitor.attributes
else:
stubs_attributes = set()
exclude_list = EXCLUDE_LIST.get(module_name, set())
missing = module_attributes - stubs_attributes - exclude_list
print("\n".join(sorted(missing)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("module")
args = parser.parse_args()
find_missing(args.module)
if __name__ == "__main__":
main()
.\numpy\tools\linter.py
import os
import sys
import subprocess
from argparse import ArgumentParser
from git import Repo, exc
CONFIG = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
'lint_diff.ini',
)
EXCLUDE = (
"numpy/typing/tests/data/",
"numpy/typing/_char_codes.py",
"numpy/__config__.py",
"numpy/f2py",
)
class DiffLinter:
def __init__(self, branch):
self.branch = branch
self.repo = Repo('.')
self.head = self.repo.head.commit
def get_branch_diff(self, uncommitted = False):
"""
Determine the first common ancestor commit.
Find diff between branch and FCA commit.
Note: if `uncommitted` is set, check only
uncommitted changes
"""
try:
commit = self.repo.merge_base(self.branch, self.head)[0]
except exc.GitCommandError:
print(f"Branch with name `{self.branch}` does not exist")
sys.exit(1)
exclude = [f':(exclude){i}' for i in EXCLUDE]
if uncommitted:
diff = self.repo.git.diff(
self.head, '--unified=0', '***.py', *exclude
)
else:
diff = self.repo.git.diff(
commit, self.head, '--unified=0', '***.py', *exclude
)
return diff
def run_pycodestyle(self, diff):
"""
Original Author: Josh Wilson (@person142)
Source:
https://github.com/scipy/scipy/blob/main/tools/lint_diff.py
Run pycodestyle on the given diff.
"""
res = subprocess.run(
['pycodestyle', '--diff', '--config', CONFIG],
input=diff,
stdout=subprocess.PIPE,
encoding='utf-8',
)
return res.returncode, res.stdout
def run_lint(self, uncommitted):
diff = self.get_branch_diff(uncommitted)
retcode, errors = self.run_pycodestyle(diff)
errors and print(errors)
sys.exit(retcode)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--branch", type=str, default='main',
help="The branch to diff against")
parser.add_argument("--uncommitted", action='store_true',
help="Check only uncommitted changes")
args = parser.parse_args()
DiffLinter(args.branch).run_lint(args.uncommitted)
.\numpy\tools\refguide_check.py
"""
refguide_check.py [OPTIONS] [-- ARGS]
- Check for a NumPy submodule whether the objects in its __all__ dict
correspond to the objects included in the reference guide.
- Check docstring examples
- Check example blocks in RST files
Example of usage::
$ python tools/refguide_check.py
Note that this is a helper script to be able to check if things are missing;
the output of this script does need to be checked manually. In some cases
objects are left out of the refguide for a good reason (it's an alias of
another function, or deprecated, or ...)
Another use of this helper script is to check validity of code samples
in docstrings::
$ python tools/refguide_check.py --doctests ma
or in RST-based documentations::
$ python tools/refguide_check.py --rst doc/source
"""
import copy
import doctest
import inspect
import io
import os
import re
import shutil
import sys
import tempfile
import warnings
import docutils.core
from argparse import ArgumentParser
from contextlib import contextmanager, redirect_stderr
from doctest import NORMALIZE_WHITESPACE, ELLIPSIS, IGNORE_EXCEPTION_DETAIL
from docutils.parsers.rst import directives
import sphinx
import numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'doc', 'sphinxext'))
from numpydoc.docscrape_sphinx import get_doc_object
SKIPBLOCK = doctest.register_optionflag('SKIPBLOCK')
from sphinx.directives.other import SeeAlso, Only
directives.register_directive('seealso', SeeAlso)
directives.register_directive('only', Only)
BASE_MODULE = "numpy"
PUBLIC_SUBMODULES = [
"f2py",
"linalg",
"lib",
"lib.format",
"lib.mixins",
"lib.recfunctions",
"lib.scimath",
"lib.stride_tricks",
"lib.npyio",
"lib.introspect",
"lib.array_utils",
"fft",
"char",
"rec",
"ma",
"ma.extras",
"ma.mrecords",
"polynomial",
"polynomial.chebyshev",
"polynomial.hermite",
"polynomial.hermite_e",
"polynomial.laguerre",
"polynomial.legendre",
"polynomial.polynomial",
"matrixlib",
"random",
"strings",
"testing",
]
OTHER_MODULE_DOCS = {
'fftpack.convolve': 'fftpack',
'io.wavfile': 'io',
'io.arff': 'io',
}
DOCTEST_SKIPDICT = {
'numpy.lib.vectorize': None,
'numpy.random.standard_gamma': None,
'numpy.random.gamma': None,
'numpy.random.vonmises': None,
'numpy.random.power': None,
'numpy.random.zipf': None,
'numpy._core.from_dlpack': None,
}
'numpy.lib.npyio.DataSource': None,
'numpy.lib.Repository': None,
RST_SKIPLIST = [
'scipy-sphinx-theme',
'sphinxext',
'neps',
'changelog',
'doc/release',
'doc/source/release',
'doc/release/upcoming_changes',
'c-info.ufunc-tutorial.rst',
'c-info.python-as-glue.rst',
'f2py.getting-started.rst',
'f2py-examples.rst',
'arrays.nditer.cython.rst',
'how-to-verify-bug.rst',
'basics.dispatch.rst',
'basics.subclassing.rst',
'basics.interoperability.rst',
'misc.rst',
'TESTS.rst'
]
REFGUIDE_ALL_SKIPLIST = [
r'scipy\.sparse\.linalg',
r'scipy\.spatial\.distance',
r'scipy\.linalg\.blas\.[sdczi].*',
r'scipy\.linalg\.lapack\.[sdczi].*',
]
REFGUIDE_AUTOSUMMARY_SKIPLIST = [
r'numpy\.*',
]
for name in ('barthann', 'bartlett', 'blackmanharris', 'blackman', 'bohman',
'boxcar', 'chebwin', 'cosine', 'exponential', 'flattop',
'gaussian', 'general_gaussian', 'hamming', 'hann', 'hanning',
'kaiser', 'nuttall', 'parzen', 'slepian', 'triang', 'tukey'):
REFGUIDE_AUTOSUMMARY_SKIPLIST.append(r'scipy\.signal\.' + name)
HAVE_MATPLOTLIB = False
"""
The `names_dict` is updated by reference and accessible in calling method
Parameters
----------
module : ModuleType
The module, whose docstrings is to be searched
names_dict : dict
Dictionary which contains module name as key and a set of found
function names and directives as value
Returns
-------
None
"""
patterns = [
r"^\s\s\s([a-z_0-9A-Z]+)(\s+-+.*)?$",
r"^\.\. (?:data|function)::\s*([a-z_0-9A-Z]+)\s*$"
]
if module.__name__ == 'scipy.constants':
patterns += ["^``([a-z_0-9A-Z]+)``"]
patterns = [re.compile(pattern) for pattern in patterns]
module_name = module.__name__
for line in module.__doc__.splitlines():
res = re.search(r"^\s*\.\. (?:currentmodule|module):: ([a-z0-9A-Z_.]+)\s*$", line)
if res:
module_name = res.group(1)
continue
for pattern in patterns:
res = re.match(pattern, line)
if res is not None:
name = res.group(1)
entry = '.'.join([module_name, name])
names_dict.setdefault(module_name, set()).add(name)
break
def get_all_dict(module):
if hasattr(module, "__all__"):
all_dict = copy.deepcopy(module.__all__)
else:
all_dict = copy.deepcopy(dir(module))
all_dict = [name for name in all_dict
if not name.startswith("_")]
for name in ['absolute_import', 'division', 'print_function']:
try:
all_dict.remove(name)
except ValueError:
pass
if not all_dict:
all_dict.append('__doc__')
all_dict = [name for name in all_dict
if not inspect.ismodule(getattr(module, name, None))]
deprecated = []
not_deprecated = []
for name in all_dict:
f = getattr(module, name, None)
if callable(f) and is_deprecated(f):
deprecated.append(name)
else:
not_deprecated.append(name)
others = set(dir(module)).difference(set(deprecated)).difference(set(not_deprecated))
return not_deprecated, deprecated, others
def compare(all_dict, others, names, module_name):
only_all = set()
for name in all_dict:
if name not in names:
for pat in REFGUIDE_AUTOSUMMARY_SKIPLIST:
if re.match(pat, module_name + '.' + name):
break
else:
only_all.add(name)
only_ref = set()
missing = set()
for name in names:
if name not in all_dict:
for pat in REFGUIDE_ALL_SKIPLIST:
if re.match(pat, module_name + '.' + name):
if name not in others:
missing.add(name)
break
else:
only_ref.add(name)
return only_all, only_ref, missing
def is_deprecated(f):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("error")
try:
f(**{"not a kwarg": None})
except DeprecationWarning:
return True
except Exception:
pass
return False
def check_items(all_dict, names, deprecated, others, module_name, dots=True):
"""
Check that `all_dict` is consistent with the `names` in `module_name`
For instance, that there are no deprecated or extra objects.
Parameters
----------
all_dict : list
待检查的对象列表
names : set
参考指南中的对象名称集合
deprecated : list
废弃的对象列表
others : list
其他对象列表
module_name : ModuleType
模块名称
dots : bool
是否打印每次检查的点符号
Returns
-------
list
返回 [(name, success_flag, output)...] 的列表
"""
num_all = len(all_dict)
num_ref = len(names)
output = ""
output += "Non-deprecated objects in __all__: %i\n" % num_all
output += "Objects in refguide: %i\n\n" % num_ref
only_all, only_ref, missing = compare(all_dict, others, names, module_name)
dep_in_ref = only_ref.intersection(deprecated)
only_ref = only_ref.difference(deprecated)
if len(dep_in_ref) > 0:
output += "Deprecated objects in refguide::\n\n"
for name in sorted(deprecated):
output += " " + name + "\n"
if len(only_all) == len(only_ref) == len(missing) == 0:
if dots:
output_dot('.')
return [(None, True, output)]
else:
if len(only_all) > 0:
output += "ERROR: objects in %s.__all__ but not in refguide::\n\n" % module_name
for name in sorted(only_all):
output += " " + name + "\n"
output += "\nThis issue can be fixed by adding these objects to\n"
output += "the function listing in __init__.py for this module\n"
if len(only_ref) > 0:
output += "ERROR: objects in refguide but not in %s.__all__::\n\n" % module_name
for name in sorted(only_ref):
output += " " + name + "\n"
output += "\nThis issue should likely be fixed by removing these objects\n"
output += "from the function listing in __init__.py for this module\n"
output += "or adding them to __all__.\n"
if len(missing) > 0:
output += "ERROR: missing objects::\n\n"
for name in sorted(missing):
output += " " + name + "\n"
if dots:
output_dot('F')
return [(None, False, output)]
def validate_rst_syntax(text, name, dots=True):
"""
Validates the doc string in a snippet of documentation
`text` from file `name`
Parameters
----------
text : str
待验证的文档字符串内容
name : str
文档所属文件名
dots : bool
是否打印每次检查的点符号
Returns
-------
(bool, str)
返回元组,第一个元素表示验证结果,第二个元素是相关输出信息
"""
if text is None:
if dots:
output_dot('E')
return False, "ERROR: %s: no documentation" % (name,)
ok_unknown_items = set([
'mod', 'doc', 'currentmodule', 'autosummary', 'data', 'attr',
'obj', 'versionadded', 'versionchanged', 'module', 'class',
'ref', 'func', 'toctree', 'moduleauthor', 'term', 'c:member',
'sectionauthor', 'codeauthor', 'eq', 'doi', 'DOI', 'arXiv', 'arxiv'
])
error_stream = io.StringIO()
def resolve(name, is_label=False):
return ("http://foo", name)
token = '<RST-VALIDATE-SYNTAX-CHECK>'
docutils.core.publish_doctree(
text, token,
settings_overrides = dict(halt_level=5,
traceback=True,
default_reference_context='title-reference',
default_role='emphasis',
link_base='',
resolve_name=resolve,
stylesheet_path='',
raw_enabled=0,
file_insertion_enabled=0,
warning_stream=error_stream))
error_msg = error_stream.getvalue()
errors = error_msg.split(token)
success = True
output = ""
for error in errors:
lines = error.splitlines()
if not lines:
continue
m = re.match(r'.*Unknown (?:interpreted text role|directive type) "(.*)".*$', lines[0])
if m:
if m.group(1) in ok_unknown_items:
continue
m = re.match(r'.*Error in "math" directive:.*unknown option: "label"', " ".join(lines), re.S)
if m:
continue
output += name + lines[0] + "::\n " + "\n ".join(lines[1:]).rstrip() + "\n"
success = False
if not success:
output += " " + "-"*72 + "\n"
for lineno, line in enumerate(text.splitlines()):
output += " %-4d %s\n" % (lineno+1, line)
output += " " + "-"*72 + "\n\n"
if dots:
output_dot('.' if success else 'F')
return success, output
def output_dot(msg='.', stream=sys.stderr):
stream.write(msg)
stream.flush()
def check_rest(module, names, dots=True):
"""
Check reStructuredText formatting of docstrings
Parameters
----------
module : ModuleType
要检查的模块对象
names : set
包含要检查的名称的集合
Returns
-------
result : list
包含元组 (module_name, success_flag, output) 的列表
"""
try:
skip_types = (dict, str, unicode, float, int)
except NameError:
skip_types = (dict, str, float, int)
results = []
if module.__name__[6:] not in OTHER_MODULE_DOCS:
results += [(module.__name__,) +
validate_rst_syntax(inspect.getdoc(module),
module.__name__, dots=dots)]
for name in names:
full_name = module.__name__ + '.' + name
obj = getattr(module, name, None)
if obj is None:
results.append((full_name, False, "%s has no docstring" % (full_name,)))
continue
elif isinstance(obj, skip_types):
continue
if inspect.ismodule(obj):
text = inspect.getdoc(obj)
else:
try:
text = str(get_doc_object(obj))
except Exception:
import traceback
results.append((full_name, False,
"Error in docstring format!\n" +
traceback.format_exc()))
continue
m = re.search("([\x00-\x09\x0b-\x1f])", text)
if m:
msg = ("Docstring contains a non-printable character %r! "
"Maybe forgot r\"\"\"?" % (m.group(1),))
results.append((full_name, False, msg))
continue
try:
src_file = short_path(inspect.getsourcefile(obj))
except TypeError:
src_file = None
if src_file:
file_full_name = src_file + ':' + full_name
else:
file_full_name = full_name
results.append((full_name,) + validate_rst_syntax(text, file_full_name, dots=dots))
return results
DEFAULT_NAMESPACE = {'np': np}
CHECK_NAMESPACE = {
'np': np,
'numpy': np,
'assert_allclose': np.testing.assert_allclose,
'assert_equal': np.testing.assert_equal,
'array': np.array,
'matrix': np.matrix,
'int64': np.int64,
'uint64': np.uint64,
'int8': np.int8,
'int32': np.int32,
'float32': np.float32,
'float64': np.float64,
'dtype': np.dtype,
'nan': np.nan,
'inf': np.inf,
'StringIO': io.StringIO,
}
class DTRunner(doctest.DocTestRunner):
"""
The doctest runner
"""
DIVIDER = "\n"
def __init__(self, item_name, checker=None, verbose=None, optionflags=0):
self._item_name = item_name
doctest.DocTestRunner.__init__(self, checker=checker, verbose=verbose,
optionflags=optionflags)
def _report_item_name(self, out, new_line=False):
if self._item_name is not None:
if new_line:
out("\n")
self._item_name = None
def report_start(self, out, test, example):
self._checker._source = example.source
return doctest.DocTestRunner.report_start(self, out, test, example)
def report_success(self, out, test, example, got):
if self._verbose:
self._report_item_name(out, new_line=True)
return doctest.DocTestRunner.report_success(self, out, test, example, got)
def report_unexpected_exception(self, out, test, example, exc_info):
self._report_item_name(out)
return doctest.DocTestRunner.report_unexpected_exception(
self, out, test, example, exc_info)
def report_failure(self, out, test, example, got):
self._report_item_name(out)
return doctest.DocTestRunner.report_failure(self, out, test,
example, got)
class Checker(doctest.OutputChecker):
"""
自定义的输出检查器,继承自 doctest.OutputChecker 类。
"""
obj_pattern = re.compile('at 0x[0-9a-fA-F]+>')
vanilla = doctest.OutputChecker()
rndm_markers = {'# random', '# Random', '#random', '#Random', "# may vary",
"# uninitialized", "#uninitialized", "# uninit"}
stopwords = {'plt.', '.hist', '.show', '.ylim', '.subplot(',
'set_title', 'imshow', 'plt.show', '.axis(', '.plot(',
'.bar(', '.title', '.ylabel', '.xlabel', 'set_ylim', 'set_xlim',
'# reformatted', '.set_xlabel(', '.set_ylabel(', '.set_zlabel(',
'.set(xlim=', '.set(ylim=', '.set(xlabel=', '.set(ylabel='}
def __init__(self, parse_namedtuples=True, ns=None, atol=1e-8, rtol=1e-2):
"""
初始化方法,用于设置对象的各种属性。
参数:
- parse_namedtuples: 是否解析命名元组,默认为 True
- ns: 命名空间,用于检查,如果为 None 则使用 CHECK_NAMESPACE
- atol: 绝对误差容限,默认为 1e-8
- rtol: 相对误差容限,默认为 1e-2
"""
self.parse_namedtuples = parse_namedtuples
self.atol, self.rtol = atol, rtol
if ns is None:
self.ns = CHECK_NAMESPACE
else:
self.ns = ns
def check_output(self, want, got, optionflags):
if want == got:
return True
if any(word in self._source for word in self.stopwords):
return True
if any(word in want for word in self.rndm_markers):
return True
if self.obj_pattern.search(got):
return True
if want.lstrip().startswith("#"):
return True
try:
if self.vanilla.check_output(want, got, optionflags):
return True
except Exception:
pass
try:
a_want = eval(want, dict(self.ns))
a_got = eval(got, dict(self.ns))
except Exception:
s_want = want.strip()
s_got = got.strip()
cond = (s_want.startswith("[") and s_want.endswith("]") and
s_got.startswith("[") and s_got.endswith("]"))
if cond:
s_want = ", ".join(s_want[1:-1].split())
s_got = ", ".join(s_got[1:-1].split())
return self.check_output(s_want, s_got, optionflags)
if not self.parse_namedtuples:
return False
try:
num = len(a_want)
regex = (r'[\w\d_]+\(' +
', '.join([r'[\w\d_]+=(.+)']*num) +
r'\)')
grp = re.findall(regex, got.replace('\n', ' '))
if len(grp) > 1:
return False
got_again = '(' + ', '.join(grp[0]) + ')'
return self.check_output(want, got_again, optionflags)
except Exception:
return False
try:
return self._do_check(a_want, a_got)
except Exception:
try:
return all(self._do_check(w, g) for w, g in zip(a_want, a_got))
except (TypeError, ValueError):
return False
def _do_check(self, want, got):
try:
if want == got:
return True
except Exception:
pass
return np.allclose(want, got, atol=self.atol, rtol=self.rtol)
def _run_doctests(tests, full_name, verbose, doctest_warnings):
"""
Run modified doctests for the set of `tests`.
Parameters
----------
tests : list
包含测试用例的列表
full_name : str
完整名称字符串
verbose : bool
是否输出详细信息
doctest_warnings : bool
是否输出 doctest 的警告信息
Returns
-------
tuple(bool, list)
返回一个元组,包含成功标志和输出信息列表
"""
flags = NORMALIZE_WHITESPACE | ELLIPSIS
runner = DTRunner(full_name, checker=Checker(), optionflags=flags,
verbose=verbose)
output = io.StringIO(newline='')
success = True
tmp_stderr = sys.stdout if doctest_warnings else output
@contextmanager
def temp_cwd():
cwd = os.getcwd()
tmpdir = tempfile.mkdtemp()
try:
os.chdir(tmpdir)
yield tmpdir
finally:
os.chdir(cwd)
shutil.rmtree(tmpdir)
cwd = os.getcwd()
with np.errstate(), np.printoptions(), temp_cwd() as tmpdir, \
redirect_stderr(tmp_stderr):
np.random.seed(None)
ns = {}
for t in tests:
t.globs.update(ns)
t.filename = short_path(t.filename, cwd)
if any([SKIPBLOCK in ex.options for ex in t.examples]):
continue
fails, successes = runner.run(t, out=output.write, clear_globs=False)
if fails > 0:
success = False
ns = t.globs
output.seek(0)
return success, output.read()
def check_doctests(module, verbose, ns=None,
dots=True, doctest_warnings=False):
"""
Check code in docstrings of the module's public symbols.
Parameters
----------
module : ModuleType
要检查的模块对象
verbose : bool
是否输出详细信息
ns : dict
模块的命名空间
dots : bool
doctest_warnings : bool
Returns
-------
results : list
返回结果列表 [(item_name, success_flag, output), ...]
"""
if ns is None:
ns = dict(DEFAULT_NAMESPACE)
results = []
for name in get_all_dict(module)[0]:
full_name = module.__name__ + '.' + name
if full_name in DOCTEST_SKIPDICT:
skip_methods = DOCTEST_SKIPDICT[full_name]
if skip_methods is None:
continue
else:
skip_methods = None
try:
obj = getattr(module, name)
except AttributeError:
import traceback
results.append((full_name, False,
"Missing item!\n" +
traceback.format_exc()))
continue
finder = doctest.DocTestFinder()
try:
tests = finder.find(obj, name, globs=dict(ns))
except Exception:
import traceback
results.append((full_name, False,
"Failed to get doctests!\n" +
traceback.format_exc()))
continue
if skip_methods is not None:
tests = [i for i in tests if
i.name.partition(".")[2] not in skip_methods]
success, output = _run_doctests(tests, full_name, verbose,
doctest_warnings)
if dots:
output_dot('.' if success else 'F')
results.append((full_name, success, output))
if HAVE_MATPLOTLIB:
import matplotlib.pyplot as plt
plt.close('all')
return results
def check_doctests_testfile(fname, verbose, ns=None,
dots=True, doctest_warnings=False):
"""
Check code in a text file.
Mimic `check_doctests` above, differing mostly in test discovery.
(which is borrowed from stdlib's doctest.testfile here,
https://github.com/python-git/python/blob/master/Lib/doctest.py)
Parameters
----------
fname : str
File name
verbose : bool
是否输出详细信息
ns : dict
Name space,命名空间
dots : bool
是否显示点
doctest_warnings : bool
是否显示文档测试警告
Returns
-------
list
List of [(item_name, success_flag, output), ...]
Notes
-----
refguide can be signalled to skip testing code by adding
``#doctest: +SKIP`` to the end of the line. If the output varies or is
random, add ``# may vary`` or ``# random`` to the comment. for example
>>> plt.plot(...) # doctest: +SKIP
>>> random.randint(0,10)
5 # random
We also try to weed out pseudocode:
* We maintain a list of exceptions which signal pseudocode,
* We split the text file into "blocks" of code separated by empty lines
and/or intervening text.
* If a block contains a marker, the whole block is then assumed to be
pseudocode. It is then not being doctested.
The rationale is that typically, the text looks like this:
blah
<BLANKLINE>
>>> from numpy import some_module # pseudocode!
>>> func = some_module.some_function
>>> func(42) # still pseudocode
146
<BLANKLINE>
blah
<BLANKLINE>
>>> 2 + 3 # real code, doctest it
5
"""
if ns is None:
ns = CHECK_NAMESPACE
results = []
_, short_name = os.path.split(fname)
if short_name in DOCTEST_SKIPDICT:
return results
full_name = fname
with open(fname, encoding='utf-8') as f:
text = f.read()
PSEUDOCODE = set(['some_function', 'some_module', 'import example',
'ctypes.CDLL',
'integrate.nquad(func,'
])
parser = doctest.DocTestParser()
good_parts = []
base_line_no = 0
for part in text.split('\n\n'):
try:
tests = parser.get_doctest(part, ns, fname, fname, base_line_no)
except ValueError as e:
if e.args[0].startswith('line '):
parts = e.args[0].split()
parts[1] = str(int(parts[1]) + base_line_no)
e.args = (' '.join(parts),) + e.args[1:]
raise
if any(word in ex.source for word in PSEUDOCODE
for ex in tests.examples):
pass
else:
good_parts.append((part, base_line_no))
base_line_no += part.count('\n') + 2
tests = []
for good_text, line_no in good_parts:
tests.append(parser.get_doctest(good_text, ns, fname, fname, line_no))
success, output = _run_doctests(tests, full_name, verbose,
doctest_warnings)
if dots:
output_dot('.' if success else 'F')
results.append((full_name, success, output))
if HAVE_MATPLOTLIB:
import matplotlib.pyplot as plt
plt.close('all')
return results
def iter_included_files(base_path, verbose=0, suffixes=('.rst',)):
"""
Generator function to walk `base_path` and its subdirectories, skipping
files or directories in RST_SKIPLIST, and yield each file with a suffix in
`suffixes`
Parameters
----------
base_path : str
Base path of the directory to be processed
verbose : int
Verbosity level (default is 0)
suffixes : tuple
Tuple of suffixes to filter files (default is ('.rst',))
Yields
------
path : str
Path of the directory and its subdirectories containing files with specified suffixes
"""
if os.path.exists(base_path) and os.path.isfile(base_path):
yield base_path
for dir_name, subdirs, files in os.walk(base_path, topdown=True):
if dir_name in RST_SKIPLIST:
if verbose > 0:
sys.stderr.write('skipping files in %s' % dir_name)
files = []
for p in RST_SKIPLIST:
if p in subdirs:
if verbose > 0:
sys.stderr.write('skipping %s and subdirs' % p)
subdirs.remove(p)
for f in files:
if (os.path.splitext(f)[1] in suffixes and
f not in RST_SKIPLIST):
yield os.path.join(dir_name, f)
def check_documentation(base_path, results, args, dots):
"""
Check examples in any *.rst located inside `base_path`.
Add the output to `results`.
See Also
--------
check_doctests_testfile
"""
for filename in iter_included_files(base_path, args.verbose):
if dots:
sys.stderr.write(filename + ' ')
sys.stderr.flush()
tut_results = check_doctests_testfile(
filename,
(args.verbose >= 2), dots=dots,
doctest_warnings=args.doctest_warnings)
def scratch():
pass
scratch.__name__ = filename
results.append((scratch, tut_results))
if dots:
sys.stderr.write('\n')
sys.stderr.flush()
def init_matplotlib():
"""
Check feasibility of matplotlib initialization.
"""
global HAVE_MATPLOTLIB
try:
import matplotlib
matplotlib.use('Agg')
HAVE_MATPLOTLIB = True
except ImportError:
HAVE_MATPLOTLIB = False
def main(argv):
"""
Validates the docstrings of all the pre decided set of
modules for errors and docstring standards.
"""
parser = ArgumentParser(usage=__doc__.lstrip())
parser.add_argument("module_names", metavar="SUBMODULES", default=[],
nargs='*', help="Submodules to check (default: all public)")
parser.add_argument("--doctests", action="store_true",
help="Run also doctests on ")
parser.add_argument("-v", "--verbose", action="count", default=0)
parser.add_argument("--doctest-warnings", action="store_true",
help="Enforce warning checking for doctests")
parser.add_argument("--rst", nargs='?', const='doc', default=None,
help=("Run also examples from *rst files "
"discovered walking the directory(s) specified, "
"defaults to 'doc'"))
args = parser.parse_args(argv)
modules = []
names_dict = {}
if not args.module_names:
args.module_names = list(PUBLIC_SUBMODULES) + [BASE_MODULE]
os.environ['SCIPY_PIL_IMAGE_VIEWER'] = 'true'
module_names = list(args.module_names)
for name in module_names:
if name in OTHER_MODULE_DOCS:
name = OTHER_MODULE_DOCS[name]
if name not in module_names:
module_names.append(name)
dots = True
success = True
results = []
errormsgs = []
if args.doctests or args.rst:
init_matplotlib()
for submodule_name in module_names:
prefix = BASE_MODULE + '.'
if not (
submodule_name.startswith(prefix) or
submodule_name == BASE_MODULE
):
module_name = prefix + submodule_name
else:
module_name = submodule_name
__import__(module_name)
module = sys.modules[module_name]
if submodule_name not in OTHER_MODULE_DOCS:
find_names(module, names_dict)
if submodule_name in args.module_names:
modules.append(module)
if args.doctests or not args.rst:
print("Running checks for %d modules:" % (len(modules),))
for module in modules:
if dots:
sys.stderr.write(module.__name__ + ' ')
sys.stderr.flush()
all_dict, deprecated, others = get_all_dict(module)
names = names_dict.get(module.__name__, set())
mod_results = []
mod_results += check_items(all_dict, names, deprecated, others,
module.__name__)
mod_results += check_rest(module, set(names).difference(deprecated),
dots=dots)
if args.doctests:
mod_results += check_doctests(module, (args.verbose >= 2), dots=dots,
doctest_warnings=args.doctest_warnings)
for v in mod_results:
assert isinstance(v, tuple), v
results.append((module, mod_results))
if dots:
sys.stderr.write('\n')
sys.stderr.flush()
if args.rst:
base_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')
rst_path = os.path.relpath(os.path.join(base_dir, args.rst))
if os.path.exists(rst_path):
print('\nChecking files in %s:' % rst_path)
check_documentation(rst_path, results, args, dots)
else:
sys.stderr.write(f'\ninvalid --rst argument "{args.rst}"')
errormsgs.append('invalid directory argument to --rst')
if dots:
sys.stderr.write("\n")
sys.stderr.flush()
for module, mod_results in results:
success = all(x[1] for x in mod_results)
if not success:
errormsgs.append(f'failed checking {module.__name__}')
if success and args.verbose == 0:
continue
print("")
print("=" * len(module.__name__))
print(module.__name__)
print("=" * len(module.__name__))
print("")
for name, success, output in mod_results:
if name is None:
if not success or args.verbose >= 1:
print(output.strip())
print("")
elif not success or (args.verbose >= 2 and output.strip()):
print(name)
print("-" * len(name))
print("")
print(output.strip())
print("")
if len(errormsgs) == 0:
print("\nOK: all checks passed!")
sys.exit(0)
else:
print('\nERROR: ', '\n '.join(errormsgs))
sys.exit(1)
if __name__ == '__main__':
main(argv=sys.argv[1:])
.\numpy\tools\swig\test\Array1.cxx
// 引入自定义头文件 Array1.h
// 引入标准输入输出流库
// 引入字符串流库
// Array1 类的构造函数,支持默认、长度、数组输入
Array1::Array1(int length, long* data) :
_ownData(false), _length(0), _buffer(0)
{
// 调整数组大小并分配内存
resize(length, data);
}
// Array1 类的拷贝构造函数
Array1::Array1(const Array1 & source) :
_length(source._length)
{
// 分配内存并复制源对象的数据
allocateMemory();
*this = source;
}
// Array1 类的析构函数,释放内存
Array1::~Array1()
{
// 释放对象所持有的内存
deallocateMemory();
}
// Array1 类的赋值运算符重载
Array1 & Array1::operator=(const Array1 & source)
{
// 比较长度并选择较小的长度进行赋值
int len = _length < source._length ? _length : source._length;
for (int i=0; i < len; ++i)
{
(*this)[i] = source[i];
}
return *this;
}
// Array1 类的相等比较运算符重载
bool Array1::operator==(const Array1 & other) const
{
// 比较数组长度及每个元素是否相等
if (_length != other._length) return false;
for (int i=0; i < _length; ++i)
{
if ((*this)[i] != other[i]) return false;
}
return true;
}
// 获取数组长度的访问器
int Array1::length() const
{
return _length;
}
// 调整数组大小的方法
void Array1::resize(int length, long* data)
{
// 检查长度是否小于零,抛出异常
if (length < 0) throw std::invalid_argument("Array1 length less than 0");
// 若长度与当前长度相同,则直接返回
if (length == _length) return;
// 释放当前内存,并重新分配新的内存
deallocateMemory();
_length = length;
// 如果没有指定数据,则分配新内存
if (!data)
{
allocateMemory();
}
else
{
// 否则使用提供的数据并设置标志
_ownData = false;
_buffer = data;
}
}
// 设置数组元素的访问器
long & Array1::operator[](int i)
{
// 检查索引是否越界,抛出异常
if (i < 0 || i >= _length) throw std::out_of_range("Array1 index out of range");
return _buffer[i];
}
// 获取数组元素的访问器(常量版本)
const long & Array1::operator[](int i) const
{
// 检查索引是否越界,抛出异常
if (i < 0 || i >= _length) throw std::out_of_range("Array1 index out of range");
return _buffer[i];
}
// 以字符串形式输出数组
std::string Array1::asString() const
{
// 使用字符串流生成数组的字符串表示
std::stringstream result;
result << "[";
for (int i=0; i < _length; ++i)
{
result << " " << _buffer[i];
if (i < _length-1) result << ",";
}
result << " ]";
return result.str();
}
// 获取数组视图的方法
void Array1::view(long** data, int* length) const
{
// 返回数组缓冲区及其长度
*data = _buffer;
*length = _length;
}
// Array1 类的私有方法,分配内存
void Array1::allocateMemory()
{
// 如果长度为零,则设置标志并置空缓冲区
if (_length == 0)
{
_ownData = false;
_buffer = 0;
}
else
{
// 否则分配新内存,并设置标志
_ownData = true;
_buffer = new long[_length];
}
}
// Array1 类的私有方法,释放内存
void Array1::deallocateMemory()
{
// 如果对象持有数据并且缓冲区不为空,则释放内存
if (_ownData && _length && _buffer)
{
delete [] _buffer;
}
// 重置标志及长度,置空缓冲区
_ownData = false;
_length = 0;
_buffer = 0;
}
.\numpy\tools\swig\test\Array1.h
class Array1
{
public:
// Default/length/array constructor
// 默认构造函数,可以指定长度和数据
Array1(int length = 0, long* data = 0);
// Copy constructor
// 拷贝构造函数
Array1(const Array1 & source);
// Destructor
// 析构函数
~Array1();
// Assignment operator
// 赋值运算符重载
Array1 & operator=(const Array1 & source);
// Equals operator
// 等号运算符重载,用于比较两个对象是否相等
bool operator==(const Array1 & other) const;
// Length accessor
// 获取数组长度
int length() const;
// Resize array
// 调整数组大小
void resize(int length, long* data = 0);
// Set item accessor
// 设置数组元素
long & operator[](int i);
// Get item accessor
// 获取数组元素(常量版本)
const long & operator[](int i) const;
// String output
// 返回数组的字符串表示
std::string asString() const;
// Get view
// 获取数组数据和长度的视图
void view(long** data, int* length) const;
private:
// Members
// 成员变量
bool _ownData; // 是否拥有数据
int _length; // 数组长度
long * _buffer; // 数据缓冲区
// Methods
// 私有方法
void allocateMemory(); // 分配内存
void deallocateMemory(); // 释放内存
};
.\numpy\tools\swig\test\Array2.cxx
// 包含 Array2 类的声明文件
// 包含用于字符串流操作的头文件
// 默认构造函数
Array2::Array2() :
_ownData(false), _nrows(0), _ncols(), _buffer(0), _rows(0)
{ }
// 大小和数组构造函数
Array2::Array2(int nrows, int ncols, long* data) :
_ownData(false), _nrows(0), _ncols(), _buffer(0), _rows(0)
{
// 调整数组大小,并使用传入的数据进行初始化
resize(nrows, ncols, data);
}
// 复制构造函数
Array2::Array2(const Array2 & source) :
_nrows(source._nrows), _ncols(source._ncols)
{
// 分配内存并将数据复制到当前对象
_ownData = true;
allocateMemory();
*this = source;
}
// 析构函数
Array2::~Array2()
{
// 释放对象占用的内存
deallocateMemory();
}
// 赋值运算符重载
Array2 & Array2::operator=(const Array2 & source)
{
// 按最小行列数复制数据到当前对象
int nrows = _nrows < source._nrows ? _nrows : source._nrows;
int ncols = _ncols < source._ncols ? _ncols : source._ncols;
for (int i=0; i < nrows; ++i)
{
for (int j=0; j < ncols; ++j)
{
(*this)[i][j] = source[i][j];
}
}
return *this;
}
// 等于运算符重载
bool Array2::operator==(const Array2 & other) const
{
// 检查是否行数、列数及所有元素相同
if (_nrows != other._nrows) return false;
if (_ncols != other._ncols) return false;
for (int i=0; i < _nrows; ++i)
{
for (int j=0; j < _ncols; ++j)
{
if ((*this)[i][j] != other[i][j]) return false;
}
}
return true;
}
// 获取行数
int Array2::nrows() const
{
return _nrows;
}
// 获取列数
int Array2::ncols() const
{
return _ncols;
}
// 调整数组大小
void Array2::resize(int nrows, int ncols, long* data)
{
// 检查行数和列数是否合法,如果相同则直接返回,否则重新分配内存
if (nrows < 0) throw std::invalid_argument("Array2 nrows less than 0");
if (ncols < 0) throw std::invalid_argument("Array2 ncols less than 0");
if (nrows == _nrows && ncols == _ncols) return;
deallocateMemory();
_nrows = nrows;
_ncols = ncols;
if (!data)
{
allocateMemory();
}
else
{
_ownData = false;
_buffer = data;
allocateRows();
}
}
// 仅调整数组大小(重载版本)
void Array2::resize(int nrows, int ncols)
{
resize(nrows, ncols, nullptr);
}
// 设置元素访问器
Array1 & Array2::operator[](int i)
{
// 检查行索引是否在合法范围内,然后返回对应行对象
if (i < 0 || i >= _nrows) throw std::out_of_range("Array2 row index out of range");
return _rows[i];
}
// 获取元素访问器
const Array1 & Array2::operator[](int i) const
{
// 检查行索引是否在合法范围内,然后返回对应行对象(常量版本)
if (i < 0 || i >= _nrows) throw std::out_of_range("Array2 row index out of range");
return _rows[i];
}
// 返回数组的字符串表示
std::string Array2::asString() const
{
std::stringstream result;
result << "[ ";
for (int i=0; i < _nrows; ++i)
{
if (i > 0) result << " ";
result << (*this)[i].asString();
if (i < _nrows-1) result << "," << std::endl;
}
result << " ]" << std::endl;
return result.str();
}
// 获取视图
void Array2::view(int* nrows, int* ncols, long** data) const
{
// 返回当前数组的行数、列数及数据指针
*nrows = _nrows;
*ncols = _ncols;
*data = _buffer;
}
// 私有方法:分配内存
void Array2::allocateMemory()
{
// 如果行列数为零,则将数据标记为非拥有,并清空缓冲区及行对象
if (_nrows * _ncols == 0)
{
_ownData = false;
_buffer = 0;
_rows = 0;
}
else
{
// 否则分配内存,并标记为拥有数据,然后分配每一行的内存
_ownData = true;
_buffer = new long[_nrows*_ncols];
allocateRows();
}
}
// 私有方法:分配每一行的内存
void Array2::allocateRows()
{
_rows = new Array1[_nrows];
for (int i=0; i < _nrows; ++i)
{
_rows[i].resize(_ncols, &_buffer[i*_ncols]);
// 调整第 i 行的大小,确保其有 _ncols 列,使用 _buffer 中的数据作为初始内容
_rows[i].resize(_ncols, &_buffer[i*_ncols]);
这行代码的作用是调整 `_rows` 中第 `i` 行的大小,确保它包含 `_ncols` 列,并使用 `_buffer` 中第 `i*_ncols` 列开始的数据作为初始内容。`.resize()` 方法用于调整容器大小,并可以指定初始值。
}
void Array2::deallocateMemory()
{
// 检查是否需要释放内存:确保_ownData为true,且_nrows和_ncols大于0,_buffer非空
if (_ownData && _nrows*_ncols && _buffer)
{
// 删除_rows数组,释放行指针内存
delete [] _rows;
// 删除_buffer数组,释放数据内存
delete [] _buffer;
}
// 重置所有成员变量,表示内存已被释放
_ownData = false;
_nrows = 0;
_ncols = 0;
_buffer = 0;
_rows = 0;
}
.\numpy\tools\swig\test\Array2.h
class Array2
{
public:
// 默认构造函数
Array2();
// 带大小和数组数据的构造函数
Array2(int nrows, int ncols, long* data=0);
// 复制构造函数
Array2(const Array2 & source);
// 析构函数
~Array2();
// 赋值运算符重载
Array2 & operator=(const Array2 & source);
// 等于运算符重载
bool operator==(const Array2 & other) const;
// 获取行数和列数
int nrows() const;
int ncols() const;
// 调整数组大小
void resize(int nrows, int ncols, long* data);
void resize(int nrows, int ncols);
// 设置元素访问器
Array1 & operator[](int i);
// 获取元素访问器
const Array1 & operator[](int i) const;
// 输出为字符串
std::string asString() const;
// 获取视图
void view(int* nrows, int* ncols, long** data) const;
private:
// 成员变量
bool _ownData; // 指示是否拥有数据
int _nrows; // 行数
int _ncols; // 列数
long * _buffer; // 缓冲区指针
Array1 * _rows; // 行数组指针
// 私有方法
void allocateMemory(); // 分配内存的方法
void allocateRows(); // 分配行的方法
void deallocateMemory(); // 释放内存的方法
};
.\numpy\tools\swig\test\ArrayZ.cxx
// 包含自定义头文件 ArrayZ.h 和标准输入输出流头文件
// 默认/长度/数组构造函数
ArrayZ::ArrayZ(int length, std::complex<double>* data) :
_ownData(false), _length(0), _buffer(0)
{
// 调整数组大小并分配内存
resize(length, data);
}
// 拷贝构造函数
ArrayZ::ArrayZ(const ArrayZ & source) :
_length(source._length)
{
// 分配内存并复制源对象数据
allocateMemory();
*this = source;
}
// 析构函数
ArrayZ::~ArrayZ()
{
// 释放对象的内存
deallocateMemory();
}
// 赋值运算符重载
ArrayZ & ArrayZ::operator=(const ArrayZ & source)
{
// 比较并复制长度较小的数据
int len = _length < source._length ? _length : source._length;
for (int i=0; i < len; ++i)
{
(*this)[i] = source[i];
}
return *this;
}
// 等号运算符重载
bool ArrayZ::operator==(const ArrayZ & other) const
{
// 比较数组长度及每个元素是否相等
if (_length != other._length) return false;
for (int i=0; i < _length; ++i)
{
if ((*this)[i] != other[i]) return false;
}
return true;
}
// 返回数组长度的访问器
int ArrayZ::length() const
{
return _length;
}
// 调整数组大小
void ArrayZ::resize(int length, std::complex<double>* data)
{
// 检查长度是否合法
if (length < 0) throw std::invalid_argument("ArrayZ length less than 0");
// 若长度不变则直接返回
if (length == _length) return;
// 释放当前内存并重新分配
deallocateMemory();
_length = length;
if (!data)
{
// 分配新的内存
allocateMemory();
}
else
{
// 使用传入的数据作为数组缓冲区
_ownData = false;
_buffer = data;
}
}
// 设置元素访问器
std::complex<double> & ArrayZ::operator[](int i)
{
// 检查索引是否有效
if (i < 0 || i >= _length) throw std::out_of_range("ArrayZ index out of range");
return _buffer[i];
}
// 获取元素访问器(常量版本)
const std::complex<double> & ArrayZ::operator[](int i) const
{
// 检查索引是否有效
if (i < 0 || i >= _length) throw std::out_of_range("ArrayZ index out of range");
return _buffer[i];
}
// 生成数组的字符串表示
std::string ArrayZ::asString() const
{
std::stringstream result;
result << "[";
for (int i=0; i < _length; ++i)
{
result << " " << _buffer[i];
if (i < _length-1) result << ",";
}
result << " ]";
return result.str();
}
// 获取数组视图
void ArrayZ::view(std::complex<double>** data, int* length) const
{
// 返回数组缓冲区及其长度
*data = _buffer;
*length = _length;
}
// 私有方法:分配内存
void ArrayZ::allocateMemory()
{
if (_length == 0)
{
_ownData = false;
_buffer = 0;
}
else
{
// 分配新的数组内存
_ownData = true;
_buffer = new std::complex<double>[_length];
}
}
// 私有方法:释放内存
void ArrayZ::deallocateMemory()
{
// 如果拥有数据并且数组长度大于0,则释放内存
if (_ownData && _length && _buffer)
{
delete [] _buffer;
}
_ownData = false;
_length = 0;
_buffer = 0;
}
.\numpy\tools\swig\test\ArrayZ.h
class ArrayZ
{
public:
// 默认/长度/数组构造函数
ArrayZ(int length = 0, std::complex<double>* data = 0);
// 复制构造函数
ArrayZ(const ArrayZ & source);
// 析构函数
~ArrayZ();
// 赋值运算符重载
ArrayZ & operator=(const ArrayZ & source);
// 等于运算符重载
bool operator==(const ArrayZ & other) const;
// 长度访问器
int length() const;
// 调整数组大小
void resize(int length, std::complex<double>* data = 0);
// 设置元素访问器
std::complex<double> & operator[](int i);
// 获取元素访问器
const std::complex<double> & operator[](int i) const;
// 字符串输出
std::string asString() const;
// 获取视图
void view(std::complex<double>** data, int* length) const;
private:
// 成员变量
bool _ownData; // 是否拥有数据的标志
int _length; // 数组长度
std::complex<double> * _buffer; // 数据缓冲区指针
// 私有方法
void allocateMemory(); // 分配内存方法
void deallocateMemory(); // 释放内存方法
};
.\numpy\tools\swig\test\Farray.cxx
// 包含自定义头文件 "Farray.h" 和标准头文件 <sstream>
// Farray 类的 Size 构造函数
Farray::Farray(int nrows, int ncols) :
_nrows(nrows), _ncols(ncols), _buffer(0)
{
// 分配内存空间
allocateMemory();
}
// Farray 类的复制构造函数
Farray::Farray(const Farray & source) :
_nrows(source._nrows), _ncols(source._ncols)
{
// 分配内存空间
allocateMemory();
// 使用赋值运算符进行复制
*this = source;
}
// Farray 类的析构函数
Farray::~Farray()
{
// 释放动态分配的内存
delete [] _buffer;
}
// Farray 类的赋值运算符重载
Farray & Farray::operator=(const Farray & source)
{
// 确定有效的行和列数
int nrows = _nrows < source._nrows ? _nrows : source._nrows;
int ncols = _ncols < source._ncols ? _ncols : source._ncols;
// 逐个元素进行赋值
for (int i=0; i < nrows; ++i)
{
for (int j=0; j < ncols; ++j)
{
(*this)(i,j) = source(i,j);
}
}
return *this;
}
// Farray 类的相等运算符重载
bool Farray::operator==(const Farray & other) const
{
// 检查行数和列数是否相等
if (_nrows != other._nrows) return false;
if (_ncols != other._ncols) return false;
// 逐个元素比较
for (int i=0; i < _nrows; ++i)
{
for (int j=0; j < _ncols; ++j)
{
if ((*this)(i,j) != other(i,j)) return false;
}
}
return true;
}
// 获取行数的访问器
int Farray::nrows() const
{
return _nrows;
}
// 获取列数的访问器
int Farray::ncols() const
{
return _ncols;
}
// 设置元素的访问器
long & Farray::operator()(int i, int j)
{
// 检查行和列的索引是否有效
if (i < 0 || i >= _nrows) throw std::out_of_range("Farray row index out of range");
if (j < 0 || j >= _ncols) throw std::out_of_range("Farray col index out of range");
return _buffer[offset(i,j)];
}
// 获取元素的访问器
const long & Farray::operator()(int i, int j) const
{
// 检查行和列的索引是否有效
if (i < 0 || i >= _nrows) throw std::out_of_range("Farray row index out of range");
if (j < 0 || j >= _ncols) throw std::out_of_range("Farray col index out of range");
return _buffer[offset(i,j)];
}
// 以字符串形式输出 Farray 对象
std::string Farray::asString() const
{
std::stringstream result;
result << "[ ";
// 遍历数组元素并格式化输出
for (int i=0; i < _nrows; ++i)
{
if (i > 0) result << " ";
result << "[";
for (int j=0; j < _ncols; ++j)
{
result << " " << (*this)(i,j);
if (j < _ncols-1) result << ",";
}
result << " ]";
if (i < _nrows-1) result << "," << std::endl;
}
result << " ]" << std::endl;
return result.str();
}
// 获取视图的访问器
void Farray::view(int* nrows, int* ncols, long** data) const
{
// 返回行数、列数和数据指针的引用
*nrows = _nrows;
*ncols = _ncols;
*data = _buffer;
}
// 私有方法:分配内存空间
void Farray::allocateMemory()
{
// 检查行数和列数是否有效
if (_nrows <= 0) throw std::invalid_argument("Farray nrows <= 0");
if (_ncols <= 0) throw std::invalid_argument("Farray ncols <= 0");
// 分配内存空间
_buffer = new long[_nrows*_ncols];
}
// 内联方法:计算元素在数组中的偏移量
inline int Farray::offset(int i, int j) const
{
return i + j * _nrows;
}
.\numpy\tools\swig\test\Farray.h
class Farray
{
public:
// Size constructor
Farray(int nrows, int ncols);
// Copy constructor
Farray(const Farray & source);
// Destructor
~Farray();
// Assignment operator
Farray & operator=(const Farray & source);
// Equals operator
bool operator==(const Farray & other) const;
// Length accessors
int nrows() const;
int ncols() const;
// Set item accessor
long & operator()(int i, int j);
// Get item accessor
const long & operator()(int i, int j) const;
// String output
std::string asString() const;
// Get view
void view(int* nrows, int* ncols, long** data) const;
private:
// Members
int _nrows; // 行数
int _ncols; // 列数
long * _buffer; // 数据存储缓冲区
// Default constructor: not implemented
Farray();
// Methods
void allocateMemory(); // 分配内存方法
int offset(int i, int j) const; // 计算索引偏移量的方法
};
这些注释提供了每个类成员和方法的简要解释,帮助理解其用途和功能。
.\numpy\tools\swig\test\Flat.cxx
// 宏定义 TEST_FUNCS(TYPE, SNAME) 定义了一组函数,用于处理给定类型 TYPE 的 1D 数组,
// 函数名称采用 SNAMEProcess(TYPE * array, int size) 的格式。
// 这些函数用于测试 numpy 接口,处理:
//
// * 具有固定元素数量的多维度原地数组
//
\
// 实现处理函数 SNAME
void SNAME
// 遍历数组,对每个元素执行加一操作
for (int i=0; i<size; ++i) array[i] += 1; \
}
// 以下展开宏 TEST_FUNCS(TYPE, SNAME) 以生成具体的处理函数
// 生成 signed char 类型的处理函数 scharProcess
TEST_FUNCS(signed char , schar )
// 生成 unsigned char 类型的处理函数 ucharProcess
TEST_FUNCS(unsigned char , uchar )
// 生成 short 类型的处理函数 shortProcess
TEST_FUNCS(short , short )
// 生成 unsigned short 类型的处理函数 ushortProcess
TEST_FUNCS(unsigned short , ushort )
// 生成 int 类型的处理函数 intProcess
TEST_FUNCS(int , int )
// 生成 unsigned int 类型的处理函数 uintProcess
TEST_FUNCS(unsigned int , uint )
// 生成 long 类型的处理函数 longProcess
TEST_FUNCS(long , long )
// 生成 unsigned long 类型的处理函数 ulongProcess
TEST_FUNCS(unsigned long , ulong )
// 生成 long long 类型的处理函数 longLongProcess
TEST_FUNCS(long long , longLong )
// 生成 unsigned long long 类型的处理函数 ulongLongProcess
TEST_FUNCS(unsigned long long, ulongLong)
// 生成 float 类型的处理函数 floatProcess
TEST_FUNCS(float , float )
// 生成 double 类型的处理函数 doubleProcess
TEST_FUNCS(double , double )
.\numpy\tools\swig\test\Flat.h
// 宏定义开始:定义了一组函数原型,用于处理不同类型数组
//
// void SNAMEProcess(TYPE * array, int size);
//
// 对于任意指定的类型 TYPE(例如:signed char, unsigned int, long long 等)和简短名称 SNAME
// (例如:schar, uint, longLong 等)。该宏会根据给定的 TYPE/SNAME 组合进行扩展,生成用于测试
// numpy 接口的函数。
//
\
void SNAME
// 以下是针对不同类型和简短名称的宏扩展示例:
// 扩展为 signed char 对应的处理函数原型
TEST_FUNC_PROTOS(signed char , schar )
// 扩展为 unsigned char 对应的处理函数原型
TEST_FUNC_PROTOS(unsigned char , uchar )
// 扩展为 short 对应的处理函数原型
TEST_FUNC_PROTOS(short , short )
// 扩展为 unsigned short 对应的处理函数原型
TEST_FUNC_PROTOS(unsigned short , ushort )
// 扩展为 int 对应的处理函数原型
TEST_FUNC_PROTOS(int , int )
// 扩展为 unsigned int 对应的处理函数原型
TEST_FUNC_PROTOS(unsigned int , uint )
// 扩展为 long 对应的处理函数原型
TEST_FUNC_PROTOS(long , long )
// 扩展为 unsigned long 对应的处理函数原型
TEST_FUNC_PROTOS(unsigned long , ulong )
// 扩展为 long long 对应的处理函数原型
TEST_FUNC_PROTOS(long long , longLong )
// 扩展为 unsigned long long 对应的处理函数原型
TEST_FUNC_PROTOS(unsigned long long, ulongLong)
// 扩展为 float 对应的处理函数原型
TEST_FUNC_PROTOS(float , float )
// 扩展为 double 对应的处理函数原型
TEST_FUNC_PROTOS(double , double )
.\numpy\tools\swig\test\Fortran.cxx
// 宏定义:TEST_FUNCS(TYPE, SNAME)
// 生成一个函数,用于获取矩阵中第二个元素的值,根据传入的类型和名称生成函数
\
// 函数定义:TYPE SNAME
// 返回矩阵中第二个元素的值,根据传入的类型和名称生成函数
TYPE SNAME
// 从矩阵中取出第二个元素的值
TYPE result = matrix[1]; \
// 返回获取到的值
return result; \
} \
// 生成各种类型的 SecondElement 函数
TEST_FUNCS(signed char , schar )
TEST_FUNCS(unsigned char , uchar )
TEST_FUNCS(short , short )
TEST_FUNCS(unsigned short , ushort )
TEST_FUNCS(int , int )
TEST_FUNCS(unsigned int , uint )
TEST_FUNCS(long , long )
TEST_FUNCS(unsigned long , ulong )
TEST_FUNCS(long long , longLong )
TEST_FUNCS(unsigned long long, ulongLong)
TEST_FUNCS(float , float )
TEST_FUNCS(double , double )
这段代码定义了一个宏 `TEST_FUNCS`,用于生成多个函数,这些函数根据不同的数据类型和名称,从传入的数组中获取第二个元素的值,并返回。
.\numpy\tools\swig\test\Fortran.h
// 定义一个宏 TEST_FUNC_PROTOS,用于生成函数原型声明
\
// 声明一个函数原型,函数名由 SNAME 和 "SecondElement" 组成,参数为指向 TYPE 类型数组的指针、行数和列数
TYPE SNAME
// 使用宏 TEST_FUNC_PROTOS 分别生成以下类型的函数原型声明
// 以下各行对应的注释格式相同,仅注明了宏的作用和生成的函数原型的具体参数
TEST_FUNC_PROTOS(signed char , schar )
TEST_FUNC_PROTOS(unsigned char , uchar )
TEST_FUNC_PROTOS(short , short )
TEST_FUNC_PROTOS(unsigned short , ushort )
TEST_FUNC_PROTOS(int , int )
TEST_FUNC_PROTOS(unsigned int , uint )
TEST_FUNC_PROTOS(long , long )
TEST_FUNC_PROTOS(unsigned long , ulong )
TEST_FUNC_PROTOS(long long , longLong )
TEST_FUNC_PROTOS(unsigned long long, ulongLong)
TEST_FUNC_PROTOS(float , float )
TEST_FUNC_PROTOS(double , double )
.\numpy\tools\swig\test\Matrix.cxx
// 引入标准库头文件
// 引入输入输出流库头文件
// 引入自定义矩阵头文件
// 下面的宏定义了一组针对 2D 数组的函数,这些函数接受特定类型 TYPE 的参数,并使用简称 SNAME
// 宏展开如下,对于给定的类型 TYPE 和简称 SNAME:
//
// TYPE SNAME
// TYPE SNAME
// TYPE SNAME
// void SNAME
//
// 这些函数分别用于处理:
// * 2x2 矩阵的行列式计算
// * 任意行列数的二维数组中的最大值计算
// * 任意行列数的二维数组中的最小值计算
// * 3x3 矩阵的每个元素与给定值的乘积计算
\
TYPE SNAME
return matrix[0][0]*matrix[1][1] - matrix[0][1]*matrix[1][0]; \
} \
\
TYPE SNAME
int i, j, index; \
TYPE result = matrix[0]; \
// 遍历二维数组找出最大值
for (j=0; j<cols; ++j) { \
for (i=0; i<rows; ++i) { \
index = j*rows + i; \
if (matrix[index] > result) result = matrix[index]; \
} \
} \
return result; \
} \
\
TYPE SNAME
int i, j, index; \
TYPE result = matrix[0]; \
// 遍历二维数组找出最小值
for (j=0; j<cols; ++j) { \
for (i=0; i<rows; ++i) { \
index = j*rows + i; \
if (matrix[index] < result) result = matrix[index]; \
} \
} \
return result; \
} \
\
void SNAME
// 遍历 3x3 矩阵,每个元素乘以给定的值
for (int i=0; i<3; ++i) \
for (int j=0; j<3; ++j) \
array[i][j] *= val; \
} \
\
void SNAME
int i, j, index; \
for (j=0; j<cols; ++j) { \
for (i=0; i<rows; ++i) { \
index = j*rows + i; \
if (array[index] < floor) array[index] = floor; \
} \
} \
} \
\
void SNAME
int i, j, index; \
for (j=0; j<cols; ++j) { \
for (i=0; i<rows; ++i) { \
index = j*rows + i; \
if (array[index] > ceil) array[index] = ceil; \
} \
} \
} \
\
void SNAME
for (int i=0; i<3; ++i) { \
for (int j=0; j<3; ++j) { \
if (i >= j) { \
lower[i][j] = matrix[i][j]; \
upper[i][j] = 0; \
} else { \
lower[i][j] = 0; \
upper[i][j] = matrix[i][j]; \
} \
} \
} \
}
TEST_FUNCS(signed char , schar )
TEST_FUNCS(unsigned char , uchar )
TEST_FUNCS(short , short )
TEST_FUNCS(unsigned short , ushort )
TEST_FUNCS(int , int )
TEST_FUNCS(unsigned int , uint )
TEST_FUNCS(long , long )
TEST_FUNCS(unsigned long , ulong )
TEST_FUNCS(long long , longLong )
TEST_FUNCS(unsigned long long, ulongLong)
TEST_FUNCS(float , float )
TEST_FUNCS(double , double )
.\numpy\tools\swig\test\Matrix.h
// 如果未定义过 MATRIX_H,则开始定义 MATRIX_H
// 以下宏定义了一组函数原型,这些函数用于处理二维数组,其形式为
//
// TYPE SNAMEDet( TYPE matrix[2][2]);
// TYPE SNAMEMax( TYPE * matrix, int rows, int cols);
// TYPE SNAMEMin( int rows, int cols, TYPE * matrix);
// void SNAMEScale( TYPE array[3][3], TYPE val);
// void SNAMEFloor( TYPE * array, int rows, int cols, TYPE floor);
// void SNAMECeil( int rows, int cols, TYPE * array, TYPE ceil );
// void SNAMELUSplit(TYPE matrix[3][3], TYPE lower[3][3], TYPE upper[3][3]);
//
// 对于任何指定的类型 TYPE(例如:short、unsigned int、long long 等),以及给定的短名称 SNAME(例如:schar、uint、longLong 等)。
// 然后,根据给定的 TYPE/SNAME 对展开宏。生成的函数用于测试 numpy 接口,分别用于:
//
// * 二维输入数组,硬编码长度
// * 二维输入数组
// * 二维输入数组,数据在最后
// * 二维原位数组,硬编码长度
// * 二维原位数组
// * 二维原位数组,数据在最后
// * 二维输出参数数组,硬编码长度
//
\
TYPE SNAME
TYPE SNAME
TYPE SNAME
void SNAME
void SNAME
void SNAME
void SNAME
// 使用宏 TEST_FUNC_PROTOS 展开各种类型和对应的短名称,定义相关的函数原型
TEST_FUNC_PROTOS(signed char , schar )
TEST_FUNC_PROTOS(unsigned char , uchar )
TEST_FUNC_PROTOS(short , short )
TEST_FUNC_PROTOS(unsigned short , ushort )
TEST_FUNC_PROTOS(int , int )
TEST_FUNC_PROTOS(unsigned int , uint )
TEST_FUNC_PROTOS(long , long )
TEST_FUNC_PROTOS(unsigned long , ulong )
TEST_FUNC_PROTOS(long long , longLong )
TEST_FUNC_PROTOS(unsigned long long, ulongLong)
TEST_FUNC_PROTOS(float , float )
TEST_FUNC_PROTOS(double , double )
// 结束宏定义部分
.\numpy\tools\swig\test\setup.py
from distutils.core import Extension, setup
import numpy
numpy_include = numpy.get_include()
_Array = Extension("_Array",
["Array_wrap.cxx",
"Array1.cxx",
"Array2.cxx",
"ArrayZ.cxx"],
include_dirs=[numpy_include],
)
_Farray = Extension("_Farray",
["Farray_wrap.cxx",
"Farray.cxx"],
include_dirs=[numpy_include],
)
_Vector = Extension("_Vector",
["Vector_wrap.cxx",
"Vector.cxx"],
include_dirs=[numpy_include],
)
_Matrix = Extension("_Matrix",
["Matrix_wrap.cxx",
"Matrix.cxx"],
include_dirs=[numpy_include],
)
_Tensor = Extension("_Tensor",
["Tensor_wrap.cxx",
"Tensor.cxx"],
include_dirs=[numpy_include],
)
_Fortran = Extension("_Fortran",
["Fortran_wrap.cxx",
"Fortran.cxx"],
include_dirs=[numpy_include],
)
_Flat = Extension("_Flat",
["Flat_wrap.cxx",
"Flat.cxx"],
include_dirs=[numpy_include],
)
setup(name="NumpyTypemapTests",
description="Functions that work on arrays",
author="Bill Spotz",
py_modules=["Array", "Farray", "Vector", "Matrix", "Tensor",
"Fortran", "Flat"],
ext_modules=[_Array, _Farray, _Vector, _Matrix, _Tensor,
_Fortran, _Flat]
)
.\numpy\tools\swig\test\SuperTensor.cxx
// 包含标准库头文件
// 包含输入输出流库的头文件
// 包含自定义头文件 "SuperTensor.h"
// 宏定义,生成一组函数,用于处理特定类型的四维数组
\
// 计算四维数组的范数,返回结果
TYPE SNAME
// 定义并初始化结果变量
double result = 0;
// 循环遍历四维数组的所有元素
for (int l=0; l<2; ++l)
for (int k=0; k<2; ++k)
for (int j=0; j<2; ++j)
for (int i=0; i<2; ++i)
// 计算每个元素的平方,并累加到结果中
result += supertensor[l][k][j][i] * supertensor[l][k][j][i];
// 对结果进行平方根运算并返回,转换为指定类型
return (TYPE)sqrt(result/16);
}
// 查找四维数组中的最大值,并返回
TYPE SNAME
// 定义变量保存结果,初始为数组第一个元素
TYPE result = supertensor[0];
// 多层循环遍历四维数组的每个元素
for (int l=0; l<cubes; ++l) {
for (int k=0; k<slices; ++k) {
for (int j=0; j<rows; ++j) {
for (int i=0; i<cols; ++i) {
// 计算当前元素在一维数组中的索引
int index = l*slices*rows*cols + k*rows*cols + j*cols + i;
// 如果当前元素大于保存的最大值,则更新最大值
if (supertensor[index] > result) result = supertensor[index];
}
}
}
}
// 返回找到的最大值
return result;
}
// 查找四维数组中的最小值,并返回
TYPE SNAME
// 定义变量保存结果,初始为数组第一个元素
TYPE result = supertensor[0];
// 多层循环遍历四维数组的每个元素
for (int l=0; l<cubes; ++l) {
for (int k=0; k<slices; ++k) {
for (int j=0; j<rows; ++j) {
for (int i=0; i<cols; ++i) {
// 计算当前元素在一维数组中的索引
int index = l*slices*rows*cols + k*rows*cols + j*cols + i;
// 如果当前元素小于保存的最小值,则更新最小值
if (supertensor[index] < result) result = supertensor[index];
}
}
}
}
// 返回找到的最小值
return result;
}
// 将四维数组中的每个元素乘以自身
void SNAME
// 多层循环遍历四维数组的每个元素
for (int l=0; l<3; ++l)
for (int k=0; k<3; ++k)
for (int j=0; j<3; ++j)
for (int i=0; i<3; ++i)
// 将当前元素乘以自身
supertensor[l][k][j][i] *= supertensor[l][k][j][i];
}
// 将四维数组中的每个元素向下取整到指定值
void SNAME
// 多层循环遍历四维数组的每个元素
for (int l=0; l<cubes; ++l) {
for (int k=0; k<slices; ++k) {
for (int j=0; j<rows; ++j) {
for (int i=0; i<cols; ++i) {
// 计算当前元素在一维数组中的索引
int index = l*slices*rows*cols + k*rows*cols + j*cols + i;
// 如果当前元素小于指定的 floor 值,则将其设为 floor
if (array[index] < floor) array[index] = floor;
}
}
}
}
}
// 将四维数组中的每个元素向上取整到指定值
void SNAME
// 多层循环遍历四维数组的每个元素
for (int l=0; l<cubes; ++l) {
for (int k=0; k<slices; ++k) {
for (int j=0; j<rows; ++j) {
for (int i=0; i<cols; ++i) {
// 计算当前元素在一维数组中的索引
int index = l*slices*rows*cols + k*rows*cols + j*cols + i;
// 如果当前元素大于指定的 ceil 值,则将其设为 ceil
if (array[index] > ceil) array[index] = ceil;
}
}
}
}
}
// 将输入的四维数组分解为下三角矩阵和上三角矩阵
void SNAME
// 多层循环遍历四维数组的每个元素
for (int l=0; l<2; ++l)
for (int k=0; k<2; ++k)
for (int j=0; j<2; ++j)
for (int i=0; i<2; ++i) {
// 将输入数组的元素分别存入下三角和上三角矩阵
if (j > i)
lower[l][k][j][i] = 0;
else if (j < i)
upper[l][k][j][i] = 0;
else {
lower[l][k][j][i] = in[l][k][j][i];
upper[l][k][j][i] = in[l][k][j][i];
}
}
}
for (k=0; k<slices; ++k) { \
// 遍历超级张量的第一维(切片维度)
for (j=0; j<rows; ++j) { \
// 在当前切片中,遍历第二维(行维度)
for (i=0; i<cols; ++i) { \
// 在当前行中,遍历第三维(列维度)
// 计算超级张量中元素的索引
index = l*slices*rows*cols + k*rows*cols + j*cols + i; \
// 检查超级张量中的当前元素是否小于结果值,更新结果值
if (supertensor[index] < result) result = supertensor[index]; \
} \
} \
} \
} \
// 返回最终的结果值
return result; \
} \
\
void SNAME
// 多维数组的缩放函数,将每个元素乘以指定的值
for (int l=0; l<3; ++l) \
for (int k=0; k<3; ++k) \
for (int j=0; j<3; ++j) \
for (int i=0; i<3; ++i) \
array[l][k][j][i] *= val; \
} \
\
void SNAME
// 对一维数组中的元素执行下界截断操作,将小于给定下界的元素设为下界值
int i, j, k, l, index; \
for (l=0; l<cubes; ++l) { \
for (k=0; k<slices; ++k) { \
for (j=0; j<rows; ++j) { \
for (i=0; i<cols; ++i) { \
index = l*slices*rows*cols + k*rows*cols + j*cols + i; \
if (array[index] < floor) array[index] = floor; \
} \
} \
} \
} \
} \
\
void SNAME
// 对一维数组中的元素执行上界截断操作,将大于给定上界的元素设为上界值
int i, j, k, l, index; \
for (l=0; l<cubes; ++l) { \
for (k=0; k<slices; ++k) { \
for (j=0; j<rows; ++j) { \
for (i=0; i<cols; ++i) { \
index = l*slices*rows*cols + k*rows*cols + j*cols + i; \
if (array[index] > ceil) array[index] = ceil; \
} \
} \
} \
} \
} \
\
void SNAME
TYPE upper[2][2][2][2]) { \
// 将四维超张量分解为下三角矩阵和上三角矩阵
int sum; \
for (int l=0; l<2; ++l) { \
for (int k=0; k<2; ++k) { \
for (int j=0; j<2; ++j) { \
for (int i=0; i<2; ++i) { \
sum = i + j + k + l; \
if (sum < 2) { \
lower[l][k][j][i] = supertensor[l][k][j][i]; \
upper[l][k][j][i] = 0; \
} else { \
upper[l][k][j][i] = supertensor[l][k][j][i]; \
lower[l][k][j][i] = 0; \
} \
} \
} \
} \
} \
}
TEST_FUNCS(signed char , schar )
TEST_FUNCS(unsigned char , uchar )
TEST_FUNCS(short , short )
TEST_FUNCS(unsigned short , ushort )
TEST_FUNCS(int , int )
TEST_FUNCS(unsigned int , uint )
TEST_FUNCS(long , long )
TEST_FUNCS(unsigned long , ulong )
TEST_FUNCS(long long , longLong )
TEST_FUNCS(unsigned long long, ulongLong)
TEST_FUNCS(float , float )
TEST_FUNCS(double , double )
.\numpy\tools\swig\test\SuperTensor.h
// 定义条件编译宏,防止头文件重复包含
// 宏定义说明:
// 下面的宏定义了一系列函数原型,用于处理4维数组,形式如下:
//
// TYPE SNAMENorm(TYPE supertensor[2][2][2][2]);
// TYPE SNAMEMax(TYPE *supertensor, int cubes, int slices, int rows, int cols);
// TYPE SNAMEMin(int cubes, int slices, int rows, int cols, TYPE *supertensor);
// void SNAMEScale(TYPE array[3][3][3][3], TYPE val);
// void SNAMEFloor(TYPE *array, int cubes, int slices, int rows, int cols, TYPE floor);
// void SNAMECeil(int cubes, int slices, int rows, int cols, TYPE *array, TYPE ceil);
// void SNAMELUSplit(TYPE in[3][3][3][3], TYPE lower[3][3][3][3], TYPE upper[3][3][3][3]);
//
// 其中TYPE可以是任意指定的类型(如:short, unsigned int, long long等),SNAME为类型的简称(如:short, uint, longLong等)。
// 这些宏根据给定的TYPE/SNAME对扩展为特定的函数原型。这些函数用于测试numpy接口,分别用于:
//
// * 处理4维输入数组,长度硬编码
// * 处理4维输入数组
// * 处理4维输入数组,数据最后
// * 处理4维原地数组,长度硬编码
// * 处理4维原地数组
// * 处理4维原地数组,数据最后
// * 处理4维输出数组,长度硬编码
//
\
TYPE SNAME
TYPE SNAME
TYPE SNAME
void SNAME
void SNAME
void SNAME
void SNAME
// 对各种类型的TYPE/SNAME进行宏扩展,生成相应的函数原型
TEST_FUNC_PROTOS(signed char, schar)
TEST_FUNC_PROTOS(unsigned char, uchar)
TEST_FUNC_PROTOS(short, short)
TEST_FUNC_PROTOS(unsigned short, ushort)
TEST_FUNC_PROTOS(int, int)
TEST_FUNC_PROTOS(unsigned int, uint)
TEST_FUNC_PROTOS(long, long)
TEST_FUNC_PROTOS(unsigned long, ulong)
TEST_FUNC_PROTOS(long long, longLong)
TEST_FUNC_PROTOS(unsigned long long, ulongLong)
TEST_FUNC_PROTOS(float, float)
TEST_FUNC_PROTOS(double, double)
// 结束条件编译宏的定义
.\numpy\tools\swig\test\Tensor.cxx
// 包含标准库头文件 <stdlib.h> 和 <math.h>
// 包含输入输出流库头文件 <iostream>
// 包含自定义头文件 "Tensor.h"
// 定义宏 TEST_FUNCS(TYPE, SNAME),生成一系列函数用于处理三维数组
\
// 计算三维数组的范数,返回类型为 TYPE
TYPE SNAME
// 初始化结果变量为 0
double result = 0;
// 遍历三维数组的每个元素并计算平方和
for (int k=0; k<2; ++k)
for (int j=0; j<2; ++j)
for (int i=0; i<2; ++i)
result += tensor[k][j][i] * tensor[k][j][i];
// 返回平方和的平方根除以 8 的结果,强制转换为指定类型 TYPE
return (TYPE)sqrt(result/8);
}
// 计算三维数组的最大值,返回类型为 TYPE
TYPE SNAME
// 声明变量 i, j, k, index
int i, j, k, index;
// 初始化结果变量为第一个元素的值
TYPE result = tensor[0];
// 嵌套循环遍历三维数组的每个元素
for (k=0; k<slices; ++k) {
for (j=0; j<rows; ++j) {
for (i=0; i<cols; ++i) {
// 计算当前元素在一维数组中的索引
index = k*rows*cols + j*cols + i;
// 如果当前元素大于结果变量,更新结果变量的值
if (tensor[index] > result) result = tensor[index];
}
}
}
// 返回最大值
return result;
}
// 计算三维数组的最小值,返回类型为 TYPE
TYPE SNAME
// 声明变量 i, j, k, index
int i, j, k, index;
// 初始化结果变量为第一个元素的值
TYPE result = tensor[0];
// 嵌套循环遍历三维数组的每个元素
for (k=0; k<slices; ++k) {
for (j=0; j<rows; ++j) {
for (i=0; i<cols; ++i) {
// 计算当前元素在一维数组中的索引
index = k*rows*cols + j*cols + i;
// 如果当前元素小于结果变量,更新结果变量的值
if (tensor[index] < result) result = tensor[index];
}
}
}
// 返回最小值
return result;
}
void scharScale(signed char array[3][3][3], signed char val) { \
// 遍历三维数组,对每个元素乘以给定的值
for (int k=0; k<3; ++k) \
for (int j=0; j<3; ++j) \
for (int i=0; i<3; ++i) \
array[k][j][i] *= val; \
} \
\
void scharFloor(signed char * array, int slices, int rows, int cols, signed char floor) { \
int i, j, k, index; \
// 遍历一维数组,将小于指定下限的元素设为下限值
for (k=0; k<slices; ++k) { \
for (j=0; j<rows; ++j) { \
for (i=0; i<cols; ++i) { \
index = k*rows*cols + j*cols + i; \
if (array[index] < floor) array[index] = floor; \
} \
} \
} \
} \
\
void scharCeil(int slices, int rows, int cols, signed char * array, signed char ceil) { \
int i, j, k, index; \
// 遍历一维数组,将大于指定上限的元素设为上限值
for (k=0; k<slices; ++k) { \
for (j=0; j<rows; ++j) { \
for (i=0; i<cols; ++i) { \
index = k*rows*cols + j*cols + i; \
if (array[index] > ceil) array[index] = ceil; \
} \
} \
} \
} \
\
void scharLUSplit(signed char tensor[2][2][2], signed char lower[2][2][2], \
signed char upper[2][2][2]) { \
int sum; \
// 遍历三维数组,根据索引和小于2的条件分割为上下三角部分
for (int k=0; k<2; ++k) { \
for (int j=0; j<2; ++j) { \
for (int i=0; i<2; ++i) { \
sum = i + j + k; \
if (sum < 2) { \
lower[k][j][i] = tensor[k][j][i]; \
upper[k][j][i] = 0; \
} else { \
upper[k][j][i] = tensor[k][j][i]; \
lower[k][j][i] = 0; \
} \
} \
} \
} \
}
.\numpy\tools\swig\test\Tensor.h
// 如果未定义 TENSOR_H,则进入条件编译,避免重复包含
// 定义 TENSOR_H,确保此头文件只被包含一次
// 下面的宏定义了一系列与 3D 数组相关的函数原型,这些函数根据指定的类型 TYPE(例如:short、unsigned int、long long 等)
// 和简称 SNAME(例如:short、uint、longLong 等)来生成
// 这些函数用于测试与 numpy 接口相关的功能,涵盖了不同数据类型和操作方式的测试需求,具体包括:
//
// TYPE SNAME
// TYPE SNAME
// TYPE SNAME
// void SNAME
// void SNAME
// void SNAME
// void SNAME
\
TYPE SNAME
TYPE SNAME
TYPE SNAME
void SNAME
void SNAME
void SNAME
void SNAME
// 根据各种类型和简称,展开测试函数的宏定义
TEST_FUNC_PROTOS(signed char , schar )
TEST_FUNC_PROTOS(unsigned char , uchar )
TEST_FUNC_PROTOS(short , short )
TEST_FUNC_PROTOS(unsigned short , ushort )
TEST_FUNC_PROTOS(int , int )
TEST_FUNC_PROTOS(unsigned int , uint )
TEST_FUNC_PROTOS(long , long )
TEST_FUNC_PROTOS(unsigned long , ulong )
TEST_FUNC_PROTOS(long long , longLong )
TEST_FUNC_PROTOS(unsigned long long, ulongLong)
TEST_FUNC_PROTOS(float , float )
TEST_FUNC_PROTOS(double , double )
// 结束条件编译指令
.\numpy\tools\swig\test\testArray.py
import sys
import unittest
import numpy as np
major, minor = [ int(d) for d in np.__version__.split(".")[:2] ]
if major == 0:
BadListError = TypeError
else:
BadListError = ValueError
import Array
class Array1TestCase(unittest.TestCase):
"""定义Array1类的单元测试用例"""
def setUp(self):
"""每个测试方法执行前的初始化操作"""
self.length = 5
self.array1 = Array.Array1(self.length)
def testConstructor0(self):
"""测试Array1的默认构造函数"""
a = Array.Array1()
self.assertTrue(isinstance(a, Array.Array1))
self.assertTrue(len(a) == 0)
def testConstructor1(self):
"""测试Array1的长度构造函数"""
self.assertTrue(isinstance(self.array1, Array.Array1))
def testConstructor2(self):
"""测试Array1的数组构造函数"""
na = np.arange(self.length)
aa = Array.Array1(na)
self.assertTrue(isinstance(aa, Array.Array1))
def testConstructor3(self):
"""测试Array1的拷贝构造函数"""
for i in range(self.array1.length()): self.array1[i] = i
arrayCopy = Array.Array1(self.array1)
self.assertTrue(arrayCopy == self.array1)
def testConstructorBad(self):
"""测试Array1的长度构造函数,负值情况"""
self.assertRaises(ValueError, Array.Array1, -4)
def testLength(self):
"""测试Array1的length方法"""
self.assertTrue(self.array1.length() == self.length)
def testLen(self):
"""测试Array1的__len__方法"""
self.assertTrue(len(self.array1) == self.length)
def testResize0(self):
"""测试Array1的resize方法,长度参数"""
newLen = 2 * self.length
self.array1.resize(newLen)
self.assertTrue(len(self.array1) == newLen)
def testResize1(self):
"""测试Array1的resize方法,数组参数"""
a = np.zeros((2*self.length,), dtype='l')
self.array1.resize(a)
self.assertTrue(len(self.array1) == a.size)
def testResizeBad(self):
"""测试Array1的resize方法,负值长度"""
self.assertRaises(ValueError, self.array1.resize, -5)
def testSetGet(self):
"""测试Array1的__setitem__和__getitem__方法"""
n = self.length
for i in range(n):
self.array1[i] = i*i
for i in range(n):
self.assertTrue(self.array1[i] == i*i)
def testSetBad1(self):
"""测试Array1的__setitem__方法,负索引"""
self.assertRaises(IndexError, self.array1.__setitem__, -1, 0)
def testSetBad2(self):
"""测试Array1的__setitem__方法,超出范围的索引"""
self.assertRaises(IndexError, self.array1.__setitem__, self.length+1, 0)
def testGetBad1(self):
"""测试Array1的__getitem__方法,负索引"""
self.assertRaises(IndexError, self.array1.__getitem__, -1)
def testGetBad2(self):
"Test Array1 __getitem__ method, out-of-range index"
self.assertRaises(IndexError, self.array1.__getitem__, self.length+1)
def testAsString(self):
"Test Array1 asString method"
for i in range(self.array1.length()): self.array1[i] = i+1
self.assertTrue(self.array1.asString() == "[ 1, 2, 3, 4, 5 ]")
def testStr(self):
"Test Array1 __str__ method"
for i in range(self.array1.length()): self.array1[i] = i-2
self.assertTrue(str(self.array1) == "[ -2, -1, 0, 1, 2 ]")
def testView(self):
"Test Array1 view method"
for i in range(self.array1.length()): self.array1[i] = i+1
a = self.array1.view()
self.assertTrue(isinstance(a, np.ndarray))
self.assertTrue(len(a) == self.length)
self.assertTrue((a == [1, 2, 3, 4, 5]).all())
class Array2TestCase(unittest.TestCase):
def setUp(self):
self.nrows = 5
self.ncols = 4
self.array2 = Array.Array2(self.nrows, self.ncols)
def testConstructor0(self):
"Test Array2 default constructor"
a = Array.Array2()
self.assertTrue(isinstance(a, Array.Array2))
self.assertTrue(len(a) == 0)
def testConstructor1(self):
"Test Array2 nrows, ncols constructor"
self.assertTrue(isinstance(self.array2, Array.Array2))
def testConstructor2(self):
"Test Array2 array constructor"
na = np.zeros((3, 4), dtype="l")
aa = Array.Array2(na)
self.assertTrue(isinstance(aa, Array.Array2))
def testConstructor3(self):
"Test Array2 copy constructor"
for i in range(self.nrows):
for j in range(self.ncols):
self.array2[i][j] = i * j
arrayCopy = Array.Array2(self.array2)
self.assertTrue(arrayCopy == self.array2)
def testConstructorBad1(self):
"Test Array2 nrows, ncols constructor, negative nrows"
self.assertRaises(ValueError, Array.Array2, -4, 4)
def testConstructorBad2(self):
"Test Array2 nrows, ncols constructor, negative ncols"
self.assertRaises(ValueError, Array.Array2, 4, -4)
def testNrows(self):
"Test Array2 nrows method"
self.assertTrue(self.array2.nrows() == self.nrows)
def testNcols(self):
"Test Array2 ncols method"
self.assertTrue(self.array2.ncols() == self.ncols)
def testLen(self):
"Test Array2 __len__ method"
self.assertTrue(len(self.array2) == self.nrows*self.ncols)
def testResize0(self):
"Test Array2 resize method, size"
newRows = 2 * self.nrows
newCols = 2 * self.ncols
self.array2.resize(newRows, newCols)
self.assertTrue(len(self.array2) == newRows * newCols)
def testResize1(self):
"Test Array2 resize method, array"
a = np.zeros((2*self.nrows, 2*self.ncols), dtype='l')
self.array2.resize(a)
self.assertTrue(len(self.array2) == a.size)
def testResizeBad1(self):
"Test Array2 resize method, negative nrows"
self.assertRaises(ValueError, self.array2.resize, -5, 5)
def testResizeBad2(self):
"Test Array2 resize method, negative ncols"
self.assertRaises(ValueError, self.array2.resize, 5, -5)
def testSetGet1(self):
"Test Array2 __setitem__, __getitem__ methods"
m = self.nrows
n = self.ncols
array1 = []
a = np.arange(n, dtype="l")
for i in range(m):
array1.append(Array.Array1(i * a))
for i in range(m):
self.array2[i] = array1[i]
for i in range(m):
self.assertTrue(self.array2[i] == array1[i])
def testSetGet2(self):
"Test Array2 chained __setitem__, __getitem__ methods"
m = self.nrows
n = self.ncols
for i in range(m):
for j in range(n):
self.array2[i][j] = i * j
for i in range(m):
for j in range(n):
self.assertTrue(self.array2[i][j] == i * j)
def testSetBad1(self):
"Test Array2 __setitem__ method, negative index"
a = Array.Array1(self.ncols)
self.assertRaises(IndexError, self.array2.__setitem__, -1, a)
def testSetBad2(self):
"Test Array2 __setitem__ method, out-of-range index"
a = Array.Array1(self.ncols)
self.assertRaises(IndexError, self.array2.__setitem__, self.nrows + 1, a)
def testGetBad1(self):
"Test Array2 __getitem__ method, negative index"
self.assertRaises(IndexError, self.array2.__getitem__, -1)
def testGetBad2(self):
"Test Array2 __getitem__ method, out-of-range index"
self.assertRaises(IndexError, self.array2.__getitem__, self.nrows + 1)
def testAsString(self):
"Test Array2 asString method"
result = """\
# 定义一个二维列表,表示一个包含整数的二维数组
[ [ 0, 1, 2, 3 ],
[ 1, 2, 3, 4 ],
[ 2, 3, 4, 5 ],
[ 3, 4, 5, 6 ],
[ 4, 5, 6, 7 ] ]
# 遍历二维数组的行
for i in range(self.nrows):
# 遍历二维数组的列
for j in range(self.ncols):
# 为每个元素赋值为行索引和列索引的和
self.array2[i][j] = i+j
# 使用断言验证数组转换为字符串后是否与指定的结果字符串相等
self.assertTrue(self.array2.asString() == result)
def testStr(self):
"Test Array2 __str__ method"
# 定义预期的结果字符串,表示包含整数的二维数组
result = """\
[ [ 0, -1, -2, -3 ],
[ 1, 0, -1, -2 ],
[ 2, 1, 0, -1 ],
[ 3, 2, 1, 0 ],
[ 4, 3, 2, 1 ] ]
"""
# 遍历二维数组的行
for i in range(self.nrows):
# 遍历二维数组的列
for j in range(self.ncols):
# 为每个元素赋值为行索引减去列索引
self.array2[i][j] = i-j
# 使用断言验证数组转换为字符串后是否与指定的结果字符串相等
self.assertTrue(str(self.array2) == result)
def testView(self):
"Test Array2 view method"
# 调用数组的视图方法
a = self.array2.view()
# 使用断言验证返回的视图对象是否是 NumPy 数组
self.assertTrue(isinstance(a, np.ndarray))
# 使用断言验证返回的视图数组的长度是否与指定的行数相等
self.assertTrue(len(a) == self.nrows)
######################################################################
class ArrayZTestCase(unittest.TestCase):
def setUp(self):
# 初始化测试用例中数组的长度
self.length = 5
# 创建一个长度为 self.length 的 ArrayZ 对象
self.array3 = Array.ArrayZ(self.length)
(以下测试函数可以根据需要进行类似的注释,但已超出了示例的范围,这里仅展示示例代码部分的注释。)
def testSetBad1(self):
"Test ArrayZ __setitem__ method, negative index"
# 断言会抛出 IndexError 异常,因为索引为负数
self.assertRaises(IndexError, self.array3.__setitem__, -1, 0)
def testSetBad2(self):
"Test ArrayZ __setitem__ method, out-of-range index"
# 断言会抛出 IndexError 异常,因为索引超出范围
self.assertRaises(IndexError, self.array3.__setitem__, self.length+1, 0)
def testGetBad1(self):
"Test ArrayZ __getitem__ method, negative index"
# 断言会抛出 IndexError 异常,因为索引为负数
self.assertRaises(IndexError, self.array3.__getitem__, -1)
def testGetBad2(self):
"Test ArrayZ __getitem__ method, out-of-range index"
# 断言会抛出 IndexError 异常,因为索引超出范围
self.assertRaises(IndexError, self.array3.__getitem__, self.length+1)
def testAsString(self):
"Test ArrayZ asString method"
# 为 ArrayZ 对象的元素赋值为复数
for i in range(self.array3.length()): self.array3[i] = complex(i+1,-i-1)
# 断言 ArrayZ 对象的字符串表示与预期的字符串相等
self.assertTrue(self.array3.asString() == "[ (1,-1), (2,-2), (3,-3), (4,-4), (5,-5) ]")
def testStr(self):
"Test ArrayZ __str__ method"
# 为 ArrayZ 对象的元素赋值为复数
for i in range(self.array3.length()): self.array3[i] = complex(i-2,(i-2)*2)
# 断言 ArrayZ 对象的字符串表示与预期的字符串相等
self.assertTrue(str(self.array3) == "[ (-2,-4), (-1,-2), (0,0), (1,2), (2,4) ]")
def testView(self):
"Test ArrayZ view method"
# 为 ArrayZ 对象的元素赋值为复数
for i in range(self.array3.length()): self.array3[i] = complex(i+1,i+2)
# 获取 ArrayZ 对象的视图
a = self.array3.view()
# 断言视图对象是 numpy.ndarray 类型
self.assertTrue(isinstance(a, np.ndarray))
# 断言视图对象的长度与 ArrayZ 对象的长度相等
self.assertTrue(len(a) == self.length)
# 断言视图对象的所有元素与预期的复数数组相等
self.assertTrue((a == [1+2j, 2+3j, 3+4j, 4+5j, 5+6j]).all())
######################################################################
if __name__ == "__main__":
# 如果当前脚本被直接执行而非被导入,则执行以下代码块
# 构建测试套件
suite = unittest.TestSuite()
# 将 Array1TestCase 的测试添加到测试套件中
suite.addTest(unittest.makeSuite(Array1TestCase))
# 将 Array2TestCase 的测试添加到测试套件中
suite.addTest(unittest.makeSuite(Array2TestCase))
# 将 ArrayZTestCase 的测试添加到测试套件中
suite.addTest(unittest.makeSuite(ArrayZTestCase))
# 执行测试套件
print("Testing Classes of Module Array")
# 打印 NumPy 版本信息
print("NumPy version", np.__version__)
print()
# 运行测试套件,并返回测试结果
result = unittest.TextTestRunner(verbosity=2).run(suite)
# 根据测试结果中是否存在错误或失败,决定退出状态
sys.exit(bool(result.errors + result.failures))