"""
Type checker for the yapCAD DSL.
Traverses the AST and validates types, collecting diagnostics
for type errors, undefined identifiers, and other semantic issues.
"""
from typing import List, Optional, Dict, Any, Union
from dataclasses import dataclass
from .ast import (
AstNode, Module, Command, FunctionDef, Parameter, Block,
Statement, LetStatement, VarDecl, AssignmentStatement,
RequireStatement, AssertStatement,
EmitStatement, ForStatement, WhileStatement, IfStatement,
ExpressionStatement, ReturnStatement, PassStatement,
PythonBlock, NativeBlock, NativeFunctionDecl, NativeFunction, Decorator,
ElifBranch,
Expression, Literal, Identifier, BinaryOp, UnaryOp,
FunctionCall, MethodCall, MemberAccess, IndexAccess,
ListLiteral, ListComprehension, RangeExpr, IfExpr, MatchExpr,
MatchArm, Pattern, LiteralPattern, IdentifierPattern, WildcardPattern,
LambdaExpr, PythonExpr, DictLiteral,
TypeNode, SimpleType, GenericType, OptionalType as AstOptionalType,
UseStatement, ExportUseStatement,
)
from .tokens import TokenType, SourceSpan
from .types import (
Type, PrimitiveType, GeometricPrimitiveType, CurveType,
CompoundCurveType, SurfaceType, SolidType,
ListType, DictType, OptionalTypeWrapper, FunctionType,
ERROR, UNKNOWN, NONE,
INT, FLOAT, BOOL, STRING,
POINT, POINT2D, POINT3D, VECTOR, VECTOR2D, VECTOR3D, TRANSFORM,
SOLID, REGION2D, DICT,
resolve_type_name, make_list_type, make_optional_type,
is_numeric, is_curve, is_geometry, common_type,
)
from .symbols import (
SymbolTable, Symbol, SymbolKind, FunctionSignature,
get_method_signature,
)
from .errors import (
Diagnostic, DiagnosticCollector, ErrorSeverity,
DslError, TypeError as DslTypeError,
)
[docs]
@dataclass
class CheckResult:
"""Result of type checking a module."""
diagnostics: List[Diagnostic]
has_errors: bool
has_warnings: bool
has_python_blocks: bool # Module contains Python blocks (requires review)
[docs]
class TypeChecker:
"""
Type checker for the yapCAD DSL.
Traverses the AST and validates:
- Type compatibility in assignments and function calls
- Return type matching
- Require expression boolean constraint
- Emit target type matching command return type
- Undefined identifier detection
- Python block flagging
"""
def __init__(self, max_errors: int = 20):
self.symbols = SymbolTable()
self.diagnostics = DiagnosticCollector()
self.max_errors = max_errors
self._current_command: Optional[Command] = None
self._has_python_blocks = False
[docs]
def check(self, module: Module) -> CheckResult:
"""Type check a complete module."""
self._check_module(module)
return CheckResult(
diagnostics=self.diagnostics.diagnostics,
has_errors=self.diagnostics.has_errors,
has_warnings=self.diagnostics.has_warnings,
has_python_blocks=self._has_python_blocks
)
# =========================================================================
# Module/Command Checking
# =========================================================================
def _check_module(self, module: Module) -> None:
"""Check a complete module."""
# First pass: register native block exports (they provide functions to commands)
for native_block in module.native_blocks:
self._register_native_block(native_block)
# Register @native decorated functions
for native_func in getattr(module, 'native_functions', []):
self._register_native_function(native_func)
# Second pass: register all commands/functions
for command in module.commands:
self._register_command(command)
# Third pass: check command/function bodies
for command in module.commands:
self._check_command(command)
def _register_native_block(self, native_block: NativeBlock) -> None:
"""Register exported functions from a native block."""
self._has_python_blocks = True # Native blocks contain Python
for func_decl in native_block.exports:
return_type = self._resolve_type_node(func_decl.return_type)
# Build function type from parameters
param_types = []
for param in func_decl.parameters:
param_type = self._resolve_type_node(param.type_annotation)
param_types.append(param_type)
func_type = FunctionType(tuple(param_types), return_type)
symbol = Symbol(
name=func_decl.name,
kind=SymbolKind.FUNCTION,
type=func_type,
span=func_decl.span
)
if not self.symbols.define(symbol):
self._error(
f"Function '{func_decl.name}' is already defined",
func_decl.span,
"E205"
)
self._warning(
"Native Python block requires manual review for type safety",
native_block.span,
"W212"
)
def _register_native_function(self, native_func: NativeFunction) -> None:
"""Register a @native decorated function."""
self._has_python_blocks = True # Native functions contain Python
return_type = self._resolve_type_node(native_func.return_type)
# Build function type from parameters
param_types = []
for param in native_func.parameters:
if param.type_annotation is not None:
param_type = self._resolve_type_node(param.type_annotation)
else:
param_type = UNKNOWN
param_types.append(param_type)
func_type = FunctionType(tuple(param_types), return_type)
symbol = Symbol(
name=native_func.name,
kind=SymbolKind.FUNCTION,
type=func_type,
span=native_func.span
)
if not self.symbols.define(symbol):
self._error(
f"Function '{native_func.name}' is already defined",
native_func.span,
"E205"
)
self._warning(
"Native function requires manual review for type safety",
native_func.span,
"W213"
)
def _register_command(self, command: Command) -> None:
"""Register a command/function in the symbol table."""
# Handle optional return type (new Pythonic syntax)
if command.return_type is not None:
return_type = self._resolve_type_node(command.return_type)
else:
return_type = UNKNOWN # Will be inferred
# Build function type from parameters
param_types = []
for param in command.parameters:
if param.type_annotation is not None:
param_type = self._resolve_type_node(param.type_annotation)
else:
param_type = UNKNOWN # Type inference
param_types.append(param_type)
func_type = FunctionType(tuple(param_types), return_type)
symbol = Symbol(
name=command.name,
kind=SymbolKind.COMMAND,
type=func_type,
span=command.span
)
if not self.symbols.define(symbol):
self._error(
f"Command '{command.name}' is already defined",
command.span,
"E201"
)
def _check_command(self, command: Command) -> None:
"""Type check a command/function definition."""
self._current_command = command
self.symbols.push_scope(f"function {command.name}")
# Register parameters
for param in command.parameters:
if param.type_annotation is not None:
param_type = self._resolve_type_node(param.type_annotation)
else:
param_type = UNKNOWN
symbol = Symbol(
name=param.name,
kind=SymbolKind.PARAMETER,
type=param_type,
span=param.span
)
if not self.symbols.define(symbol):
self._error(
f"Duplicate parameter '{param.name}'",
param.span,
"E202"
)
# Check default value type
if param.default_value is not None:
default_type = self._check_expression(param.default_value)
if param_type != UNKNOWN and not param_type.is_assignable_from(default_type):
self._error(
f"Default value type '{default_type}' is not assignable to "
f"parameter type '{param_type}'",
param.default_value.span,
"E203"
)
# Check body
self._check_block(command.body)
self.symbols.pop_scope()
self._current_command = None
# =========================================================================
# Statement Checking
# =========================================================================
def _check_block(self, block: Block) -> Optional[Type]:
"""Check a block of statements, returning the final expression type if any."""
for stmt in block.statements:
self._check_statement(stmt)
if block.final_expression is not None:
return self._check_expression(block.final_expression)
return None
def _check_statement(self, stmt: Statement) -> None:
"""Check a statement."""
if self.diagnostics.error_count >= self.max_errors:
return
if isinstance(stmt, LetStatement): # Also handles VarDecl (alias)
self._check_let_statement(stmt)
elif isinstance(stmt, AssignmentStatement):
self._check_assignment_statement(stmt)
elif isinstance(stmt, RequireStatement): # Also handles AssertStatement (alias)
self._check_require_statement(stmt)
elif isinstance(stmt, EmitStatement):
self._check_emit_statement(stmt)
elif isinstance(stmt, ForStatement):
self._check_for_statement(stmt)
elif isinstance(stmt, WhileStatement):
self._check_while_statement(stmt)
elif isinstance(stmt, IfStatement):
self._check_if_statement(stmt)
elif isinstance(stmt, PassStatement):
pass # PassStatement has no semantic content to check
elif isinstance(stmt, ExpressionStatement):
self._check_expression(stmt.expression)
elif isinstance(stmt, ReturnStatement):
self._check_return_statement(stmt)
elif isinstance(stmt, PythonBlock):
self._check_python_block(stmt)
else:
self._warning(
f"Unknown statement type: {type(stmt).__name__}",
stmt.span,
"W201"
)
def _check_let_statement(self, stmt: LetStatement) -> None:
"""Check a let statement."""
init_type = self._check_expression(stmt.initializer)
if stmt.type_annotation is not None:
declared_type = self._resolve_type_node(stmt.type_annotation)
if not declared_type.is_assignable_from(init_type):
self._error(
f"Cannot assign '{init_type}' to variable of type '{declared_type}'",
stmt.initializer.span,
"E210"
)
var_type = declared_type
else:
var_type = init_type
symbol = Symbol(
name=stmt.name,
kind=SymbolKind.VARIABLE,
type=var_type,
span=stmt.span,
is_mutable=False
)
if not self.symbols.define(symbol):
self._error(
f"Variable '{stmt.name}' is already defined in this scope",
stmt.span,
"E211"
)
def _check_assignment_statement(self, stmt: AssignmentStatement) -> None:
"""Check an assignment statement."""
target_type = self._check_expression(stmt.target)
value_type = self._check_expression(stmt.value)
if not target_type.is_assignable_from(value_type):
self._error(
f"Cannot assign '{value_type}' to target of type '{target_type}'",
stmt.value.span,
"E212"
)
# Check that target is an l-value (identifier or member/index access)
if isinstance(stmt.target, Identifier):
symbol = self.symbols.lookup(stmt.target.name)
if symbol is not None and not symbol.is_mutable:
# DSL variables are immutable by default, but reassignment is allowed
# Update: actually, let's allow reassignment for now
pass
def _check_require_statement(self, stmt: RequireStatement) -> None:
"""Check a require statement."""
cond_type = self._check_expression(stmt.condition)
if cond_type != BOOL and cond_type != ERROR:
self._error(
f"Require condition must be boolean, got '{cond_type}'",
stmt.condition.span,
"E220"
)
if stmt.message is not None:
msg_type = self._check_expression(stmt.message)
if msg_type != STRING and msg_type != ERROR:
self._error(
f"Require message must be string, got '{msg_type}'",
stmt.message.span,
"E221"
)
def _check_emit_statement(self, stmt: EmitStatement) -> None:
"""Check an emit statement."""
value_type = self._check_expression(stmt.value)
# Check that emit type matches command return type
if self._current_command is not None and self._current_command.return_type is not None:
return_type = self._resolve_type_node(self._current_command.return_type)
if not return_type.is_assignable_from(value_type):
self._error(
f"Emit value type '{value_type}' does not match command "
f"return type '{return_type}'",
stmt.value.span,
"E230"
)
# Check metadata if present
# Metadata can be a DictLiteral (old syntax) or a plain dict (new kwargs syntax)
if stmt.metadata is not None:
if isinstance(stmt.metadata, DictLiteral):
self._check_expression(stmt.metadata)
elif isinstance(stmt.metadata, dict):
# New syntax: emit value, name="x", material="y"
# metadata is stored as a dict of key -> expression
for key, expr in stmt.metadata.items():
self._check_expression(expr)
def _check_for_statement(self, stmt: ForStatement) -> None:
"""Check a for statement."""
iterable_type = self._check_expression(stmt.iterable)
# Determine element type from iterable
if isinstance(iterable_type, ListType):
elem_type = iterable_type.element_type
elif isinstance(iterable_type, RangeExpr):
elem_type = INT
else:
# Range expressions parsed as RangeExpr, check for int range
elem_type = INT # Assume numeric iteration
# Create new scope for loop body
self.symbols.push_scope("for loop")
symbol = Symbol(
name=stmt.variable,
kind=SymbolKind.VARIABLE,
type=elem_type,
span=stmt.span,
is_mutable=False
)
self.symbols.define(symbol)
self._check_block(stmt.body)
self.symbols.pop_scope()
def _check_while_statement(self, stmt: WhileStatement) -> None:
"""Check a while statement."""
cond_type = self._check_expression(stmt.condition)
if cond_type != BOOL and cond_type != ERROR:
self._error(
f"While condition must be boolean, got '{cond_type}'",
stmt.condition.span,
"E225"
)
# Create new scope for loop body
self.symbols.push_scope("while loop")
self._check_block(stmt.body)
self.symbols.pop_scope()
def _check_if_statement(self, stmt: IfStatement) -> None:
"""Check a block-level if statement."""
cond_type = self._check_expression(stmt.condition)
if cond_type != BOOL and cond_type != ERROR:
self._error(
f"If condition must be boolean, got '{cond_type}'",
stmt.condition.span,
"E226"
)
# Check then branch
self.symbols.push_scope("if then")
self._check_block(stmt.then_branch)
self.symbols.pop_scope()
# Check elif branches
for elif_branch in stmt.elif_branches:
elif_cond_type = self._check_expression(elif_branch.condition)
if elif_cond_type != BOOL and elif_cond_type != ERROR:
self._error(
f"Elif condition must be boolean, got '{elif_cond_type}'",
elif_branch.condition.span,
"E227"
)
self.symbols.push_scope("elif")
self._check_block(elif_branch.body)
self.symbols.pop_scope()
# Check else branch if present
if stmt.else_branch is not None:
self.symbols.push_scope("else")
self._check_block(stmt.else_branch)
self.symbols.pop_scope()
def _check_return_statement(self, stmt: ReturnStatement) -> None:
"""Check a return statement (for Python blocks)."""
value_type = self._check_expression(stmt.value)
declared_type = self._resolve_type_node(stmt.return_type)
# The declared type is the bridge type - we trust the user's annotation
# but flag it as requiring review
self._has_python_blocks = True
def _check_python_block(self, stmt: PythonBlock) -> None:
"""Check a Python block (just flag it for review)."""
self._has_python_blocks = True
self._warning(
"Python block requires manual review for type safety",
stmt.span,
"W210"
)
# =========================================================================
# Expression Checking
# =========================================================================
def _check_expression(self, expr: Expression) -> Type:
"""Check an expression and return its type."""
if self.diagnostics.error_count >= self.max_errors:
return ERROR
if isinstance(expr, Literal):
return self._check_literal(expr)
elif isinstance(expr, Identifier):
return self._check_identifier(expr)
elif isinstance(expr, BinaryOp):
return self._check_binary_op(expr)
elif isinstance(expr, UnaryOp):
return self._check_unary_op(expr)
elif isinstance(expr, FunctionCall):
return self._check_function_call(expr)
elif isinstance(expr, MethodCall):
return self._check_method_call(expr)
elif isinstance(expr, MemberAccess):
return self._check_member_access(expr)
elif isinstance(expr, IndexAccess):
return self._check_index_access(expr)
elif isinstance(expr, ListLiteral):
return self._check_list_literal(expr)
elif isinstance(expr, ListComprehension):
return self._check_list_comprehension(expr)
elif isinstance(expr, RangeExpr):
return self._check_range_expr(expr)
elif isinstance(expr, IfExpr):
return self._check_if_expr(expr)
elif isinstance(expr, MatchExpr):
return self._check_match_expr(expr)
elif isinstance(expr, DictLiteral):
return self._check_dict_literal(expr)
elif isinstance(expr, LambdaExpr):
return self._check_lambda_expr(expr)
elif isinstance(expr, PythonExpr):
return self._check_python_expr(expr)
else:
self._warning(
f"Unknown expression type: {type(expr).__name__}",
expr.span,
"W220"
)
return ERROR
def _check_literal(self, expr: Literal) -> Type:
"""Check a literal and return its type."""
if expr.literal_type == TokenType.INT_LITERAL:
return INT
elif expr.literal_type == TokenType.FLOAT_LITERAL:
return FLOAT
elif expr.literal_type == TokenType.STRING_LITERAL:
return STRING
elif expr.literal_type == TokenType.BOOL_LITERAL:
return BOOL
else:
return ERROR
def _check_identifier(self, expr: Identifier) -> Type:
"""Check an identifier reference."""
# First check local symbols
symbol = self.symbols.lookup(expr.name)
if symbol is not None:
return symbol.type
# Check if it's a type used as constructor (e.g., point(...))
type_val = resolve_type_name(expr.name)
if type_val is not None:
# Types used as identifiers are constructors
return UNKNOWN # Will be resolved in function call
# Check if it's a built-in function
if self.symbols.is_builtin(expr.name):
return UNKNOWN # Function type, resolved in call
self._error(
f"Undefined identifier '{expr.name}'",
expr.span,
"E240"
)
return ERROR
def _check_binary_op(self, expr: BinaryOp) -> Type:
"""Check a binary operation."""
left_type = self._check_expression(expr.left)
right_type = self._check_expression(expr.right)
op = expr.operator
# Arithmetic operators
if op in (TokenType.PLUS, TokenType.MINUS, TokenType.STAR,
TokenType.SLASH, TokenType.PERCENT, TokenType.DOUBLE_SLASH,
TokenType.DOUBLE_STAR):
if is_numeric(left_type) and is_numeric(right_type):
# int op int -> int (except for division), int op float -> float
# // (integer division) always returns int
# ** (power) follows standard type promotion
if op == TokenType.DOUBLE_SLASH:
return INT # Integer division always returns int
if left_type == FLOAT or right_type == FLOAT or op == TokenType.SLASH:
return FLOAT
return INT
# Vector/point arithmetic
if isinstance(left_type, GeometricPrimitiveType) and is_numeric(right_type):
return left_type # scalar multiplication
if isinstance(left_type, GeometricPrimitiveType) and isinstance(right_type, GeometricPrimitiveType):
if op in (TokenType.PLUS, TokenType.MINUS):
return common_type(left_type, right_type) or left_type
# List concatenation with +
if op == TokenType.PLUS and isinstance(left_type, ListType) and isinstance(right_type, ListType):
# Lists of compatible element types can be concatenated
elem_common = common_type(left_type.element_type, right_type.element_type)
if elem_common is not None:
return make_list_type(elem_common)
# If no common type, use left element type
return left_type
if left_type != ERROR and right_type != ERROR:
self._error(
f"Cannot apply '{op.name}' to '{left_type}' and '{right_type}'",
expr.span,
"E250"
)
return ERROR
# Comparison operators
if op in (TokenType.LT, TokenType.GT, TokenType.LE, TokenType.GE):
if is_numeric(left_type) and is_numeric(right_type):
return BOOL
if left_type != ERROR and right_type != ERROR:
self._error(
f"Cannot compare '{left_type}' and '{right_type}' with '{op.name}'",
expr.span,
"E251"
)
return ERROR
# Equality operators
if op in (TokenType.EQ, TokenType.NE):
# Most types can be compared for equality
return BOOL
# Logical operators
if op in (TokenType.AND, TokenType.OR):
if left_type != BOOL and left_type != ERROR:
self._error(
f"Left operand of '{op.name}' must be boolean, got '{left_type}'",
expr.left.span,
"E252"
)
if right_type != BOOL and right_type != ERROR:
self._error(
f"Right operand of '{op.name}' must be boolean, got '{right_type}'",
expr.right.span,
"E253"
)
return BOOL
return ERROR
def _check_unary_op(self, expr: UnaryOp) -> Type:
"""Check a unary operation."""
operand_type = self._check_expression(expr.operand)
if expr.operator == TokenType.NOT:
if operand_type != BOOL and operand_type != ERROR:
self._error(
f"Operand of '!' must be boolean, got '{operand_type}'",
expr.operand.span,
"E260"
)
return BOOL
if expr.operator == TokenType.MINUS:
if is_numeric(operand_type):
return operand_type
if isinstance(operand_type, GeometricPrimitiveType):
return operand_type # Negating vectors/points
if operand_type != ERROR:
self._error(
f"Cannot negate '{operand_type}'",
expr.operand.span,
"E261"
)
return ERROR
return ERROR
def _check_function_call(self, expr: FunctionCall) -> Type:
"""Check a function call."""
# Get the callee name
callee_name = None
if isinstance(expr.callee, Identifier):
callee_name = expr.callee.name
if callee_name is None:
# Complex callee expression
callee_type = self._check_expression(expr.callee)
return UNKNOWN
# Check for built-in function
builtin = self.symbols.lookup_builtin(callee_name)
if builtin is not None:
return self._check_builtin_call(builtin, expr)
# Check for type constructor
type_val = resolve_type_name(callee_name)
if type_val is not None:
# Type constructor - check argument count loosely
for arg in expr.arguments:
self._check_expression(arg)
for arg in expr.named_arguments.values():
self._check_expression(arg)
return type_val
# Check for user-defined command or native function
symbol = self.symbols.lookup(callee_name)
if symbol is not None and symbol.kind in (SymbolKind.COMMAND, SymbolKind.FUNCTION):
if isinstance(symbol.type, FunctionType):
# Check argument count and types for native functions
func_type = symbol.type
for i, arg in enumerate(expr.arguments):
arg_type = self._check_expression(arg)
if i < len(func_type.param_types):
param_type = func_type.param_types[i]
if param_type != UNKNOWN and not param_type.is_assignable_from(arg_type):
if arg_type != ERROR:
self._error(
f"Argument {i+1} expects '{param_type}', got '{arg_type}'",
arg.span,
"E274"
)
return func_type.return_type
return UNKNOWN
self._error(
f"Unknown function '{callee_name}'",
expr.callee.span,
"E270"
)
return ERROR
def _check_builtin_call(self, sig: FunctionSignature, expr: FunctionCall) -> Type:
"""Check a call to a built-in function."""
# Check argument count
required_params = [p for p in sig.params if p[2] is None]
if len(expr.arguments) < len(required_params):
self._error(
f"Function '{sig.name}' requires at least {len(required_params)} "
f"arguments, got {len(expr.arguments)}",
expr.span,
"E271"
)
# Check argument types and collect them for type inference
arg_types = []
for i, arg in enumerate(expr.arguments):
arg_type = self._check_expression(arg)
arg_types.append(arg_type)
if i < len(sig.params):
param_name, param_type, _ = sig.params[i]
if param_type != UNKNOWN and not param_type.is_assignable_from(arg_type):
if arg_type != ERROR:
self._error(
f"Argument '{param_name}' expects '{param_type}', "
f"got '{arg_type}'",
arg.span,
"E272"
)
# Check named arguments
param_names = {p[0] for p in sig.params}
for name, arg in expr.named_arguments.items():
arg_type = self._check_expression(arg)
if name not in param_names:
self._error(
f"Unknown parameter '{name}' for function '{sig.name}'",
arg.span,
"E273"
)
# Infer return type for list functions based on argument types
return_type = sig.return_type
if isinstance(return_type, ListType) and return_type.element_type == UNKNOWN:
# Functions that return list<unknown> need type inference
if sig.name in ('concat', 'reverse', 'flatten'):
if arg_types and isinstance(arg_types[0], ListType):
if sig.name == 'flatten' and isinstance(arg_types[0].element_type, ListType):
# flatten: list<list<T>> -> list<T>
return_type = arg_types[0].element_type
else:
# concat, reverse: list<T> -> list<T>
return_type = arg_types[0]
return return_type
def _check_method_call(self, expr: MethodCall) -> Type:
"""Check a method call."""
obj_type = self._check_expression(expr.object)
# Get method signature
method_sig = get_method_signature(obj_type, expr.method)
if method_sig is None:
if obj_type != ERROR:
self._error(
f"Type '{obj_type}' has no method '{expr.method}'",
expr.span,
"E280"
)
return ERROR
# Check arguments
for i, arg in enumerate(expr.arguments):
arg_type = self._check_expression(arg)
if i < len(method_sig.params):
param_name, param_type, _ = method_sig.params[i]
if not param_type.is_assignable_from(arg_type):
if arg_type != ERROR:
self._error(
f"Method argument '{param_name}' expects '{param_type}', "
f"got '{arg_type}'",
arg.span,
"E281"
)
return method_sig.return_type
def _check_member_access(self, expr: MemberAccess) -> Type:
"""Check member access (e.g., point.x)."""
obj_type = self._check_expression(expr.object)
# Known member accesses
if isinstance(obj_type, GeometricPrimitiveType):
if obj_type.name.startswith("point") or obj_type.name.startswith("vector"):
if expr.member in ("x", "y", "z"):
return FLOAT
if obj_type != ERROR:
self._error(
f"Type '{obj_type}' has no member '{expr.member}'",
expr.span,
"E290"
)
return ERROR
def _check_index_access(self, expr: IndexAccess) -> Type:
"""Check index access (e.g., list[0])."""
obj_type = self._check_expression(expr.object)
idx_type = self._check_expression(expr.index)
if idx_type != INT and idx_type != ERROR:
self._error(
f"Index must be integer, got '{idx_type}'",
expr.index.span,
"E291"
)
if isinstance(obj_type, ListType):
return obj_type.element_type
if obj_type != ERROR:
self._error(
f"Cannot index into '{obj_type}'",
expr.object.span,
"E292"
)
return ERROR
def _check_list_literal(self, expr: ListLiteral) -> Type:
"""Check a list literal."""
if not expr.elements:
return make_list_type(UNKNOWN)
# Infer element type from first element
elem_type = self._check_expression(expr.elements[0])
for i, elem in enumerate(expr.elements[1:], start=1):
t = self._check_expression(elem)
if not elem_type.is_assignable_from(t):
ct = common_type(elem_type, t)
if ct is not None:
elem_type = ct
else:
self._error(
f"List element {i} has type '{t}', expected '{elem_type}'",
elem.span,
"E293"
)
return make_list_type(elem_type)
def _check_list_comprehension(self, expr: ListComprehension) -> Type:
"""Check a list comprehension."""
iterable_type = self._check_expression(expr.iterable)
# Determine element type
if isinstance(iterable_type, ListType):
iter_elem_type = iterable_type.element_type
else:
iter_elem_type = INT # Assume range
# Create scope for comprehension variable
self.symbols.push_scope("list comprehension")
symbol = Symbol(
name=expr.variable,
kind=SymbolKind.VARIABLE,
type=iter_elem_type,
span=expr.span
)
self.symbols.define(symbol)
# Check condition if present
if expr.condition is not None:
cond_type = self._check_expression(expr.condition)
if cond_type != BOOL and cond_type != ERROR:
self._error(
f"Comprehension condition must be boolean, got '{cond_type}'",
expr.condition.span,
"E294"
)
# Check element expression
elem_type = self._check_expression(expr.element_expr)
self.symbols.pop_scope()
return make_list_type(elem_type)
def _check_range_expr(self, expr: RangeExpr) -> Type:
"""Check a range expression (e.g., 0..10)."""
start_type = self._check_expression(expr.start)
end_type = self._check_expression(expr.end)
if start_type != INT and start_type != ERROR:
self._error(
f"Range start must be integer, got '{start_type}'",
expr.start.span,
"E295"
)
if end_type != INT and end_type != ERROR:
self._error(
f"Range end must be integer, got '{end_type}'",
expr.end.span,
"E296"
)
return make_list_type(INT)
def _check_if_expr(self, expr: IfExpr) -> Type:
"""Check an if expression."""
cond_type = self._check_expression(expr.condition)
if cond_type != BOOL and cond_type != ERROR:
self._error(
f"If condition must be boolean, got '{cond_type}'",
expr.condition.span,
"E297"
)
then_type = self._check_block(expr.then_branch)
if expr.else_branch is not None:
if isinstance(expr.else_branch, Block):
else_type = self._check_block(expr.else_branch)
else: # IfExpr (else if)
else_type = self._check_if_expr(expr.else_branch)
# Types should match
if then_type and else_type:
ct = common_type(then_type, else_type)
if ct is None and then_type != ERROR and else_type != ERROR:
self._warning(
f"If branches have different types: '{then_type}' and '{else_type}'",
expr.span,
"W230"
)
return ct or then_type
return then_type or UNKNOWN
def _check_match_expr(self, expr: MatchExpr) -> Type:
"""Check a match expression."""
subject_type = self._check_expression(expr.subject)
result_type: Optional[Type] = None
for arm in expr.arms:
# Check pattern (just validate it)
self._check_pattern(arm.pattern, subject_type)
# Check body
arm_type = self._check_expression(arm.body)
if result_type is None:
result_type = arm_type
else:
ct = common_type(result_type, arm_type)
if ct is not None:
result_type = ct
return result_type or UNKNOWN
def _check_pattern(self, pattern: Pattern, expected_type: Type) -> None:
"""Check a match pattern."""
if isinstance(pattern, LiteralPattern):
lit_type = self._check_literal(pattern.value)
if not expected_type.is_assignable_from(lit_type):
self._warning(
f"Pattern type '{lit_type}' may not match subject type '{expected_type}'",
pattern.span,
"W231"
)
elif isinstance(pattern, IdentifierPattern):
# Binding pattern - introduces variable in arm scope
# For now, just note it
pass
elif isinstance(pattern, WildcardPattern):
# Always matches
pass
def _check_dict_literal(self, expr: DictLiteral) -> Type:
"""Check a dictionary literal."""
for key, value in expr.entries.items():
self._check_expression(value)
return DICT
def _check_lambda_expr(self, expr: LambdaExpr) -> Type:
"""Check a lambda expression."""
# Create scope for lambda parameters
self.symbols.push_scope("lambda")
for param in expr.parameters:
symbol = Symbol(
name=param,
kind=SymbolKind.PARAMETER,
type=UNKNOWN, # Type inference would happen here
span=expr.span
)
self.symbols.define(symbol)
body_type = self._check_expression(expr.body)
self.symbols.pop_scope()
param_types = tuple(UNKNOWN for _ in expr.parameters)
return FunctionType(param_types, body_type)
def _check_python_expr(self, expr: PythonExpr) -> Type:
"""Check a Python expression."""
self._has_python_blocks = True
self._warning(
"Python expression requires manual review for type safety",
expr.span,
"W211"
)
return self._resolve_type_node(expr.return_type)
# =========================================================================
# Type Resolution
# =========================================================================
def _resolve_type_node(self, node: Optional[TypeNode]) -> Type:
"""Resolve an AST type node to a Type instance."""
if node is None:
return UNKNOWN # Type will be inferred
if isinstance(node, SimpleType):
resolved = resolve_type_name(node.name)
if resolved is None:
self._error(
f"Unknown type '{node.name}'",
node.span,
"E299"
)
return ERROR
return resolved
elif isinstance(node, GenericType):
if node.name == "list":
if len(node.type_args) != 1:
self._error(
f"list<T> requires exactly one type argument",
node.span,
"E298"
)
return ERROR
elem_type = self._resolve_type_node(node.type_args[0])
return make_list_type(elem_type)
else:
self._error(
f"Unknown generic type '{node.name}'",
node.span,
"E297"
)
return ERROR
elif isinstance(node, AstOptionalType):
inner = self._resolve_type_node(node.inner)
return make_optional_type(inner)
return ERROR
# =========================================================================
# Diagnostics
# =========================================================================
def _error(self, message: str, span: SourceSpan, code: str) -> None:
"""Record an error diagnostic."""
self.diagnostics.add(Diagnostic(
code=code,
message=message,
severity=ErrorSeverity.ERROR,
span=span
))
def _warning(self, message: str, span: SourceSpan, code: str) -> None:
"""Record a warning diagnostic."""
self.diagnostics.add(Diagnostic(
code=code,
message=message,
severity=ErrorSeverity.WARNING,
span=span
))
[docs]
def check(module: Module, max_errors: int = 20) -> CheckResult:
"""
Convenience function to type check a module.
Args:
module: The parsed module AST
max_errors: Maximum errors before stopping (default 20)
Returns:
CheckResult with diagnostics
"""
checker = TypeChecker(max_errors=max_errors)
return checker.check(module)