gh-137530: generate an __annotate__ function for dataclasses __init__ (GH-137711)
This commit is contained in:
@@ -368,6 +368,14 @@ collections.abc
|
|||||||
previously emitted if it was merely imported or accessed from the
|
previously emitted if it was merely imported or accessed from the
|
||||||
:mod:`!collections.abc` module.
|
:mod:`!collections.abc` module.
|
||||||
|
|
||||||
|
|
||||||
|
dataclasses
|
||||||
|
-----------
|
||||||
|
|
||||||
|
* Annotations for generated ``__init__`` methods no longer include internal
|
||||||
|
type names.
|
||||||
|
|
||||||
|
|
||||||
dbm
|
dbm
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -441,9 +441,11 @@ class _FuncBuilder:
|
|||||||
self.locals = {}
|
self.locals = {}
|
||||||
self.overwrite_errors = {}
|
self.overwrite_errors = {}
|
||||||
self.unconditional_adds = {}
|
self.unconditional_adds = {}
|
||||||
|
self.method_annotations = {}
|
||||||
|
|
||||||
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
|
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
|
||||||
overwrite_error=False, unconditional_add=False, decorator=None):
|
overwrite_error=False, unconditional_add=False, decorator=None,
|
||||||
|
annotation_fields=None):
|
||||||
if locals is not None:
|
if locals is not None:
|
||||||
self.locals.update(locals)
|
self.locals.update(locals)
|
||||||
|
|
||||||
@@ -464,16 +466,14 @@ class _FuncBuilder:
|
|||||||
|
|
||||||
self.names.append(name)
|
self.names.append(name)
|
||||||
|
|
||||||
if return_type is not MISSING:
|
if annotation_fields is not None:
|
||||||
self.locals[f'__dataclass_{name}_return_type__'] = return_type
|
self.method_annotations[name] = (annotation_fields, return_type)
|
||||||
return_annotation = f'->__dataclass_{name}_return_type__'
|
|
||||||
else:
|
|
||||||
return_annotation = ''
|
|
||||||
args = ','.join(args)
|
args = ','.join(args)
|
||||||
body = '\n'.join(body)
|
body = '\n'.join(body)
|
||||||
|
|
||||||
# Compute the text of the entire function, add it to the text we're generating.
|
# Compute the text of the entire function, add it to the text we're generating.
|
||||||
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
|
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')
|
||||||
|
|
||||||
def add_fns_to_class(self, cls):
|
def add_fns_to_class(self, cls):
|
||||||
# The source to all of the functions we're generating.
|
# The source to all of the functions we're generating.
|
||||||
@@ -509,6 +509,15 @@ class _FuncBuilder:
|
|||||||
# Now that we've generated the functions, assign them into cls.
|
# Now that we've generated the functions, assign them into cls.
|
||||||
for name, fn in zip(self.names, fns):
|
for name, fn in zip(self.names, fns):
|
||||||
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
|
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
annotation_fields, return_type = self.method_annotations[name]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
|
||||||
|
fn.__annotate__ = annotate_fn
|
||||||
|
|
||||||
if self.unconditional_adds.get(name, False):
|
if self.unconditional_adds.get(name, False):
|
||||||
setattr(cls, name, fn)
|
setattr(cls, name, fn)
|
||||||
else:
|
else:
|
||||||
@@ -524,6 +533,44 @@ class _FuncBuilder:
|
|||||||
raise TypeError(error_msg)
|
raise TypeError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
|
||||||
|
# Create an __annotate__ function for a dataclass
|
||||||
|
# Try to return annotations in the same format as they would be
|
||||||
|
# from a regular __init__ function
|
||||||
|
|
||||||
|
def __annotate__(format, /):
|
||||||
|
Format = annotationlib.Format
|
||||||
|
match format:
|
||||||
|
case Format.VALUE | Format.FORWARDREF | Format.STRING:
|
||||||
|
cls_annotations = {}
|
||||||
|
for base in reversed(__class__.__mro__):
|
||||||
|
cls_annotations.update(
|
||||||
|
annotationlib.get_annotations(base, format=format)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_annotations = {}
|
||||||
|
for k in annotation_fields:
|
||||||
|
new_annotations[k] = cls_annotations[k]
|
||||||
|
|
||||||
|
if return_type is not MISSING:
|
||||||
|
if format == Format.STRING:
|
||||||
|
new_annotations["return"] = annotationlib.type_repr(return_type)
|
||||||
|
else:
|
||||||
|
new_annotations["return"] = return_type
|
||||||
|
|
||||||
|
return new_annotations
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(format)
|
||||||
|
|
||||||
|
# This is a flag for _add_slots to know it needs to regenerate this method
|
||||||
|
# In order to remove references to the original class when it is replaced
|
||||||
|
__annotate__.__generated_by_dataclasses__ = True
|
||||||
|
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"
|
||||||
|
|
||||||
|
return __annotate__
|
||||||
|
|
||||||
|
|
||||||
def _field_assign(frozen, name, value, self_name):
|
def _field_assign(frozen, name, value, self_name):
|
||||||
# If we're a frozen class, then assign to our fields in __init__
|
# If we're a frozen class, then assign to our fields in __init__
|
||||||
# via object.__setattr__. Otherwise, just use a simple
|
# via object.__setattr__. Otherwise, just use a simple
|
||||||
@@ -612,7 +659,7 @@ def _init_param(f):
|
|||||||
elif f.default_factory is not MISSING:
|
elif f.default_factory is not MISSING:
|
||||||
# There's a factory function. Set a marker.
|
# There's a factory function. Set a marker.
|
||||||
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
|
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
|
||||||
return f'{f.name}:__dataclass_type_{f.name}__{default}'
|
return f'{f.name}{default}'
|
||||||
|
|
||||||
|
|
||||||
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
||||||
@@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
|||||||
raise TypeError(f'non-default argument {f.name!r} '
|
raise TypeError(f'non-default argument {f.name!r} '
|
||||||
f'follows default argument {seen_default.name!r}')
|
f'follows default argument {seen_default.name!r}')
|
||||||
|
|
||||||
locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
|
annotation_fields = [f.name for f in fields if f.init]
|
||||||
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
|
||||||
'__dataclass_builtins_object__': object,
|
locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
||||||
}
|
'__dataclass_builtins_object__': object}
|
||||||
}
|
|
||||||
|
|
||||||
body_lines = []
|
body_lines = []
|
||||||
for f in fields:
|
for f in fields:
|
||||||
@@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
|||||||
[self_name] + _init_params,
|
[self_name] + _init_params,
|
||||||
body_lines,
|
body_lines,
|
||||||
locals=locals,
|
locals=locals,
|
||||||
return_type=None)
|
return_type=None,
|
||||||
|
annotation_fields=annotation_fields)
|
||||||
|
|
||||||
|
|
||||||
def _frozen_get_del_attr(cls, fields, func_builder):
|
def _frozen_get_del_attr(cls, fields, func_builder):
|
||||||
@@ -1337,6 +1384,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
|
|||||||
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
|
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Get new annotations to remove references to the original class
|
||||||
|
# in forward references
|
||||||
|
newcls_ann = annotationlib.get_annotations(
|
||||||
|
newcls, format=annotationlib.Format.FORWARDREF)
|
||||||
|
|
||||||
|
# Fix references in dataclass Fields
|
||||||
|
for f in getattr(newcls, _FIELDS).values():
|
||||||
|
try:
|
||||||
|
ann = newcls_ann[f.name]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
f.type = ann
|
||||||
|
|
||||||
|
# Fix the class reference in the __annotate__ method
|
||||||
|
init_annotate = newcls.__init__.__annotate__
|
||||||
|
if getattr(init_annotate, "__generated_by_dataclasses__", False):
|
||||||
|
_update_func_cell_for__class__(init_annotate, cls, newcls)
|
||||||
|
|
||||||
return newcls
|
return newcls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2471,6 +2471,135 @@ class TestInit(unittest.TestCase):
|
|||||||
self.assertEqual(D(5).a, 10)
|
self.assertEqual(D(5).a, 10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitAnnotate(unittest.TestCase):
|
||||||
|
# Tests for the generated __annotate__ function for __init__
|
||||||
|
# See: https://github.com/python/cpython/issues/137530
|
||||||
|
|
||||||
|
def test_annotate_function(self):
|
||||||
|
# No forward references
|
||||||
|
@dataclass
|
||||||
|
class A:
|
||||||
|
a: int
|
||||||
|
|
||||||
|
value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
|
||||||
|
forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
|
||||||
|
string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)
|
||||||
|
|
||||||
|
self.assertEqual(value_annos, {'a': int, 'return': None})
|
||||||
|
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
|
||||||
|
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})
|
||||||
|
|
||||||
|
self.assertTrue(getattr(A.__init__.__annotate__, "__generated_by_dataclasses__"))
|
||||||
|
|
||||||
|
def test_annotate_function_forwardref(self):
|
||||||
|
# With forward references
|
||||||
|
@dataclass
|
||||||
|
class B:
|
||||||
|
b: undefined
|
||||||
|
|
||||||
|
# VALUE annotations should raise while unresolvable
|
||||||
|
with self.assertRaises(NameError):
|
||||||
|
_ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
|
||||||
|
|
||||||
|
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
|
||||||
|
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
|
||||||
|
|
||||||
|
self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
|
||||||
|
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
|
||||||
|
|
||||||
|
# Now VALUE and FORWARDREF should resolve, STRING should be unchanged
|
||||||
|
undefined = int
|
||||||
|
|
||||||
|
value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
|
||||||
|
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
|
||||||
|
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
|
||||||
|
|
||||||
|
self.assertEqual(value_annos, {'b': int, 'return': None})
|
||||||
|
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
|
||||||
|
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
|
||||||
|
|
||||||
|
def test_annotate_function_init_false(self):
|
||||||
|
# Check `init=False` attributes don't get into the annotations of the __init__ function
|
||||||
|
@dataclass
|
||||||
|
class C:
|
||||||
|
c: str = field(init=False)
|
||||||
|
|
||||||
|
self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})
|
||||||
|
|
||||||
|
def test_annotate_function_contains_forwardref(self):
|
||||||
|
# Check string annotations on objects containing a ForwardRef
|
||||||
|
@dataclass
|
||||||
|
class D:
|
||||||
|
d: list[undefined]
|
||||||
|
|
||||||
|
with self.assertRaises(NameError):
|
||||||
|
annotationlib.get_annotations(D.__init__)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
|
||||||
|
{"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
|
||||||
|
{"d": "list[undefined]", "return": "None"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now test when it is defined
|
||||||
|
undefined = str
|
||||||
|
|
||||||
|
# VALUE should now resolve
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(D.__init__),
|
||||||
|
{"d": list[str], "return": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
|
||||||
|
{"d": list[str], "return": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
|
||||||
|
{"d": "list[undefined]", "return": "None"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_annotate_function_not_replaced(self):
|
||||||
|
# Check that __annotate__ is not replaced on non-generated __init__ functions
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class E:
|
||||||
|
x: str
|
||||||
|
def __init__(self, x: int) -> None:
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(hasattr(E.__init__.__annotate__, "__generated_by_dataclasses__"))
|
||||||
|
|
||||||
|
def test_init_false_forwardref(self):
|
||||||
|
# Test forward references in fields not required for __init__ annotations.
|
||||||
|
|
||||||
|
# At the moment this raises a NameError for VALUE annotations even though the
|
||||||
|
# undefined annotation is not required for the __init__ annotations.
|
||||||
|
# Ideally this will be fixed but currently there is no good way to resolve this
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class F:
|
||||||
|
not_in_init: list[undefined] = field(init=False, default=None)
|
||||||
|
in_init: int
|
||||||
|
|
||||||
|
annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
|
||||||
|
self.assertEqual(
|
||||||
|
annos,
|
||||||
|
{"in_init": int, "return": None},
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(NameError):
|
||||||
|
annos = annotationlib.get_annotations(F.__init__) # NameError on not_in_init
|
||||||
|
|
||||||
|
|
||||||
class TestRepr(unittest.TestCase):
|
class TestRepr(unittest.TestCase):
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -3831,7 +3960,15 @@ class TestSlots(unittest.TestCase):
|
|||||||
|
|
||||||
return SlotsTest
|
return SlotsTest
|
||||||
|
|
||||||
for make in (make_simple, make_with_annotations, make_with_annotations_and_method):
|
def make_with_forwardref():
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class SlotsTest:
|
||||||
|
x: undefined
|
||||||
|
y: list[undefined]
|
||||||
|
|
||||||
|
return SlotsTest
|
||||||
|
|
||||||
|
for make in (make_simple, make_with_annotations, make_with_annotations_and_method, make_with_forwardref):
|
||||||
with self.subTest(make=make):
|
with self.subTest(make=make):
|
||||||
C = make()
|
C = make()
|
||||||
support.gc_collect()
|
support.gc_collect()
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
:mod:`dataclasses` Fix annotations for generated ``__init__`` methods by replacing the annotations that were in-line in the generated source code with ``__annotate__`` functions attached to the methods.
|
||||||
Reference in New Issue
Block a user