diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index d8c6f5ee6d82..389c1bdb722e 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1080,11 +1080,7 @@ impl TryInto for &protobuf::LogicalExprNode { } // argo engine add start ExprType::AggregateUdfExpr(expr) => { - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager .aggregate_udfs .get(expr.fun_name.as_str()) @@ -1106,11 +1102,7 @@ impl TryInto for &protobuf::LogicalExprNode { } } ExprType::ScalarUdfProtoExpr(expr) => { - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager .scalar_udfs .get(expr.fun_name.as_str()) @@ -1335,7 +1327,7 @@ impl TryInto for &protobuf::Field { use crate::serde::protobuf::ColumnStats; use datafusion::physical_plan::{aggregates, windows}; use datafusion::plugin::plugin_manager::global_plugin_manager; -use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::udf::{get_udf_plugin_manager, UDFPluginManager}; use datafusion::plugin::PluginEnum; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 72a0455ea396..38c04ef29dd3 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -83,7 +83,7 @@ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; use datafusion::plugin::plugin_manager::global_plugin_manager; -use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::udf::{get_udf_plugin_manager, UDFPluginManager}; use datafusion::plugin::PluginEnum; use datafusion::prelude::CsvReadOptions; use log::debug; @@ -320,10 +320,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ExprType::AggregateUdfExpr(agg_node) => { let name = agg_node.fun_name.as_str(); let udaf_fun_name = &name[0..name.find('(').unwrap()]; - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager.aggregate_udfs.get(udaf_fun_name).ok_or_else(|| { proto_error(format!( "can not get udaf:{} from plugins!", @@ -584,11 +581,7 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { } // argo engine add. ExprType::ScalarUdfProtoExpr(e) => { - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager .scalar_udfs .get(&e.fun_name) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 6e542c7c6804..5e4e3ef6b711 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -83,9 +83,7 @@ use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::PhysicalPlanner; -use crate::plugin::plugin_manager::global_plugin_manager; -use crate::plugin::udf::UDFPluginManager; -use crate::plugin::PluginEnum; +use crate::plugin::udf::get_udf_plugin_manager; use crate::sql::{ parser::{DFParser, FileType}, planner::{ContextProvider, SqlToRel}, @@ -198,13 +196,9 @@ impl ExecutionContext { })), }; - let gpm = global_plugin_manager(config.plugin_dir.as_str()); - // register udf - let gpm_guard = gpm.lock().unwrap(); - let plugin_registrar = gpm_guard.plugin_managers.get(&PluginEnum::UDF).unwrap(); if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() + get_udf_plugin_manager(config.plugin_dir.as_str()) { udf_plugin_manager .scalar_udfs diff --git a/datafusion/src/plugin/mod.rs b/datafusion/src/plugin/mod.rs index 67d6655a2b07..1450749a2afc 100644 --- a/datafusion/src/plugin/mod.rs +++ b/datafusion/src/plugin/mod.rs @@ -1,7 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::error::Result; use crate::plugin::udf::UDFPluginManager; +use libloading::Library; use std::any::Any; use std::env; +use std::sync::Arc; /// plugin manager pub mod plugin_manager; @@ -47,17 +66,15 @@ pub struct PluginDeclaration { /// One of PluginEnum pub plugin_type: unsafe extern "C" fn() -> PluginEnum, - - /// `register` is a function which impl PluginRegistrar. It will be call when plugin load. - pub register: unsafe extern "C" fn(&mut Box), } /// Plugin Registrar , Every plugin need implement this trait pub trait PluginRegistrar: Send + Sync + 'static { - /// The implementer of the plug-in needs to call this interface to report his own information to the plug-in manager - fn register_plugin(&mut self, plugin: Box) -> Result<()>; + /// # Safety + /// load plugin from library + unsafe fn load(&mut self, library: Arc) -> Result<()>; - /// Returns the plugin registrar as [`Any`](std::any::Any) so that it can be + /// Returns the plugin as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; } @@ -66,22 +83,12 @@ pub trait PluginRegistrar: Send + Sync + 'static { /// /// # Notes /// -/// This works by automatically generating an `extern "C"` function with a +/// This works by automatically generating an `extern "C"` function named `get_plugin_type` with a /// pre-defined signature and symbol name. And then generating a PluginDeclaration. /// Therefore you will only be able to declare one plugin per library. #[macro_export] macro_rules! declare_plugin { - ($plugin_type:expr, $curr_plugin_type:ty, $constructor:path) => { - #[no_mangle] - pub extern "C" fn register_plugin( - registrar: &mut Box, - ) { - // make sure the constructor is the correct type. - let constructor: fn() -> $curr_plugin_type = $constructor; - let object = constructor(); - registrar.register_plugin(Box::new(object)).unwrap(); - } - + ($plugin_type:expr) => { #[no_mangle] pub extern "C" fn get_plugin_type() -> $crate::plugin::PluginEnum { $plugin_type @@ -93,7 +100,6 @@ macro_rules! declare_plugin { rustc_version: $crate::plugin::RUSTC_VERSION, core_version: $crate::plugin::CORE_VERSION, plugin_type: get_plugin_type, - register: register_plugin, }; }; } diff --git a/datafusion/src/plugin/plugin_manager.rs b/datafusion/src/plugin/plugin_manager.rs index a8a19e4ac8d9..764e94157a00 100644 --- a/datafusion/src/plugin/plugin_manager.rs +++ b/datafusion/src/plugin/plugin_manager.rs @@ -1,3 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. use crate::error::{DataFusionError, Result}; use crate::plugin::{PluginDeclaration, CORE_VERSION, RUSTC_VERSION}; use crate::plugin::{PluginEnum, PluginRegistrar}; @@ -13,11 +29,14 @@ use once_cell::sync::OnceCell; /// To prevent the library from being loaded multiple times, we use once_cell defines a Arc> /// Because datafusion is a library, not a service, users may not need to load all plug-ins in the process. /// So fn global_plugin_manager return Arc>. In this way, users can load the required library through the load method of GlobalPluginManager when needed +static INSTANCE: OnceCell>> = OnceCell::new(); + +/// global_plugin_manager pub fn global_plugin_manager( plugin_path: &str, ) -> &'static Arc> { - static INSTANCE: OnceCell>> = OnceCell::new(); INSTANCE.get_or_init(move || unsafe { + println!("====================init==================="); let mut gpm = GlobalPluginManager::default(); gpm.load(plugin_path).unwrap(); Arc::new(Mutex::new(gpm)) @@ -38,6 +57,9 @@ impl GlobalPluginManager { /// # Safety /// find plugin file from `plugin_path` and load it . unsafe fn load(&mut self, plugin_path: &str) -> Result<()> { + if "".eq(plugin_path) { + return Ok(()); + } // find library file from udaf_plugin_path info!("load plugin from dir:{}", plugin_path); println!("load plugin from dir:{}", plugin_path); @@ -54,18 +76,19 @@ impl GlobalPluginManager { let library = Arc::new(library); - // get a pointer to the plugin_declaration symbol. - let dec = library - .get::<*mut PluginDeclaration>(b"plugin_declaration\0") - .map_err(|e| { - DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("not found plugin_declaration in the library: {}", e), - )) - })? - .read(); - - // version checks to prevent accidental ABI incompatibilities + let dec = library.get::<*mut PluginDeclaration>(b"plugin_declaration\0"); + if dec.is_err() { + info!( + "not found plugin_declaration in the library: {}", + plugin_file.path().to_str().unwrap() + ); + return Ok(()); + } + + let dec = dec.unwrap().read(); + + // ersion checks to prevent accidental ABI incompatibilities + if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { return Err(DataFusionError::IoError(io::Error::new( io::ErrorKind::Other, @@ -82,8 +105,7 @@ impl GlobalPluginManager { } Some(manager) => manager, }; - - (dec.register)(curr_plugin_manager); + curr_plugin_manager.load(library)?; self.plugin_files .push(plugin_file.path().to_str().unwrap().to_string()); } @@ -112,17 +134,20 @@ impl GlobalPluginManager { if let Some(path) = item.path().extension() { if let Some(suffix) = path.to_str() { if suffix == "dylib" || suffix == "so" || suffix == "dll" { - info!("load plugin from library file:{}", path.to_str().unwrap()); + info!( + "load plugin from library file:{}", + item.path().to_str().unwrap() + ); println!( "load plugin from library file:{}", - path.to_str().unwrap() + item.path().to_str().unwrap() ); return Some(item); } } } - return None; + None }) { plugin_files.push(entry); } diff --git a/datafusion/src/plugin/udf.rs b/datafusion/src/plugin/udf.rs index ffbb928fbd0f..7f223ee69570 100644 --- a/datafusion/src/plugin/udf.rs +++ b/datafusion/src/plugin/udf.rs @@ -1,8 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. use crate::error::{DataFusionError, Result}; use crate::physical_plan::udaf::AggregateUDF; use crate::physical_plan::udf::ScalarUDF; -use crate::plugin::{Plugin, PluginRegistrar}; -use libloading::Library; +use crate::plugin::plugin_manager::global_plugin_manager; +use crate::plugin::{Plugin, PluginEnum, PluginRegistrar}; +use libloading::{Library, Symbol}; use std::any::Any; use std::collections::HashMap; use std::io; @@ -24,7 +41,7 @@ pub trait UDFPlugin: Plugin { } /// UDFPluginManager -#[derive(Default)] +#[derive(Default, Clone)] pub struct UDFPluginManager { /// scalar udfs pub scalar_udfs: HashMap>, @@ -37,52 +54,98 @@ pub struct UDFPluginManager { } impl PluginRegistrar for UDFPluginManager { - fn register_plugin(&mut self, plugin: Box) -> Result<()> { - if let Some(udf_plugin) = plugin.as_any().downcast_ref::>() { - udf_plugin - .udf_names() - .unwrap() - .iter() - .try_for_each(|udf_name| { - if self.scalar_udfs.contains_key(udf_name) { - Err(DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("udf name: {} already exists", udf_name), - ))) - } else { - let scalar_udf = udf_plugin.get_scalar_udf_by_name(udf_name)?; - self.scalar_udfs - .insert(udf_name.to_string(), Arc::new(scalar_udf)); - Ok(()) - } - })?; + unsafe fn load(&mut self, library: Arc) -> Result<()> { + type PluginRegister = unsafe fn() -> Box; + let register_fun: Symbol = + library.get(b"registrar_udf_plugin\0").map_err(|e| { + DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("not found fn registrar_udf_plugin in the library: {}", e), + )) + })?; - udf_plugin - .udaf_names() - .unwrap() - .iter() - .try_for_each(|udaf_name| { - if self.aggregate_udfs.contains_key(udaf_name) { - Err(DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("udaf name: {} already exists", udaf_name), - ))) - } else { - let aggregate_udf = - udf_plugin.get_aggregate_udf_by_name(udaf_name)?; - self.aggregate_udfs - .insert(udaf_name.to_string(), Arc::new(aggregate_udf)); - Ok(()) - } - })?; - } - Err(DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("expected plugin type is 'dyn UDFPlugin', but it's not"), - ))) + let udf_plugin: Box = register_fun(); + udf_plugin + .udf_names() + .unwrap() + .iter() + .try_for_each(|udf_name| { + if self.scalar_udfs.contains_key(udf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udf name: {} already exists", udf_name), + ))) + } else { + let scalar_udf = udf_plugin.get_scalar_udf_by_name(udf_name)?; + self.scalar_udfs + .insert(udf_name.to_string(), Arc::new(scalar_udf)); + Ok(()) + } + })?; + + udf_plugin + .udaf_names() + .unwrap() + .iter() + .try_for_each(|udaf_name| { + if self.aggregate_udfs.contains_key(udaf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udaf name: {} already exists", udaf_name), + ))) + } else { + let aggregate_udf = + udf_plugin.get_aggregate_udf_by_name(udaf_name)?; + self.aggregate_udfs + .insert(udaf_name.to_string(), Arc::new(aggregate_udf)); + Ok(()) + } + })?; + Ok(()) } fn as_any(&self) -> &dyn Any { self } } + +/// Declare a udf plugin registrar callback +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function named `registrar_udf_plugin` with a +/// pre-defined signature and symbol name. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_udf_plugin { + ($curr_plugin_type:ty, $constructor:path) => { + #[no_mangle] + pub extern "C" fn registrar_udf_plugin() -> Box { + // make sure the constructor is the correct type. + let constructor: fn() -> $curr_plugin_type = $constructor; + let object = constructor(); + Box::new(object) + } + + $crate::declare_plugin!($crate::plugin::PluginEnum::UDF); + }; +} + +/// get a Option of Immutable UDFPluginManager +pub fn get_udf_plugin_manager(path: &str) -> Option { + let udf_plugin_manager_opt = { + let gpm = global_plugin_manager(path).lock().unwrap(); + let plugin_registrar_opt = gpm.plugin_managers.get(&PluginEnum::UDF); + if let Some(plugin_registrar) = plugin_registrar_opt { + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + return Some(udf_plugin_manager.clone()); + } else { + return None; + } + } + None + }; + udf_plugin_manager_opt +}