Skip to content
163 changes: 146 additions & 17 deletions crates/fmt/src/state/sol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,14 +491,12 @@ impl<'ast> State<'_, 'ast> {
let params_format = match header_style {
MultilineFuncHeaderStyle::ParamsAlways => ListFormat::always_break(),
MultilineFuncHeaderStyle::All
if header.parameters.len() > 1
&& !self.can_header_be_inlined(header, body.is_some()) =>
if header.parameters.len() > 1 && !self.can_header_be_inlined(func) =>
{
ListFormat::always_break()
}
MultilineFuncHeaderStyle::AllParams
if !header.parameters.is_empty()
&& !self.can_header_be_inlined(header, body.is_some()) =>
if !header.parameters.is_empty() && !self.can_header_be_inlined(func) =>
{
ListFormat::always_break()
}
Expand Down Expand Up @@ -552,7 +550,7 @@ impl<'ast> State<'_, 'ast> {

let attrib_box = self.config.multiline_func_header.params_first()
|| (self.config.multiline_func_header.attrib_first()
&& !self.can_header_params_be_inlined(header));
&& !self.can_header_params_be_inlined(func));
if attrib_box {
self.s.cbox(0);
}
Expand Down Expand Up @@ -2554,49 +2552,66 @@ impl<'ast> State<'_, 'ast> {
els_opt.is_none_or(|els| self.is_inline_stmt(els, 6))
}

fn can_header_be_inlined(&mut self, header: &ast::FunctionHeader<'_>, has_body: bool) -> bool {
fn can_header_be_inlined(&mut self, func: &ast::ItemFunction<'_>) -> bool {
self.estimate_header_size(func) <= self.space_left()
}

fn can_header_params_be_inlined(&mut self, func: &ast::ItemFunction<'_>) -> bool {
self.estimate_header_params_size(func) <= self.space_left()
}

fn estimate_header_size(&mut self, func: &ast::ItemFunction<'_>) -> usize {
let ast::ItemFunction { kind: _, ref header, ref body, body_span: _ } = *func;

// ' ' + visibility
let visibility = header.visibility.map_or(0, |v| self.estimate_size(v.span) + 1);
// ' ' + state mutability
let mutability = header.state_mutability.map_or(0, |sm| self.estimate_size(sm.span) + 1);
// ' ' + modifier + (' ' + modifier)
let modifiers =
header.modifiers.iter().fold(0, |len, m| len + self.estimate_size(m.span())) + 1;
let m = header.modifiers.iter().fold(0, |len, m| len + self.estimate_size(m.span()));
let modifiers = if m != 0 { m + 1 } else { 0 };
// ' ' + override
let override_ = header.override_.as_ref().map_or(0, |o| self.estimate_size(o.span) + 1);
// ' ' + virtual
let virtual_ = if header.virtual_.is_none() { 0 } else { 8 };
// ' returns(' + var + (', ' + var) + ')'
let returns = header.returns.as_ref().map_or(0, |ret| {
ret.vars
.iter()
.fold(0, |len, p| if len != 0 { len + 2 } else { 10 } + self.estimate_size(p.span))
});
// ' {' or ';'
let end = if has_body { 2 } else { 1 };
let end = if body.is_some() { 2 } else { 1 };

self.estimate_header_params_size(header) // accounts for 'function name(..)'
self.estimate_header_params_size(func)
+ visibility
+ mutability
+ modifiers
+ override_
+ virtual_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

+ returns
+ end
<= self.space_left()
}

fn can_header_params_be_inlined(&mut self, header: &ast::FunctionHeader<'_>) -> bool {
self.estimate_header_params_size(header) <= self.space_left()
}
fn estimate_header_params_size(&mut self, func: &ast::ItemFunction<'_>) -> usize {
let ast::ItemFunction { kind, ref header, body: _, body_span: _ } = *func;

let kw = match kind {
ast::FunctionKind::Constructor => 11, // 'constructor'
ast::FunctionKind::Function => 9, // 'function '
ast::FunctionKind::Modifier => 9, // 'modifier '
ast::FunctionKind::Fallback => 8, // 'fallback'
ast::FunctionKind::Receive => 7, // 'receive'
};

fn estimate_header_params_size(&mut self, header: &ast::FunctionHeader<'_>) -> usize {
// '(' + param + (', ' + param) + ')'
let params = header
.parameters
.vars
.iter()
.fold(0, |len, p| if len != 0 { len + 2 } else { 2 } + self.estimate_size(p.span));

// 'function ' + name + ' ' + params
9 + header.name.map_or(0, |name| self.estimate_size(name.span) + 1) + params
kw + header.name.map_or(0, |name| self.estimate_size(name.span)) + std::cmp::max(2, params)
}

fn estimate_lhs_size(&self, expr: &ast::Expr<'_>, parent_op: &ast::BinOp) -> usize {
Expand Down Expand Up @@ -2958,3 +2973,117 @@ pub(super) fn get_callee_head_size(callee: &ast::Expr<'_>) -> usize {
_ => 0,
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{FormatterConfig, InlineConfig};
use foundry_common::comments::Comments;
use solar::{
interface::{Session, source_map::FileName},
sema::Compiler,
};
use std::sync::Arc;

/// This helper extracts function headers from the AST and passes them to the test function.
fn parse_and_test<F>(source: &str, test_fn: F)
where
F: FnOnce(&mut State<'_, '_>, &ast::ItemFunction<'_>) + Send,
{
let session = Session::builder().with_buffer_emitter(Default::default()).build();
let mut compiler = Compiler::new(session);

compiler
.enter_mut(|c| -> solar::interface::Result<()> {
let mut pcx = c.parse();
pcx.set_resolve_imports(false);

// Create a source file using stdin as the filename
let file = c
.sess()
.source_map()
.new_source_file(FileName::Stdin, source)
.map_err(|e| c.sess().dcx.err(e.to_string()).emit())?;

pcx.add_file(file.clone());
pcx.parse();
c.dcx().has_errors()?;

// Get AST from parsed source and setup the formatter
let gcx = c.gcx();
let (_, source_obj) = gcx.get_ast_source(&file.name).expect("Failed to get AST");
let ast = source_obj.ast.as_ref().expect("No AST found");
let comments =
Comments::new(&source_obj.file, gcx.sess.source_map(), true, false, None);
let config = Arc::new(FormatterConfig::default());
let inline_config = InlineConfig::default();
let mut state = State::new(gcx.sess.source_map(), config, inline_config, comments);

// Extract the first function header (either top-level or inside a contract)
let func = ast
.items
.iter()
.find_map(|item| match &item.kind {
ast::ItemKind::Function(func) => Some(func),
ast::ItemKind::Contract(contract) => {
contract.body.iter().find_map(|contract_item| {
match &contract_item.kind {
ast::ItemKind::Function(func) => Some(func),
_ => None,
}
})
}
_ => None,
})
.expect("No function found in source");

// Run the closure
test_fn(&mut state, func);

Ok(())
})
.expect("Test failed");
}

#[test]
fn test_estimate_header_sizes() {
let test_cases = [
("function foo();", 14, 15),
("function foo() {}", 14, 16),
("function foo() public {}", 14, 23),
("function foo(uint256 a) public {}", 23, 32),
("function foo(uint256 a, address b, bool c) public {}", 42, 51),
("function foo() public pure {}", 14, 28),
("function foo() public virtual {}", 14, 31),
("function foo() public override {}", 14, 32),
("function foo() public onlyOwner {}", 14, 33),
("function foo() public returns(uint256) {}", 14, 40),
("function foo() public returns(uint256, address) {}", 14, 49),
("function foo(uint256 a) public virtual override returns(uint256) {}", 23, 66),
("function foo() external payable {}", 14, 33),
// other function types
("contract C { constructor() {} }", 13, 15),
("contract C { constructor(uint256 a) {} }", 22, 24),
("contract C { modifier onlyOwner() {} }", 20, 22),
("contract C { modifier onlyRole(bytes32 role) {} }", 31, 33),
("contract C { fallback() external payable {} }", 10, 29),
("contract C { receive() external payable {} }", 9, 28),
];

for (source, expected_params, expected_header) in &test_cases {
parse_and_test(source, |state, func| {
let params_size = state.estimate_header_params_size(func);
assert_eq!(
params_size, *expected_params,
"Failed params size: expected {expected_params}, got {params_size} for source: {source}",
);

let header_size = state.estimate_header_size(func);
assert_eq!(
header_size, *expected_header,
"Failed header size: expected {expected_header}, got {header_size} for source: {source}",
);
});
}
}
}
Loading