diff --git a/README.md b/README.md index ae8281b..aed3384 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ import track, { useTracking } from 'react-tracking'; - `dispatch`, which is a function to use instead of the default dispatch behavior. See the section on custom `dispatch()` later in this document. - `dispatchOnMount`, when set to `true`, dispatches the tracking data when the component mounts to the DOM. When provided as a function will be called on componentDidMount with all of the tracking context data as the only argument. - `process`, which is a function that can be defined once on some top-level component, used for selectively dispatching tracking events based on each component's tracking data. See more details later in this document. + - `forwardRef`, when set to `true`, adding a ref to the wrapped component will actually return the instance of the underlying component. Default is `false`. #### `tracking` prop diff --git a/src/__tests__/e2e.test.js b/src/__tests__/e2e.test.js index a9fea5e..7965fc6 100644 --- a/src/__tests__/e2e.test.js +++ b/src/__tests__/e2e.test.js @@ -866,4 +866,37 @@ describe('e2e', () => { status: 'failed', }); }); + + it('can access wrapped component by ref', async () => { + const focusFn = jest.fn(); + @track({}, { forwardRef: true }) + class Child extends React.Component { + focus = focusFn; + + render() { + return 'child'; + } + } + + class Parent extends React.Component { + componentDidMount() { + this.child.focus(); + } + + render() { + return ( + { + this.child = el; + }} + /> + ); + } + } + + const parent = await mount(); + + expect(parent.instance().child).not.toBeNull(); + expect(focusFn).toHaveBeenCalledTimes(1); + }); }); diff --git a/src/withTrackingComponentDecorator.js b/src/withTrackingComponentDecorator.js index ac0d897..510f7d6 100644 --- a/src/withTrackingComponentDecorator.js +++ b/src/withTrackingComponentDecorator.js @@ -15,13 +15,18 @@ export const ReactTrackingContext = React.createContext({}); export default function withTrackingComponentDecorator( trackingData = {}, - { dispatch = dispatchTrackingEvent, dispatchOnMount = false, process } = {} + { + dispatch = dispatchTrackingEvent, + dispatchOnMount = false, + process, + forwardRef = false, + } = {} ) { return DecoratedComponent => { const decoratedComponentName = DecoratedComponent.displayName || DecoratedComponent.name || 'Component'; - function WithTracking(props) { + function WithTracking({ rtFwdRef, ...props }) { const { tracking } = useContext(ReactTrackingContext); const latestProps = useRef(props); @@ -118,20 +123,32 @@ export default function withTrackingComponentDecorator( [getTrackingDispatcher, getTrackingDataFn, getProcessFn] ); + const propsToBePassed = useMemo( + () => (forwardRef ? { ...props, ref: rtFwdRef } : props), + [props, rtFwdRef] + ); + return useMemo( () => ( - + ), - [contextValue, props, trackingProp] + [contextValue, trackingProp, propsToBePassed] ); } - WithTracking.displayName = `WithTracking(${decoratedComponentName})`; + if (forwardRef) { + const forwarded = React.forwardRef((props, ref) => ( + + )); + forwarded.displayName = `WithTracking(${decoratedComponentName})`; + hoistNonReactStatic(forwarded, DecoratedComponent); + return forwarded; + } + WithTracking.displayName = `WithTracking(${decoratedComponentName})`; hoistNonReactStatic(WithTracking, DecoratedComponent); - return WithTracking; }; }