Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ public KnowledgeBuilderConfigurationImpl(ClassLoader... classLoaders) {

/**
* Programmatic properties file, added with lease precedence
* @param properties
*/
public KnowledgeBuilderConfigurationImpl(Properties properties) {
init(properties,
Expand All @@ -170,8 +169,6 @@ public KnowledgeBuilderConfigurationImpl(Properties properties) {

/**
* Programmatic properties file, added with lease precedence
* @param classLoaders
* @param properties
*/
public KnowledgeBuilderConfigurationImpl(Properties properties,
ClassLoader... classLoaders) {
Expand Down Expand Up @@ -537,21 +534,6 @@ private void buildAccumulateFunctionsMap() {
}
}

/**
* This method is deprecated and will be removed
* @return
*
* @deprecated
*/
public Map<String, String> getAccumulateFunctionsMap() {
Map<String, String> result = new HashMap<String, String>();
for (Map.Entry<String, AccumulateFunction> entry : this.accumulateFunctions.entrySet()) {
result.put(entry.getKey(),
entry.getValue().getClass().getName());
}
return result;
}

public void addAccumulateFunction(String identifier,
String className) {
this.accumulateFunctions.put(identifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*/
public class MemoryResourceReader implements ResourceReader {

private Map resources;
private Map<String, byte[]> resources;

private Set<String> modifiedResourcesSinceLastMark;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.drools.compiler.rule.builder.RuleBuildContext;
import org.drools.compiler.rule.builder.RuleConditionBuilder;
import org.drools.compiler.rule.builder.dialect.java.parser.JavaLocalDeclarationDescr;
import org.drools.compiler.rule.builder.dialect.mvel.MVELExprAnalyzer;
import org.drools.compiler.rule.builder.util.PackageBuilderUtil;
import org.drools.core.base.accumulators.JavaAccumulatorFunctionExecutor;
import org.drools.core.base.extractors.ArrayElementReader;
Expand All @@ -47,6 +48,8 @@
import org.drools.core.util.index.IndexUtil;
import org.kie.api.runtime.rule.AccumulateFunction;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -124,12 +127,12 @@ public RuleConditionElement build( final RuleBuildContext context,
return accumulate;
}

private Accumulate buildExternalFunctionCall( final RuleBuildContext context,
final AccumulateDescr accumDescr,
final RuleConditionElement source,
private Accumulate buildExternalFunctionCall( RuleBuildContext context,
AccumulateDescr accumDescr,
RuleConditionElement source,
Map<String, Declaration> declsInScope,
Map<String, Class< ? >> declCls,
final boolean readLocalsFromTuple) {
boolean readLocalsFromTuple) {
// list of functions to build
final List<AccumulateFunctionCallDescr> funcCalls = accumDescr.getFunctions();
// list of available source declarations
Expand All @@ -150,7 +153,7 @@ private Accumulate buildExternalFunctionCall( final RuleBuildContext context,

int index = 0;
for ( AccumulateFunctionCallDescr fc : funcCalls ) {
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc);
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc, source, declCls);
if (function == null) {
return null;
}
Expand All @@ -164,7 +167,7 @@ private Accumulate buildExternalFunctionCall( final RuleBuildContext context,
accumulators );
} else {
AccumulateFunctionCallDescr fc = accumDescr.getFunctions().get(0);
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc);
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc, source, declCls);
if (function == null) {
return null;
}
Expand Down Expand Up @@ -211,22 +214,45 @@ private void bindReaderToDeclaration( RuleBuildContext context, AccumulateDescr
}
}

private AccumulateFunction getAccumulateFunction(RuleBuildContext context, AccumulateDescr accumDescr, AccumulateFunctionCallDescr fc) {
private AccumulateFunction getAccumulateFunction(RuleBuildContext context,
AccumulateDescr accumDescr,
AccumulateFunctionCallDescr fc,
RuleConditionElement source,
Map<String, Class< ? >> declCls) {
String functionName = getFunctionName( context, fc, source, declCls );

// find the corresponding function
AccumulateFunction function = context.getConfiguration().getAccumulateFunction( fc.getFunction() );
AccumulateFunction function = context.getConfiguration().getAccumulateFunction( functionName );
if( function == null ) {
// might have been imported in the package
function = context.getKnowledgeBuilder().getPackage().getAccumulateFunctions().get(fc.getFunction());
function = context.getKnowledgeBuilder().getPackage().getAccumulateFunctions().get( functionName );
}
if ( function == null ) {
context.addError( new DescrBuildError( accumDescr,
context.getRuleDescr(),
null,
"Unknown accumulate function: '" + fc.getFunction() + "' on rule '" + context.getRuleDescr().getName() + "'. All accumulate functions must be registered before building a resource." ) );
"Unknown accumulate function: '" + functionName + "' on rule '" + context.getRuleDescr().getName() + "'. All accumulate functions must be registered before building a resource." ) );
}
return function;
}

private String getFunctionName( RuleBuildContext context, AccumulateFunctionCallDescr fc, RuleConditionElement source, Map<String, Class<?>> declCls ) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getFunctionName doesn't deal with averageBD, maxDB, etc, so I suspect that the BigDecimalAverageAccumulateFunction. We or QA will need to build a unit test for very combination of functionName and type, just to be sure that they all work properly.

String functionName = fc.getFunction();
if (functionName.equals( "sum" )) {
Class<?> exprClass = MVELExprAnalyzer.getExpressionType( context, declCls, source, fc.getParams()[0] );
if (exprClass == int.class || exprClass == Integer.class) {
functionName = "sumI";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity, I believe they should have been named "sumInteger" etc, but I believe that ship has already sailed, didn't it?

} else if (exprClass == long.class || exprClass == Long.class) {
functionName = "sumL";
} else if (exprClass == BigInteger.class) {
functionName = "sumBI";
} else if (exprClass == BigDecimal.class) {
functionName = "sumBD";
}
}
return functionName;
}

private Accumulator buildAccumulator(RuleBuildContext context, AccumulateDescr accumDescr, Map<String, Declaration> declsInScope, Map<String, Class<?>> declCls, boolean readLocalsFromTuple, Declaration[] sourceDeclArr, Set<Declaration> requiredDecl, AccumulateFunctionCallDescr fc, AccumulateFunction function) {
// analyze the expression
final JavaAnalysisResult analysis = (JavaAnalysisResult) context.getDialect().analyzeBlock( context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ public class MVELDialect
initBuilder();
}

private static final MVELExprAnalyzer analyzer = new MVELExprAnalyzer();

private final Map interceptors = MVELCompilationUnit.INTERCEPTORS;

protected List<KnowledgeBuilderResult> results;
Expand Down Expand Up @@ -509,12 +507,12 @@ public AnalysisResult analyzeExpression(final PackageBuildContext context,
BaseDescr temp = context.getParentDescr();
context.setParentDescr( descr );
try {
result = analyzer.analyzeExpression( context,
(String) content,
availableIdentifiers,
localTypes,
"drools",
KnowledgeHelper.class );
result = MVELExprAnalyzer.analyzeExpression( context,
(String) content,
availableIdentifiers,
localTypes,
"drools",
KnowledgeHelper.class );
} catch ( final Exception e ) {
DialectUtil.copyErrorLocation( e, descr );
context.addError( new DescrBuildError( context.getParentDescr(),
Expand Down Expand Up @@ -547,12 +545,12 @@ public AnalysisResult analyzeBlock(final PackageBuildContext context,
String contextIndeifier,
Class kcontextClass) {

return analyzer.analyzeExpression( context,
text,
availableIdentifiers,
localTypes,
contextIndeifier,
kcontextClass );
return MVELExprAnalyzer.analyzeExpression( context,
text,
availableIdentifiers,
localTypes,
contextIndeifier,
kcontextClass );
}

public MVELCompilationUnit getMVELCompilationUnit(final String expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import org.drools.compiler.rule.builder.RuleBuildContext;
import org.drools.compiler.rule.builder.dialect.DialectUtil;
import org.drools.core.base.EvaluatorWrapper;
import org.drools.core.rule.Declaration;
import org.drools.core.rule.MVELDialectRuntimeData;
import org.drools.core.rule.RuleConditionElement;
import org.kie.api.definition.rule.Rule;
import org.mvel2.MVEL;
import org.mvel2.ParserConfiguration;
import org.mvel2.ParserContext;
import org.mvel2.optimizers.OptimizerFactory;
import org.mvel2.util.PropertyTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -45,6 +49,8 @@
*/
public class MVELExprAnalyzer {

private static final Logger log = LoggerFactory.getLogger( MVELExprAnalyzer.class );

static {
// always use mvel reflective optimizer
OptimizerFactory.setDefaultOptimizer(OptimizerFactory.SAFE_REFLECTIVE);
Expand All @@ -71,12 +77,12 @@ public MVELExprAnalyzer() {
* If an error occurs in the parser.
*/
@SuppressWarnings("unchecked")
public MVELAnalysisResult analyzeExpression(final PackageBuildContext context,
final String expr,
final BoundIdentifiers availableIdentifiers,
final Map<String, Class< ? >> localTypes,
String contextIndeifier,
Class kcontextClass) {
public static MVELAnalysisResult analyzeExpression(final PackageBuildContext context,
final String expr,
final BoundIdentifiers availableIdentifiers,
final Map<String, Class< ? >> localTypes,
String contextIndeifier,
Class kcontextClass) {
if ( expr.trim().length() <= 0 ) {
MVELAnalysisResult result = analyze( (Set<String>) Collections.EMPTY_SET, availableIdentifiers );
result.setMvelVariables( new HashMap<String, Class< ? >>() );
Expand Down Expand Up @@ -246,8 +252,8 @@ public MVELAnalysisResult analyzeExpression(final PackageBuildContext context,
* @throws RecognitionException
* If an error occurs in the parser.
*/
private MVELAnalysisResult analyze(final Set<String> identifiers,
final BoundIdentifiers availableIdentifiers) {
private static MVELAnalysisResult analyze(final Set<String> identifiers,
final BoundIdentifiers availableIdentifiers) {

MVELAnalysisResult result = new MVELAnalysisResult();
result.setIdentifiers( identifiers );
Expand Down Expand Up @@ -290,4 +296,28 @@ private MVELAnalysisResult analyze(final Set<String> identifiers,

return result;
}

public static Class<?> getExpressionType(PackageBuildContext context,
Map<String, Class< ? >> declCls,
RuleConditionElement source,
String expression) {
MVELDialectRuntimeData data = ( MVELDialectRuntimeData) context.getPkg().getDialectRuntimeRegistry().getDialectData( "mvel" );
ParserConfiguration conf = data.getParserConfiguration();
conf.setClassLoader( context.getKnowledgeBuilder().getRootClassLoader() );
ParserContext pctx = new ParserContext( conf );
pctx.setStrongTyping(true);
pctx.setStrictTypeEnforcement(true);
for (Map.Entry<String, Class< ? >> entry : declCls.entrySet()) {
pctx.addInput(entry.getKey(), entry.getValue());
}
for (Declaration decl : source.getOuterDeclarations().values()) {
pctx.addInput(decl.getBindingName(), decl.getDeclarationClass());
}
try {
return MVEL.analyze( expression, pctx );
} catch (Exception e) {
log.warn( "Unable to parse expression: " + expression, e );
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ drools.accumulate.function.average = org.drools.core.base.accumulators.AverageAc
drools.accumulate.function.max = org.drools.core.base.accumulators.MaxAccumulateFunction
drools.accumulate.function.min = org.drools.core.base.accumulators.MinAccumulateFunction
drools.accumulate.function.count = org.drools.core.base.accumulators.CountAccumulateFunction
drools.accumulate.function.sum = org.drools.core.base.accumulators.SumAccumulateFunction
drools.accumulate.function.collectList = org.drools.core.base.accumulators.CollectListAccumulateFunction
drools.accumulate.function.collectSet = org.drools.core.base.accumulators.CollectSetAccumulateFunction
drools.accumulate.function.sumBD = org.drools.core.base.accumulators.BigDecimalSumAccumulateFunction
drools.accumulate.function.averageBD = org.drools.core.base.accumulators.BigDecimalAverageAccumulateFunction
drools.accumulate.function.sum = org.drools.core.base.accumulators.SumAccumulateFunction
drools.accumulate.function.sumI = org.drools.core.base.accumulators.IntegerSumAccumulateFunction
drools.accumulate.function.sumL = org.drools.core.base.accumulators.LongSumAccumulateFunction
drools.accumulate.function.sumBI = org.drools.core.base.accumulators.BigIntegerSumAccumulateFunction
drools.accumulate.function.sumBD = org.drools.core.base.accumulators.BigDecimalSumAccumulateFunction

drools.evaluator.coincides = org.drools.core.base.evaluators.CoincidesEvaluatorDefinition
drools.evaluator.before = org.drools.core.base.evaluators.BeforeEvaluatorDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ public void testAccnSharingWithMixedDormantAndActive() {
list.add( act.getRule().getName() + ":" + act.getDeclarationValue( "$s1" ) + ":" + act.isQueued() );
}

assertContains( new String[]{"rule1:6.0:true", "rule2:6.0:true", "rule3:6.0:false"},
assertContains( new String[]{"rule1:6:true", "rule2:6:true", "rule3:6:false"},
list );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2003,7 +2003,7 @@ public void testAccumulateWithBoundExpression() {
ksession.dispose();
assertEquals( 1,
results.size() );
assertEquals( 9.0,
assertEquals( 9L,
results.get( 0 ) );
}

Expand Down Expand Up @@ -3012,7 +3012,7 @@ public void testAccumulateWithOr() {
.build()
.newKieSession();

List<Double> list = new ArrayList<Double>();
List<Integer> list = new ArrayList<Integer>();
ksession.setGlobal( "list", list );

ksession.insert( 1 );
Expand All @@ -3021,7 +3021,7 @@ public void testAccumulateWithOr() {
ksession.fireAllRules();

assertEquals( 1, list.size() );
assertEquals( "hello".length(), (double)list.get(0), 0.01 );
assertEquals( "hello".length(), (int)list.get(0), 0.01 );
}

@Test
Expand Down Expand Up @@ -3054,7 +3054,7 @@ public void testMvelAccumulateWithOr() {
ksession.fireAllRules();

assertEquals( 1, list.size() );
assertEquals( "hello".length(), (double)list.get(0), 0.01 );
assertEquals( "hello".length(), list.get(0), 0.01 );
}

public static class Converter {
Expand Down Expand Up @@ -3090,4 +3090,32 @@ public void testNormalizeStagedTuplesInAccumulate() {
ksession.fireAllRules();
assertEquals( 1, list.size() );
}

@Test
public void testTypedSumOnAccumulate() {
// DROOLS-1175
String drl1 =
"global java.util.List list;\n" +
"rule R when\n" +
" $i : Integer()\n" +
" accumulate ( $s : String(), $result : sum( $s.length() ) )\n" +
"then\n" +
" list.add($result);\n" +
"end";

KieSession ksession = new KieHelper().addContent( drl1, ResourceType.DRL )
.build()
.newKieSession();

List<Integer> list = new ArrayList<Integer>();
ksession.setGlobal( "list", list );

ksession.insert( 1 );
ksession.insert( "hello" );
ksession.insert( "hi" );
ksession.fireAllRules();

assertEquals( 1, list.size() );
assertEquals( "hello".length() + "hi".length(), (int)list.get(0) );
}
}
Loading