""" | |
Unit tests for refactor.py. | |
""" | |
from __future__ import with_statement | |
import sys | |
import os | |
import codecs | |
import operator | |
import StringIO | |
import tempfile | |
import shutil | |
import unittest | |
import warnings | |
from lib2to3 import refactor, pygram, fixer_base | |
from lib2to3.pgen2 import token | |
from . import support | |
TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") | |
FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers") | |
sys.path.append(FIXER_DIR) | |
try: | |
_DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes") | |
finally: | |
sys.path.pop() | |
_2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes") | |
class TestRefactoringTool(unittest.TestCase): | |
def setUp(self): | |
sys.path.append(FIXER_DIR) | |
def tearDown(self): | |
sys.path.pop() | |
def check_instances(self, instances, classes): | |
for inst, cls in zip(instances, classes): | |
if not isinstance(inst, cls): | |
self.fail("%s are not instances of %s" % instances, classes) | |
def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): | |
return refactor.RefactoringTool(fixers, options, explicit) | |
def test_print_function_option(self): | |
rt = self.rt({"print_function" : True}) | |
self.assertTrue(rt.grammar is pygram.python_grammar_no_print_statement) | |
self.assertTrue(rt.driver.grammar is | |
pygram.python_grammar_no_print_statement) | |
def test_fixer_loading_helpers(self): | |
contents = ["explicit", "first", "last", "parrot", "preorder"] | |
non_prefixed = refactor.get_all_fix_names("myfixes") | |
prefixed = refactor.get_all_fix_names("myfixes", False) | |
full_names = refactor.get_fixers_from_package("myfixes") | |
self.assertEqual(prefixed, ["fix_" + name for name in contents]) | |
self.assertEqual(non_prefixed, contents) | |
self.assertEqual(full_names, | |
["myfixes.fix_" + name for name in contents]) | |
def test_detect_future_features(self): | |
run = refactor._detect_future_features | |
fs = frozenset | |
empty = fs() | |
self.assertEqual(run(""), empty) | |
self.assertEqual(run("from __future__ import print_function"), | |
fs(("print_function",))) | |
self.assertEqual(run("from __future__ import generators"), | |
fs(("generators",))) | |
self.assertEqual(run("from __future__ import generators, feature"), | |
fs(("generators", "feature"))) | |
inp = "from __future__ import generators, print_function" | |
self.assertEqual(run(inp), fs(("generators", "print_function"))) | |
inp ="from __future__ import print_function, generators" | |
self.assertEqual(run(inp), fs(("print_function", "generators"))) | |
inp = "from __future__ import (print_function,)" | |
self.assertEqual(run(inp), fs(("print_function",))) | |
inp = "from __future__ import (generators, print_function)" | |
self.assertEqual(run(inp), fs(("generators", "print_function"))) | |
inp = "from __future__ import (generators, nested_scopes)" | |
self.assertEqual(run(inp), fs(("generators", "nested_scopes"))) | |
inp = """from __future__ import generators | |
from __future__ import print_function""" | |
self.assertEqual(run(inp), fs(("generators", "print_function"))) | |
invalid = ("from", | |
"from 4", | |
"from x", | |
"from x 5", | |
"from x im", | |
"from x import", | |
"from x import 4", | |
) | |
for inp in invalid: | |
self.assertEqual(run(inp), empty) | |
inp = "'docstring'\nfrom __future__ import print_function" | |
self.assertEqual(run(inp), fs(("print_function",))) | |
inp = "'docstring'\n'somng'\nfrom __future__ import print_function" | |
self.assertEqual(run(inp), empty) | |
inp = "# comment\nfrom __future__ import print_function" | |
self.assertEqual(run(inp), fs(("print_function",))) | |
inp = "# comment\n'doc'\nfrom __future__ import print_function" | |
self.assertEqual(run(inp), fs(("print_function",))) | |
inp = "class x: pass\nfrom __future__ import print_function" | |
self.assertEqual(run(inp), empty) | |
def test_get_headnode_dict(self): | |
class NoneFix(fixer_base.BaseFix): | |
pass | |
class FileInputFix(fixer_base.BaseFix): | |
PATTERN = "file_input< any * >" | |
class SimpleFix(fixer_base.BaseFix): | |
PATTERN = "'name'" | |
no_head = NoneFix({}, []) | |
with_head = FileInputFix({}, []) | |
simple = SimpleFix({}, []) | |
d = refactor._get_headnode_dict([no_head, with_head, simple]) | |
top_fixes = d.pop(pygram.python_symbols.file_input) | |
self.assertEqual(top_fixes, [with_head, no_head]) | |
name_fixes = d.pop(token.NAME) | |
self.assertEqual(name_fixes, [simple, no_head]) | |
for fixes in d.itervalues(): | |
self.assertEqual(fixes, [no_head]) | |
def test_fixer_loading(self): | |
from myfixes.fix_first import FixFirst | |
from myfixes.fix_last import FixLast | |
from myfixes.fix_parrot import FixParrot | |
from myfixes.fix_preorder import FixPreorder | |
rt = self.rt() | |
pre, post = rt.get_fixers() | |
self.check_instances(pre, [FixPreorder]) | |
self.check_instances(post, [FixFirst, FixParrot, FixLast]) | |
def test_naughty_fixers(self): | |
self.assertRaises(ImportError, self.rt, fixers=["not_here"]) | |
self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"]) | |
self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"]) | |
def test_refactor_string(self): | |
rt = self.rt() | |
input = "def parrot(): pass\n\n" | |
tree = rt.refactor_string(input, "<test>") | |
self.assertNotEqual(str(tree), input) | |
input = "def f(): pass\n\n" | |
tree = rt.refactor_string(input, "<test>") | |
self.assertEqual(str(tree), input) | |
def test_refactor_stdin(self): | |
class MyRT(refactor.RefactoringTool): | |
def print_output(self, old_text, new_text, filename, equal): | |
results.extend([old_text, new_text, filename, equal]) | |
results = [] | |
rt = MyRT(_DEFAULT_FIXERS) | |
save = sys.stdin | |
sys.stdin = StringIO.StringIO("def parrot(): pass\n\n") | |
try: | |
rt.refactor_stdin() | |
finally: | |
sys.stdin = save | |
expected = ["def parrot(): pass\n\n", | |
"def cheese(): pass\n\n", | |
"<stdin>", False] | |
self.assertEqual(results, expected) | |
def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS): | |
def read_file(): | |
with open(test_file, "rb") as fp: | |
return fp.read() | |
old_contents = read_file() | |
rt = self.rt(fixers=fixers) | |
rt.refactor_file(test_file) | |
self.assertEqual(old_contents, read_file()) | |
try: | |
rt.refactor_file(test_file, True) | |
new_contents = read_file() | |
self.assertNotEqual(old_contents, new_contents) | |
finally: | |
with open(test_file, "wb") as fp: | |
fp.write(old_contents) | |
return new_contents | |
def test_refactor_file(self): | |
test_file = os.path.join(FIXER_DIR, "parrot_example.py") | |
self.check_file_refactoring(test_file, _DEFAULT_FIXERS) | |
def test_refactor_dir(self): | |
def check(structure, expected): | |
def mock_refactor_file(self, f, *args): | |
got.append(f) | |
save_func = refactor.RefactoringTool.refactor_file | |
refactor.RefactoringTool.refactor_file = mock_refactor_file | |
rt = self.rt() | |
got = [] | |
dir = tempfile.mkdtemp(prefix="2to3-test_refactor") | |
try: | |
os.mkdir(os.path.join(dir, "a_dir")) | |
for fn in structure: | |
open(os.path.join(dir, fn), "wb").close() | |
rt.refactor_dir(dir) | |
finally: | |
refactor.RefactoringTool.refactor_file = save_func | |
shutil.rmtree(dir) | |
self.assertEqual(got, | |
[os.path.join(dir, path) for path in expected]) | |
check([], []) | |
tree = ["nothing", | |
"hi.py", | |
".dumb", | |
".after.py", | |
"notpy.npy", | |
"sappy"] | |
expected = ["hi.py"] | |
check(tree, expected) | |
tree = ["hi.py", | |
os.path.join("a_dir", "stuff.py")] | |
check(tree, tree) | |
def test_file_encoding(self): | |
fn = os.path.join(TEST_DATA_DIR, "different_encoding.py") | |
self.check_file_refactoring(fn) | |
def test_bom(self): | |
fn = os.path.join(TEST_DATA_DIR, "bom.py") | |
data = self.check_file_refactoring(fn) | |
self.assertTrue(data.startswith(codecs.BOM_UTF8)) | |
def test_crlf_newlines(self): | |
old_sep = os.linesep | |
os.linesep = "\r\n" | |
try: | |
fn = os.path.join(TEST_DATA_DIR, "crlf.py") | |
fixes = refactor.get_fixers_from_package("lib2to3.fixes") | |
self.check_file_refactoring(fn, fixes) | |
finally: | |
os.linesep = old_sep | |
def test_refactor_docstring(self): | |
rt = self.rt() | |
doc = """ | |
>>> example() | |
42 | |
""" | |
out = rt.refactor_docstring(doc, "<test>") | |
self.assertEqual(out, doc) | |
doc = """ | |
>>> def parrot(): | |
... return 43 | |
""" | |
out = rt.refactor_docstring(doc, "<test>") | |
self.assertNotEqual(out, doc) | |
def test_explicit(self): | |
from myfixes.fix_explicit import FixExplicit | |
rt = self.rt(fixers=["myfixes.fix_explicit"]) | |
self.assertEqual(len(rt.post_order), 0) | |
rt = self.rt(explicit=["myfixes.fix_explicit"]) | |
for fix in rt.post_order: | |
if isinstance(fix, FixExplicit): | |
break | |
else: | |
self.fail("explicit fixer not loaded") |