diff --git a/packages/sdk/src/utils/Mapping.ts b/packages/sdk/src/utils/Mapping.ts index 1bb6a9e4f1..6f0f239f86 100644 --- a/packages/sdk/src/utils/Mapping.ts +++ b/packages/sdk/src/utils/Mapping.ts @@ -13,8 +13,9 @@ interface ValueWrapper { */ export class Mapping { - private delegate: Map> = new Map() - private valueFactory: (...args: K) => Promise + private readonly delegate: Map> = new Map() + private readonly pendingPromises: Map> = new Map() + private readonly valueFactory: (...args: K) => Promise constructor(valueFactory: (...args: K) => Promise) { this.valueFactory = valueFactory @@ -22,13 +23,25 @@ export class Mapping { async get(...args: K): Promise { const key = formLookupKey(...args) - let valueWrapper = this.delegate.get(key) - if (valueWrapper === undefined) { - const value = await this.valueFactory(...args) - valueWrapper = { value } - this.delegate.set(key, valueWrapper) + const pendingPromise = this.pendingPromises.get(key) + if (pendingPromise !== undefined) { + return await pendingPromise + } else { + let valueWrapper = this.delegate.get(key) + if (valueWrapper === undefined) { + const promise = this.valueFactory(...args) + this.pendingPromises.set(key, promise) + let value + try { + value = await promise + } finally { + this.pendingPromises.delete(key) + } + valueWrapper = { value } + this.delegate.set(key, valueWrapper) + } + return valueWrapper.value } - return valueWrapper.value } values(): V[] { diff --git a/packages/sdk/test/unit/Mapping.test.ts b/packages/sdk/test/unit/Mapping.test.ts index 73488036d2..f30676478e 100644 --- a/packages/sdk/test/unit/Mapping.test.ts +++ b/packages/sdk/test/unit/Mapping.test.ts @@ -1,3 +1,4 @@ +import { wait } from '@streamr/utils' import { Mapping } from '../../src/utils/Mapping' describe('Mapping', () => { @@ -27,7 +28,49 @@ describe('Mapping', () => { const mapping = new Mapping(valueFactory) expect(await mapping.get('foo')).toBe(undefined) expect(await mapping.get('foo')).toBe(undefined) - expect(valueFactory).toBeCalledTimes(1) + expect(valueFactory).toHaveBeenCalledTimes(1) }) + it('rejections are not cached', async () => { + const valueFactory = jest.fn().mockImplementation(async (p1: string, p2: number) => { + throw new Error(`error ${p1}-${p2}`) + }) + const mapping = new Mapping(valueFactory) + await expect(mapping.get('foo', 1)).rejects.toEqual(new Error('error foo-1')) + await expect(mapping.get('foo', 1)).rejects.toEqual(new Error('error foo-1')) + expect(valueFactory).toHaveBeenCalledTimes(2) + }) + + it('throws are not cached', async () => { + const valueFactory = jest.fn().mockImplementation((p1: string, p2: number) => { + throw new Error(`error ${p1}-${p2}`) + }) + const mapping = new Mapping(valueFactory) + await expect(mapping.get('foo', 1)).rejects.toEqual(new Error('error foo-1')) + await expect(mapping.get('foo', 1)).rejects.toEqual(new Error('error foo-1')) + expect(valueFactory).toHaveBeenCalledTimes(2) + }) + + it('concurrency', async () => { + const valueFactory = jest.fn().mockImplementation(async (p1: string, p2: number) => { + await wait(50) + return `${p1}${p2}` + }) + const mapping = new Mapping(valueFactory) + const results = await Promise.all([ + mapping.get('foo', 1), + mapping.get('foo', 2), + mapping.get('foo', 2), + mapping.get('foo', 1), + mapping.get('foo', 1) + ]) + expect(valueFactory).toHaveBeenCalledTimes(2) + expect(results).toEqual([ + 'foo1', + 'foo2', + 'foo2', + 'foo1', + 'foo1' + ]) + }) })