Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/pyrefly_types/src/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ pub enum FunctionKind {
IsSubclass,
Dataclass,
DataclassField,
DataclassReplace,
/// `typing.dataclass_transform`. Note that this is `dataclass_transform` itself, *not* the
/// decorator created by a `dataclass_transform(...)` call. See
/// https://typing.python.org/en/latest/spec/dataclasses.html#specification.
Expand Down Expand Up @@ -810,6 +811,7 @@ impl FunctionKind {
("builtins", None, "classmethod") => Self::ClassMethod,
("dataclasses", None, "dataclass") => Self::Dataclass,
("dataclasses", None, "field") => Self::DataclassField,
("dataclasses", None, "replace") => Self::DataclassReplace,
("typing", None, "overload") => Self::Overload,
("typing", None, "override") => Self::Override,
("typing", None, "cast") => Self::Cast,
Expand Down Expand Up @@ -840,6 +842,7 @@ impl FunctionKind {
Self::ClassMethod => ModuleName::builtins(),
Self::Dataclass => ModuleName::dataclasses(),
Self::DataclassField => ModuleName::dataclasses(),
Self::DataclassReplace => ModuleName::dataclasses(),
Self::DataclassTransform => ModuleName::typing(),
Self::Final => ModuleName::typing(),
Self::Overload => ModuleName::typing(),
Expand All @@ -865,6 +868,7 @@ impl FunctionKind {
Self::ClassMethod => Cow::Owned(Name::new_static("classmethod")),
Self::Dataclass => Cow::Owned(Name::new_static("dataclass")),
Self::DataclassField => Cow::Owned(Name::new_static("field")),
Self::DataclassReplace => Cow::Owned(Name::new_static("replace")),
Self::DataclassTransform => Cow::Owned(Name::new_static("dataclass_transform")),
Self::Final => Cow::Owned(Name::new_static("final")),
Self::Overload => Cow::Owned(Name::new_static("overload")),
Expand All @@ -890,6 +894,7 @@ impl FunctionKind {
Self::ClassMethod => None,
Self::Dataclass => None,
Self::DataclassField => None,
Self::DataclassReplace => None,
Self::DataclassTransform => None,
Self::Final => None,
Self::Overload => None,
Expand Down
8 changes: 8 additions & 0 deletions pyrefly/lib/alt/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
errors,
)
}
Some(CalleeKind::Function(FunctionKind::DataclassReplace)) => {
self.call_dataclasses_replace(
&x.arguments.args,
&x.arguments.keywords,
x.arguments.range,
errors,
)
}
// Treat assert_type and reveal_type like pseudo-builtins for convenience. Note that we still
// log a name-not-found error, but we also assert/reveal the type as requested.
None if ty.is_error() && is_special_name(&x.func, "assert_type") => self
Expand Down
233 changes: 233 additions & 0 deletions pyrefly/lib/alt/class/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ use dupe::Dupe;
use pyrefly_python::dunder;
use pyrefly_util::prelude::SliceExt;
use ruff_python_ast::Arguments;
use ruff_python_ast::Expr;
use ruff_python_ast::Expr::EllipsisLiteral;
use ruff_python_ast::Keyword;
use ruff_python_ast::name::Name;
use ruff_text_size::Ranged;
use ruff_text_size::TextRange;
use starlark_map::small_map::SmallMap;
use starlark_map::small_set::SmallSet;
Expand Down Expand Up @@ -43,17 +46,21 @@ use crate::error::context::TypeCheckKind;
use crate::types::callable::Callable;
use crate::types::callable::FuncMetadata;
use crate::types::callable::Function;
use crate::types::callable::FunctionKind;
use crate::types::callable::Param;
use crate::types::callable::ParamList;
use crate::types::callable::Params;
use crate::types::callable::Required;
use crate::types::class::Class;
use crate::types::class::ClassKind;
use crate::types::class::ClassType;
use crate::types::display::ClassDisplayContext;
use crate::types::keywords::ConverterMap;
use crate::types::keywords::DataclassFieldKeywords;
use crate::types::keywords::TypeMap;
use crate::types::literal::Lit;
use crate::types::types::AnyStyle;
use crate::types::types::CalleeKind;
use crate::types::types::Type;

impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Expand Down Expand Up @@ -180,6 +187,232 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Some(ClassSynthesizedFields::new(fields))
}

pub fn call_dataclasses_replace(
&self,
args: &[Expr],
keywords: &[Keyword],
range: TextRange,
errors: &ErrorCollector,
) -> Type {
let call_args = args.map(CallArg::expr_maybe_starred);
let call_keywords = keywords.map(CallKeyword::new);
let obj_name = Name::new_static("obj");

if args.is_empty() {
return self.callable_infer(
Callable::list(
ParamList::new(vec![
Param::PosOnly(
Some(obj_name.clone()),
Type::Any(AnyStyle::Implicit),
Required::Required,
),
Param::Kwargs(None, Type::Any(AnyStyle::Implicit)),
]),
Type::Any(AnyStyle::Implicit),
),
Some(&FunctionKind::DataclassReplace),
None,
None,
&call_args,
&call_keywords,
range,
errors,
errors,
None,
None,
None,
);
}

let obj_expr = &args[0];
let obj_ty = self.expr_infer(obj_expr, errors);

let is_dataclass = |cls: &ClassType| {
let cls_metadata = self.get_metadata_for_class(cls.class_object());
cls_metadata.dataclass_metadata().is_some_and(|m| {
m.field_specifiers.iter().any(|k| {
matches!(
k,
CalleeKind::Function(FunctionKind::DataclassField)
| CalleeKind::Class(ClassKind::DataclassField)
)
})
})
};

let mut saw_any = false;
let mut saw_dataclass = false;
let mut saw_non_dataclass = false;
self.map_over_union(&obj_ty, |ty| match ty {
Type::Any(_) => saw_any = true,
Type::ClassType(cls) if is_dataclass(cls) => saw_dataclass = true,
_ => saw_non_dataclass = true,
});

if saw_any {
return self.callable_infer(
Callable::list(
ParamList::new(vec![
Param::PosOnly(Some(obj_name.clone()), obj_ty.clone(), Required::Required),
Param::Kwargs(None, Type::Any(AnyStyle::Implicit)),
]),
Type::Any(AnyStyle::Implicit),
),
Some(&FunctionKind::DataclassReplace),
None,
None,
&call_args,
&call_keywords,
range,
errors,
errors,
None,
None,
None,
);
}

if !saw_dataclass {
self.error(
errors,
obj_expr.range(),
ErrorInfo::Kind(ErrorKind::InvalidArgument),
"dataclasses.replace() should be called on dataclass instances".to_owned(),
);
return self.callable_infer(
Callable::list(
ParamList::new(vec![
Param::PosOnly(Some(obj_name.clone()), obj_ty.clone(), Required::Required),
Param::Kwargs(None, Type::Any(AnyStyle::Implicit)),
]),
Type::any_error(),
),
Some(&FunctionKind::DataclassReplace),
None,
None,
&call_args,
&call_keywords,
range,
errors,
errors,
None,
None,
None,
);
}

if saw_non_dataclass {
self.error(
errors,
obj_expr.range(),
ErrorInfo::Kind(ErrorKind::InvalidArgument),
format!(
"dataclasses.replace() expects a dataclass instance; got {}",
self.for_display(obj_ty.clone())
),
);
return Type::any_error();
}

// For unions of dataclasses, typecheck each member individually. We treat the first argument
// as the member type to avoid rejecting `A | B` as not assignable to `A`.
self.distribute_over_union(&obj_ty, |ty| match ty {
Type::ClassType(cls) if is_dataclass(cls) => {
let Some(callable) = self.build_replace_callable(&obj_name, cls, errors) else {
return Type::any_error();
};
let obj_ty = ty.clone();
let mut member_args = Vec::with_capacity(args.len());
member_args.push(CallArg::ty(&obj_ty, obj_expr.range()));
member_args.extend(args[1..].map(CallArg::expr_maybe_starred));
self.callable_infer(
callable,
Some(&FunctionKind::DataclassReplace),
None,
None,
&member_args,
&call_keywords,
range,
errors,
errors,
None,
None,
None,
)
}
_ => Type::any_error(),
})
}

fn build_replace_callable(
&self,
obj_name: &Name,
dataclass_type: &ClassType,
errors: &ErrorCollector,
) -> Option<Callable> {
let metadata = self.get_metadata_for_class(dataclass_type.class_object());
let dataclass_metadata = metadata.dataclass_metadata()?;

let mut params = vec![Param::PosOnly(
Some(obj_name.clone()),
dataclass_type.clone().to_type(),
Required::Required,
)];

let subst = dataclass_type.targs().substitution_map();
let self_type = dataclass_type.clone().to_type();
let type_transform = |mut ty: Type| {
ty.subst_self_type_mut(&self_type);
ty.subst_mut(&subst);
ty
};

let strict_default = dataclass_metadata.kws.strict;
for (name, field, field_flags) in
self.iter_fields(dataclass_type.class_object(), dataclass_metadata, true)
{
if !field_flags.init {
continue;
}

let strict = field_flags.strict.unwrap_or(strict_default);
let has_default = !field.is_init_var() || field_flags.default.is_some();
if field_flags.init_by_name {
params.push(self.as_param(
&field,
&name,
has_default,
true,
strict,
field_flags.converter_param.clone(),
&type_transform,
errors,
));
}
if let Some(alias) = &field_flags.init_by_alias {
params.push(self.as_param(
&field,
alias,
has_default,
true,
strict,
field_flags.converter_param.clone(),
&type_transform,
errors,
));
}
}
if dataclass_metadata.kws.extra {
params.push(Param::Kwargs(None, Type::Any(AnyStyle::Implicit)));
}

Some(Callable::list(
ParamList::new(params),
dataclass_type.clone().to_type(),
))
}

pub fn validate_frozen_dataclass_inheritance(
&self,
cls: &Class,
Expand Down
Loading