diff --git a/src/container.ts b/src/container.ts index 2a288bde8..da74d0c9a 100644 --- a/src/container.ts +++ b/src/container.ts @@ -31,6 +31,7 @@ const defaultContainer: { get(someClass: { new (...args: any[]): T } | Functi })(); let userContainer: { get(someClass: { new (...args: any[]): T } | Function): T }; +let userAsyncContainer: { get(someClass: { new (...args: any[]): T } | Function): Promise }; let userContainerOptions: UseContainerOptions; /** @@ -41,6 +42,11 @@ export function useContainer(iocContainer: { get(someClass: any): any }, options userContainerOptions = options; } +export function useAsyncContainer(iocContainer: { get(someClass: any): Promise }, options?: UseContainerOptions): void { + userAsyncContainer = iocContainer; + userContainerOptions = options; +} + /** * Gets the IOC container used by this library. */ @@ -57,3 +63,17 @@ export function getFromContainer(someClass: { new (...args: any[]): T } | Fun } return defaultContainer.get(someClass); } + +export async function getFromAsyncContainer(someClass: { new (...args: any[]): T } | Function): Promise { + if (userAsyncContainer) { + try { + const instance = await userAsyncContainer.get(someClass); + if (instance) return instance; + + if (!userContainerOptions || !userContainerOptions.fallback) return instance; + } catch (error) { + if (!userContainerOptions || !userContainerOptions.fallbackOnErrors) throw error; + } + } + return getFromContainer(someClass); +} diff --git a/src/index.ts b/src/index.ts index 34aa0f38b..8fa62be30 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,7 +3,7 @@ import { ValidatorOptions } from './validation/ValidatorOptions'; import { ValidationSchema } from './validation-schema/ValidationSchema'; import { getMetadataStorage } from './metadata/MetadataStorage'; import { Validator } from './validation/Validator'; -import { getFromContainer } from './container'; +import { getFromAsyncContainer, getFromContainer } from './container'; // ------------------------------------------------------------------------- // Export everything api users needs @@ -49,11 +49,12 @@ export function validate( maybeValidatorOptions?: ValidatorOptions ): Promise { if (typeof schemaNameOrObject === 'string') { - return getFromContainer(Validator).validate( - schemaNameOrObject, - objectOrValidationOptions as object, - maybeValidatorOptions - ); + return getFromAsyncContainer(Validator) + .then((validator) => validator.validate( + schemaNameOrObject, + objectOrValidationOptions as object, + maybeValidatorOptions + )); } else { return getFromContainer(Validator).validate(schemaNameOrObject, objectOrValidationOptions as ValidatorOptions); } @@ -82,11 +83,12 @@ export function validateOrReject( maybeValidatorOptions?: ValidatorOptions ): Promise { if (typeof schemaNameOrObject === 'string') { - return getFromContainer(Validator).validateOrReject( - schemaNameOrObject, - objectOrValidationOptions as object, - maybeValidatorOptions - ); + return getFromAsyncContainer(Validator).then((validator) => validator + .validateOrReject( + schemaNameOrObject, + objectOrValidationOptions as object, + maybeValidatorOptions + )); } else { return getFromContainer(Validator).validateOrReject( schemaNameOrObject,