diff --git a/crates/fmt/src/state/sol.rs b/crates/fmt/src/state/sol.rs index 9b8a478726666..433fc71d16d40 100644 --- a/crates/fmt/src/state/sol.rs +++ b/crates/fmt/src/state/sol.rs @@ -491,14 +491,12 @@ impl<'ast> State<'_, 'ast> { let params_format = match header_style { MultilineFuncHeaderStyle::ParamsAlways => ListFormat::always_break(), MultilineFuncHeaderStyle::All - if header.parameters.len() > 1 - && !self.can_header_be_inlined(header, body.is_some()) => + if header.parameters.len() > 1 && !self.can_header_be_inlined(func) => { ListFormat::always_break() } MultilineFuncHeaderStyle::AllParams - if !header.parameters.is_empty() - && !self.can_header_be_inlined(header, body.is_some()) => + if !header.parameters.is_empty() && !self.can_header_be_inlined(func) => { ListFormat::always_break() } @@ -552,7 +550,7 @@ impl<'ast> State<'_, 'ast> { let attrib_box = self.config.multiline_func_header.params_first() || (self.config.multiline_func_header.attrib_first() - && !self.can_header_params_be_inlined(header)); + && !self.can_header_params_be_inlined(func)); if attrib_box { self.s.cbox(0); } @@ -2554,16 +2552,28 @@ impl<'ast> State<'_, 'ast> { els_opt.is_none_or(|els| self.is_inline_stmt(els, 6)) } - fn can_header_be_inlined(&mut self, header: &ast::FunctionHeader<'_>, has_body: bool) -> bool { + fn can_header_be_inlined(&mut self, func: &ast::ItemFunction<'_>) -> bool { + self.estimate_header_size(func) <= self.space_left() + } + + fn can_header_params_be_inlined(&mut self, func: &ast::ItemFunction<'_>) -> bool { + self.estimate_header_params_size(func) <= self.space_left() + } + + fn estimate_header_size(&mut self, func: &ast::ItemFunction<'_>) -> usize { + let ast::ItemFunction { kind: _, ref header, ref body, body_span: _ } = *func; + // ' ' + visibility let visibility = header.visibility.map_or(0, |v| self.estimate_size(v.span) + 1); // ' ' + state mutability let mutability = header.state_mutability.map_or(0, |sm| self.estimate_size(sm.span) + 1); // ' ' + modifier + (' ' + modifier) - let modifiers = - header.modifiers.iter().fold(0, |len, m| len + self.estimate_size(m.span())) + 1; + let m = header.modifiers.iter().fold(0, |len, m| len + self.estimate_size(m.span())); + let modifiers = if m != 0 { m + 1 } else { 0 }; // ' ' + override let override_ = header.override_.as_ref().map_or(0, |o| self.estimate_size(o.span) + 1); + // ' ' + virtual + let virtual_ = if header.virtual_.is_none() { 0 } else { 8 }; // ' returns(' + var + (', ' + var) + ')' let returns = header.returns.as_ref().map_or(0, |ret| { ret.vars @@ -2571,23 +2581,29 @@ impl<'ast> State<'_, 'ast> { .fold(0, |len, p| if len != 0 { len + 2 } else { 10 } + self.estimate_size(p.span)) }); // ' {' or ';' - let end = if has_body { 2 } else { 1 }; + let end = if body.is_some() { 2 } else { 1 }; - self.estimate_header_params_size(header) // accounts for 'function name(..)' + self.estimate_header_params_size(func) + visibility + mutability + modifiers + override_ + + virtual_ + returns + end - <= self.space_left() } - fn can_header_params_be_inlined(&mut self, header: &ast::FunctionHeader<'_>) -> bool { - self.estimate_header_params_size(header) <= self.space_left() - } + fn estimate_header_params_size(&mut self, func: &ast::ItemFunction<'_>) -> usize { + let ast::ItemFunction { kind, ref header, body: _, body_span: _ } = *func; + + let kw = match kind { + ast::FunctionKind::Constructor => 11, // 'constructor' + ast::FunctionKind::Function => 9, // 'function ' + ast::FunctionKind::Modifier => 9, // 'modifier ' + ast::FunctionKind::Fallback => 8, // 'fallback' + ast::FunctionKind::Receive => 7, // 'receive' + }; - fn estimate_header_params_size(&mut self, header: &ast::FunctionHeader<'_>) -> usize { // '(' + param + (', ' + param) + ')' let params = header .parameters @@ -2595,8 +2611,7 @@ impl<'ast> State<'_, 'ast> { .iter() .fold(0, |len, p| if len != 0 { len + 2 } else { 2 } + self.estimate_size(p.span)); - // 'function ' + name + ' ' + params - 9 + header.name.map_or(0, |name| self.estimate_size(name.span) + 1) + params + kw + header.name.map_or(0, |name| self.estimate_size(name.span)) + std::cmp::max(2, params) } fn estimate_lhs_size(&self, expr: &ast::Expr<'_>, parent_op: &ast::BinOp) -> usize { @@ -2958,3 +2973,117 @@ pub(super) fn get_callee_head_size(callee: &ast::Expr<'_>) -> usize { _ => 0, } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{FormatterConfig, InlineConfig}; + use foundry_common::comments::Comments; + use solar::{ + interface::{Session, source_map::FileName}, + sema::Compiler, + }; + use std::sync::Arc; + + /// This helper extracts function headers from the AST and passes them to the test function. + fn parse_and_test(source: &str, test_fn: F) + where + F: FnOnce(&mut State<'_, '_>, &ast::ItemFunction<'_>) + Send, + { + let session = Session::builder().with_buffer_emitter(Default::default()).build(); + let mut compiler = Compiler::new(session); + + compiler + .enter_mut(|c| -> solar::interface::Result<()> { + let mut pcx = c.parse(); + pcx.set_resolve_imports(false); + + // Create a source file using stdin as the filename + let file = c + .sess() + .source_map() + .new_source_file(FileName::Stdin, source) + .map_err(|e| c.sess().dcx.err(e.to_string()).emit())?; + + pcx.add_file(file.clone()); + pcx.parse(); + c.dcx().has_errors()?; + + // Get AST from parsed source and setup the formatter + let gcx = c.gcx(); + let (_, source_obj) = gcx.get_ast_source(&file.name).expect("Failed to get AST"); + let ast = source_obj.ast.as_ref().expect("No AST found"); + let comments = + Comments::new(&source_obj.file, gcx.sess.source_map(), true, false, None); + let config = Arc::new(FormatterConfig::default()); + let inline_config = InlineConfig::default(); + let mut state = State::new(gcx.sess.source_map(), config, inline_config, comments); + + // Extract the first function header (either top-level or inside a contract) + let func = ast + .items + .iter() + .find_map(|item| match &item.kind { + ast::ItemKind::Function(func) => Some(func), + ast::ItemKind::Contract(contract) => { + contract.body.iter().find_map(|contract_item| { + match &contract_item.kind { + ast::ItemKind::Function(func) => Some(func), + _ => None, + } + }) + } + _ => None, + }) + .expect("No function found in source"); + + // Run the closure + test_fn(&mut state, func); + + Ok(()) + }) + .expect("Test failed"); + } + + #[test] + fn test_estimate_header_sizes() { + let test_cases = [ + ("function foo();", 14, 15), + ("function foo() {}", 14, 16), + ("function foo() public {}", 14, 23), + ("function foo(uint256 a) public {}", 23, 32), + ("function foo(uint256 a, address b, bool c) public {}", 42, 51), + ("function foo() public pure {}", 14, 28), + ("function foo() public virtual {}", 14, 31), + ("function foo() public override {}", 14, 32), + ("function foo() public onlyOwner {}", 14, 33), + ("function foo() public returns(uint256) {}", 14, 40), + ("function foo() public returns(uint256, address) {}", 14, 49), + ("function foo(uint256 a) public virtual override returns(uint256) {}", 23, 66), + ("function foo() external payable {}", 14, 33), + // other function types + ("contract C { constructor() {} }", 13, 15), + ("contract C { constructor(uint256 a) {} }", 22, 24), + ("contract C { modifier onlyOwner() {} }", 20, 22), + ("contract C { modifier onlyRole(bytes32 role) {} }", 31, 33), + ("contract C { fallback() external payable {} }", 10, 29), + ("contract C { receive() external payable {} }", 9, 28), + ]; + + for (source, expected_params, expected_header) in &test_cases { + parse_and_test(source, |state, func| { + let params_size = state.estimate_header_params_size(func); + assert_eq!( + params_size, *expected_params, + "Failed params size: expected {expected_params}, got {params_size} for source: {source}", + ); + + let header_size = state.estimate_header_size(func); + assert_eq!( + header_size, *expected_header, + "Failed header size: expected {expected_header}, got {header_size} for source: {source}", + ); + }); + } + } +}