From f4de9d11587ed17feef210b098dab8f03cc4626a Mon Sep 17 00:00:00 2001 From: Shautvast Date: Sun, 30 Nov 2025 09:11:25 +0100 Subject: [PATCH] if-else expression, simple case --- src/compiler/assembly_pass.rs | 42 +++++++------ src/compiler/ast_pass.rs | 71 ++++++++++++---------- src/errors.rs | 3 + src/symbol_builder.rs | 109 ++++++++++++++++++++-------------- 4 files changed, 126 insertions(+), 99 deletions(-) diff --git a/src/compiler/assembly_pass.rs b/src/compiler/assembly_pass.rs index 86426a1..a620077 100644 --- a/src/compiler/assembly_pass.rs +++ b/src/compiler/assembly_pass.rs @@ -229,14 +229,12 @@ impl AsmPass { self.emit(Goto(0)); let goto_addr2 = self.chunk.code.len() - 1; // placeholder self.chunk.code[goto_addr1] = GotoIfNot(self.chunk.code.len()); - if else_branch.is_some() { - self.compile_statements( - else_branch.as_ref().unwrap(), - symbols, - registry, - namespace, - )?; - } + self.compile_statements( + else_branch, + symbols, + registry, + namespace, + )?; self.chunk.code[goto_addr2] = Op::Goto(self.chunk.code.len()); } Expression::LetExpression { @@ -245,12 +243,12 @@ impl AsmPass { let name = name.lexeme.as_str(); let var = symbols.get(name); if let Some(Symbol::Variable { var_type, .. }) = var { - let inferred_type = infer_type(initializer, symbols); + let inferred_type = infer_type(initializer, symbols).map_err(|e| self.error_at_line(e))?; let calculated_type = - calculate_type(var_type, &inferred_type).map_err(|e| self.raise(e))?; + calculate_type(var_type, &inferred_type).map_err(|e| self.error_at_line(e))?; if var_type != &Unknown && var_type != &calculated_type { return Err( - self.raise(IncompatibleTypes(var_type.clone(), calculated_type)) + self.error_at_line(IncompatibleTypes(var_type.clone(), calculated_type)) ); } let name_index = self.chunk.add_var(var_type, name); @@ -258,7 +256,7 @@ impl AsmPass { self.compile_expression(namespace, initializer, symbols, registry)?; self.emit(Assign(name_index)); } else { - return Err(self.raise(UndeclaredVariable(name.to_string()))); + return Err(self.error_at_line(UndeclaredVariable(name.to_string()))); } } Expression::FunctionCall { @@ -290,7 +288,7 @@ impl AsmPass { self.emit(Call(name_index, fun.arity())); } else { return Err( - self.raise(CompilerError::FunctionNotFound(name.to_string())) + self.error_at_line(CompilerError::FunctionNotFound(name.to_string())) ); } } @@ -303,7 +301,7 @@ impl AsmPass { .. } => { self.compile_expression(namespace, receiver, symbols, registry)?; - let receiver_type = infer_type(receiver, symbols).to_string(); + let receiver_type = infer_type(receiver, symbols).map_err(|e|self.error_at_line(e))?.to_string(); let type_index = self.chunk.find_constant(&receiver_type).unwrap_or_else(|| { self.chunk @@ -314,9 +312,9 @@ impl AsmPass { self.chunk .add_constant(Value::String(method_name.to_string())) }); - let signature = lookup(&receiver_type, method_name).map_err(|e| self.raise(e))?; + let signature = lookup(&receiver_type, method_name).map_err(|e| self.error_at_line(e))?; if signature.arity() != arguments.len() { - return Err(self.raise(CompilerError::IllegalArgumentsException( + return Err(self.error_at_line(CompilerError::IllegalArgumentsException( format!("{}.{}", receiver_type, method_name), signature.parameters.len(), arguments.len(), @@ -336,7 +334,7 @@ impl AsmPass { if let Some(name_index) = name_index { self.emit(Get(*name_index)); } else { - return Err(self.raise(UndeclaredVariable(name.to_string()))); + return Err(self.error_at_line(UndeclaredVariable(name.to_string()))); } } Expression::Assignment { @@ -349,7 +347,7 @@ impl AsmPass { if let Some(name_index) = name_index { self.emit(Assign(*name_index)); } else { - return Err(self.raise(UndeclaredVariable(variable_name.to_string()))); + return Err(self.error_at_line(UndeclaredVariable(variable_name.to_string()))); } } Expression::Literal { value, .. } => { @@ -402,7 +400,7 @@ impl AsmPass { self.emit(Assign(*index)); self.emit(Pop); } else { - return Err(self.raise(UndeclaredVariable("".to_string()))); + return Err(self.error_at_line(UndeclaredVariable("".to_string()))); } } TokenType::EqualEqual => self.emit(Equal), @@ -489,10 +487,10 @@ impl AsmPass { for parameter in parameters { if let NamedParameter { name, value, .. } = argument { if name.lexeme == parameter.name.lexeme { - let value_type = infer_type(value, symbols); + let value_type = infer_type(value, symbols).map_err(|e| self.error_at_line(e))?; if parameter.var_type != value_type { return Err(self - .raise(IncompatibleTypes(parameter.var_type.clone(), value_type))); + .error_at_line(IncompatibleTypes(parameter.var_type.clone(), value_type))); } else { self.compile_expression(namespace, argument, symbols, registry)?; break; @@ -517,7 +515,7 @@ impl AsmPass { index } - fn raise(&self, error: CompilerError) -> CompilerErrorAtLine { + fn error_at_line(&self, error: CompilerError) -> CompilerErrorAtLine { CompilerErrorAtLine::raise(error, self.current_line) } } diff --git a/src/compiler/ast_pass.rs b/src/compiler/ast_pass.rs index a088538..71fe8c7 100644 --- a/src/compiler/ast_pass.rs +++ b/src/compiler/ast_pass.rs @@ -87,11 +87,11 @@ impl AstCompiler { debug!("AST {:?}", statements); Ok(statements) } else { - Err(self.raise(CompilerError::Failure)) + Err(self.error_at_line(CompilerError::Failure)) } } - fn raise(&self, error: CompilerError) -> CompilerErrorAtLine { + fn error_at_line(&self, error: CompilerError) -> CompilerErrorAtLine { CompilerErrorAtLine::raise(error, self.current_line()) } @@ -111,7 +111,7 @@ impl AstCompiler { indent_on_line += 1; } if indent_on_line > expected_indent { - Err(self.raise(UnexpectedIndent(indent_on_line, expected_indent))) + Err(self.error_at_line(UnexpectedIndent(indent_on_line, expected_indent))) } else if indent_on_line < expected_indent { self.indent.pop(); Ok(None) @@ -148,7 +148,7 @@ impl AstCompiler { } else if self.match_token(&[TokenType::Question]) { self.query_guard_expr(symbol_table) } else { - Err(self.raise(Expected("-> or ?"))) + Err(self.error_at_line(Expected("-> or ?"))) }; } Ok(Stop { @@ -181,7 +181,7 @@ impl AstCompiler { } fn match_expression(&mut self) -> Expr { - Err(self.raise(Expected("unimplemented"))) + Err(self.error_at_line(Expected("unimplemented"))) } fn object_declaration(&mut self, symbol_table: &mut SymbolTable) -> Stmt { @@ -208,7 +208,7 @@ impl AstCompiler { if field_type.is_type() { self.advance(); } else { - Err(self.raise(Expected("a type")))? + Err(self.error_at_line(Expected("a type")))? } fields.push(Parameter { name: field_name, @@ -236,7 +236,7 @@ impl AstCompiler { fn function_declaration(&mut self, symbol_table: &mut SymbolTable) -> Stmt { let name_token = self.consume(&Identifier, Expected("function name."))?; if GLOBAL_FUNCTIONS.contains_key(name_token.lexeme.as_str()) { - return Err(self.raise(CompilerError::ReservedFunctionName( + return Err(self.error_at_line(CompilerError::ReservedFunctionName( name_token.lexeme.clone(), ))); } @@ -244,7 +244,7 @@ impl AstCompiler { let mut parameters = vec![]; while !self.check(&RightParen) { if parameters.len() >= 25 { - return Err(self.raise(TooManyParameters)); + return Err(self.error_at_line(TooManyParameters)); } let parm_name = self.consume(&Identifier, Expected("a parameter name."))?; @@ -292,11 +292,7 @@ impl AstCompiler { let expr = if self.match_token(&[For]) { self.for_expression(symbol_table)? } else if self.match_token(&[Let]) { - let expr = self.let_exp(symbol_table)?; - if !self.is_at_end() { - self.consume(&Eol, Expected("end of line after expression."))?; - } - expr + self.let_exp(symbol_table)? } else { self.expression(symbol_table)? }; @@ -321,9 +317,11 @@ impl AstCompiler { fn let_exp(&mut self, symbol_table: &mut SymbolTable) -> Expr { if self.peek().token_type.is_type() { - return Err(self.raise(CompilerError::KeywordNotAllowedAsIdentifier( - self.peek().token_type.clone(), - ))); + return Err( + self.error_at_line(CompilerError::KeywordNotAllowedAsIdentifier( + self.peek().token_type.clone(), + )), + ); } let name_token = self.consume(&Identifier, Expected("variable name."))?; @@ -337,9 +335,10 @@ impl AstCompiler { if self.match_token(&[Equal]) { let initializer = self.expression(symbol_table)?; let declared_type = declared_type.unwrap_or(Unknown); - let inferred_type = infer_type(&initializer, symbol_table); - let var_type = - calculate_type(&declared_type, &inferred_type).map_err(|e| self.raise(e))?; + let inferred_type = + infer_type(&initializer, symbol_table).map_err(|e| self.error_at_line(e))?; + let var_type = calculate_type(&declared_type, &inferred_type) + .map_err(|e| self.error_at_line(e))?; symbol_table.insert( name_token.lexeme.clone(), Symbol::Variable { @@ -354,7 +353,7 @@ impl AstCompiler { initializer: Box::new(initializer), }) } else { - Err(self.raise(UninitializedVariable))? + Err(self.error_at_line(UninitializedVariable))? } } @@ -399,7 +398,7 @@ impl AstCompiler { value: Box::new(right), }) } else { - Err(self.raise(CompilerError::Failure)) + Err(self.error_at_line(CompilerError::Failure)) } } else { Ok(expr) @@ -503,7 +502,7 @@ impl AstCompiler { Ok(IfElseExpression { condition: Box::new(condition), then_branch, - else_branch: Some(self.compile(symbol_table)?), + else_branch: self.compile(symbol_table)?, }) } else { Ok(IfExpression { @@ -555,10 +554,16 @@ impl AstCompiler { key: Box::new(index), }, _ => { - return Err(self.raise(CompilerError::IllegalTypeToIndex(var_type.to_string()))); + return Err( + self.error_at_line(CompilerError::IllegalTypeToIndex(var_type.to_string())) + ); } }, - _ => return Err(self.raise(CompilerError::IllegalTypeToIndex("Unknown".to_string()))), + _ => { + return Err( + self.error_at_line(CompilerError::IllegalTypeToIndex("Unknown".to_string())) + ); + } }; self.consume(&RightBracket, Expected("']' after index."))?; Ok(get) @@ -613,7 +618,7 @@ impl AstCompiler { self.previous() .lexeme .parse() - .map_err(|e| self.raise(ParseError(format!("{:?}", e))))?, + .map_err(|e| self.error_at_line(ParseError(format!("{:?}", e))))?, ), } } else if self.match_token(&[U32]) { @@ -622,7 +627,7 @@ impl AstCompiler { literaltype: Integer, value: Value::U32( u32::from_str_radix(self.previous().lexeme.trim_start_matches("0x"), 16) - .map_err(|e| self.raise(ParseError(format!("{:?}", e))))?, + .map_err(|e| self.error_at_line(ParseError(format!("{:?}", e))))?, ), } } else if self.match_token(&[U64]) { @@ -631,7 +636,7 @@ impl AstCompiler { literaltype: Integer, value: Value::U64( u64::from_str_radix(self.previous().lexeme.trim_start_matches("0x"), 16) - .map_err(|e| self.raise(ParseError(format!("{:?}", e))))?, + .map_err(|e| self.error_at_line(ParseError(format!("{:?}", e))))?, ), } } else if self.match_token(&[FloatingPoint]) { @@ -642,7 +647,7 @@ impl AstCompiler { self.previous() .lexeme .parse() - .map_err(|e| self.raise(ParseError(format!("{:?}", e))))?, + .map_err(|e| self.error_at_line(ParseError(format!("{:?}", e))))?, ), } } else if self.match_token(&[StringType]) { @@ -663,7 +668,9 @@ impl AstCompiler { literaltype: DateTime, value: Value::DateTime(Box::new( chrono::DateTime::parse_from_str(&self.previous().lexeme, DATE_FORMAT_TIMEZONE) - .map_err(|_| self.raise(ParseError(self.previous().lexeme.clone())))? + .map_err(|_| { + self.error_at_line(ParseError(self.previous().lexeme.clone())) + })? .into(), )), } @@ -766,7 +773,7 @@ impl AstCompiler { let mut arguments = vec![]; while !self.match_token(&[RightParen]) { if arguments.len() >= 25 { - return Err(self.raise(TooManyParameters)); + return Err(self.error_at_line(TooManyParameters)); } let arg = self.expression(symbol_table)?; arguments.push(arg); @@ -789,7 +796,7 @@ impl AstCompiler { self.advance(); } else { self.had_error = true; - return Err(self.raise(message)); + return Err(self.error_at_line(message)); } Ok(self.previous().clone()) } @@ -969,7 +976,7 @@ pub enum Expression { IfElseExpression { condition: Box, then_branch: Vec, - else_branch: Option>, + else_branch: Vec, }, LetExpression { name: Token, diff --git a/src/errors.rs b/src/errors.rs index 6cf92c4..7d8f7b1 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -73,6 +73,9 @@ pub enum CompilerError { IllegalArgumentsException(String, usize, usize), #[error("Function name {0} is a global function and cannot be used here.")] ReservedFunctionName(String), + #[error("The if expression branches do not match. Was {0} and {1}")] + IfElseBranchesDoNotMatch(TokenType, TokenType), + } #[derive(Error, Debug, PartialEq)] diff --git a/src/symbol_builder.rs b/src/symbol_builder.rs index 0c398eb..06d72e8 100644 --- a/src/symbol_builder.rs +++ b/src/symbol_builder.rs @@ -118,7 +118,7 @@ pub fn calculate_type( }) } -pub fn infer_type(expr: &Expression, symbols: &HashMap) -> TokenType { +pub fn infer_type(expr: &Expression, symbols: &HashMap) -> Result { match expr { Expression::Binary { left, @@ -126,35 +126,35 @@ pub fn infer_type(expr: &Expression, symbols: &HashMap) -> Token right, .. } => { - let left_type = infer_type(left, symbols); - let right_type = infer_type(right, symbols); + let left_type = infer_type(left, symbols)?; + let right_type = infer_type(right, symbols)?; if [Greater, Less, GreaterEqual, LessEqual].contains(&operator.token_type) { - Bool + Ok(Bool) } else if left_type == right_type { // map to determined numeric type if yet undetermined (32 or 64 bits) - match left_type { + Ok(match left_type { FloatingPoint => F64, Integer => I64, _ => left_type, - } + }) } else if let Plus = operator.token_type { // includes string concatenation with numbers // followed by type coercion to 64 bits for numeric types debug!("coerce {} : {}", left_type, right_type); match (left_type, right_type) { - (_, StringType) => StringType, - (StringType, _) => StringType, - (FloatingPoint, _) => F64, - (Integer, FloatingPoint) => F64, - (Integer, _) => I64, - (I64, Integer) => I64, - (F64, _) => F64, - (U64, U32) => U64, - (I64, I32) => I64, + (_, StringType) => Ok(StringType), + (StringType, _) => Ok(StringType), + (FloatingPoint, _) => Ok(F64), + (Integer, FloatingPoint) => Ok(F64), + (Integer, _) => Ok(I64), + (I64, Integer) => Ok(I64), + (F64, _) => Ok(F64), + (U64, U32) => Ok(U64), + (I64, I32) => Ok(I64), // could add a date and a duration. future work // could add a List and a value. also future work // could add a Map and a tuple. Will I add tuple types? Future work! - _ => panic!("Unexpected coercion"), + _ => Err(CompilerError::Failure), //TODO better error message } // could have done some fall through here, but this will fail less gracefully, // so if my thinking is wrong or incomplete it will panic @@ -162,40 +162,40 @@ pub fn infer_type(expr: &Expression, symbols: &HashMap) -> Token // type coercion to 64 bits for numeric types debug!("coerce {} : {}", left_type, right_type); match (left_type, right_type) { - (FloatingPoint, _) => F64, - (Integer, FloatingPoint) => F64, - (Integer, I64) => I64, - (I64, FloatingPoint) => F64, - (F64, _) => F64, - (U64, U32) => U64, - (I64, I32) => I64, - (I64, Integer) => I64, - _ => panic!("Unexpected coercion"), + (FloatingPoint, _) => Ok(F64), + (Integer, FloatingPoint) => Ok(F64), + (Integer, I64) => Ok(I64), + (I64, FloatingPoint) => Ok(F64), + (F64, _) => Ok(F64), + (U64, U32) => Ok(U64), + (I64, I32) => Ok(I64), + (I64, Integer) => Ok(I64), + _ => Err(CompilerError::Failure), // TODO } } } Expression::Grouping { expression, .. } => infer_type(expression, symbols), - Expression::Literal { literaltype, .. } => literaltype.clone(), - Expression::List { literaltype, .. } => literaltype.clone(), - Expression::Map { literaltype, .. } => literaltype.clone(), + Expression::Literal { literaltype, .. } => Ok(literaltype.clone()), + Expression::List { literaltype, .. } => Ok(literaltype.clone()), + Expression::Map { literaltype, .. } => Ok(literaltype.clone()), Expression::Unary { right, operator, .. } => { - let literal_type = infer_type(right, symbols); + let literal_type = infer_type(right, symbols)?; if literal_type == Integer && operator.token_type == Minus { - SignedInteger + Ok(SignedInteger) } else { - UnsignedInteger + Ok(UnsignedInteger) } } - Expression::Variable { var_type, .. } => var_type.clone(), + Expression::Variable { var_type, .. } => Ok(var_type.clone()), Expression::Assignment { value, .. } => infer_type(value, symbols), Expression::FunctionCall { name, .. } => { let symbol = symbols.get(name); match symbol { - Some(Symbol::Function { return_type, .. }) => return_type.clone(), - Some(Symbol::Object { name, .. }) => ObjectType(name.clone()), - _ => Unknown, + Some(Symbol::Function { return_type, .. }) => Ok(return_type.clone()), + Some(Symbol::Object { name, .. }) => Ok(ObjectType(name.clone())), + _ => Err(CompilerError::Failure), // TODO } } Expression::MethodCall { @@ -205,7 +205,7 @@ pub fn infer_type(expr: &Expression, symbols: &HashMap) -> Token } => { if let Expression::Literal { value, .. } = receiver.deref() { if let Ok(signature) = lookup(&value.to_string(), method_name) { - signature.return_type.clone() + Ok(signature.return_type.clone()) } else { unreachable!() //? } @@ -213,15 +213,34 @@ pub fn infer_type(expr: &Expression, symbols: &HashMap) -> Token infer_type(receiver, symbols) } } - Expression::Stop { .. } => Unknown, - Expression::NamedParameter { .. } => Unknown, - Expression::ListGet { .. } => Unknown, - Expression::MapGet { .. } => Unknown, - Expression::FieldGet { .. } => Unknown, + Expression::Stop { .. } => Ok(Unknown), + Expression::NamedParameter { .. } => Ok(Unknown), + Expression::ListGet { .. } => Ok(Unknown), + Expression::MapGet { .. } => Ok(Unknown), + Expression::FieldGet { .. } => Ok(Unknown), Expression::Range { lower, .. } => infer_type(lower, symbols), - Expression::IfExpression { .. } => Unknown, - Expression::IfElseExpression { .. } => Unknown, - Expression::LetExpression { .. } => Void, - Expression::ForStatement { .. } => Void, + Expression::IfExpression { .. } => Ok(Unknown), + Expression::IfElseExpression { then_branch, else_branch, .. } => { + let mut then_type = Void; + for statement in then_branch { + if let Statement::ExpressionStmt { expression } = statement { + then_type = infer_type(expression, symbols)? + } + } + + let mut else_type = Void; + for statement in else_branch { + if let Statement::ExpressionStmt { expression } = statement { + else_type = infer_type(expression, symbols)? + } + } + if then_type != else_type{ + Err(CompilerError::IfElseBranchesDoNotMatch(then_type, else_type)) + } else { + Ok(then_type) + } + }, + Expression::LetExpression { .. } => Ok(Void), + Expression::ForStatement { .. } => Ok(Void), } }