Skip to content

Commit 410754a

Browse files
committed
remove knn gpu splitting
1 parent e5d1771 commit 410754a

File tree

3 files changed

+34
-153
lines changed

3 files changed

+34
-153
lines changed

tensorboard/plugins/projector/vz_projector/data.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ export class DataSet {
475475
} else {
476476
const knnGpuEnabled = (await util.hasWebGLSupport()) && !IS_FIREFOX;
477477
const result = await (knnGpuEnabled
478-
? knn.findKNNGPUCosDistNorm(data, nNeighbors, (d) => d.vector)
478+
? knn.findKNNTFCosDistNorm(data, nNeighbors, (d) => d.vector)
479479
: knn.findKNN(
480480
data,
481481
nNeighbors,

tensorboard/plugins/projector/vz_projector/knn.ts

Lines changed: 25 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,33 @@ export type NearestEntry = {
2222
index: number;
2323
dist: number;
2424
};
25-
/**
26-
* Optimal size for the height of the matrix when doing computation on the GPU
27-
* using WebGL. This was found experimentally.
28-
*
29-
* This also guarantees that for computing pair-wise distance for up to 10K
30-
* vectors, no more than 40MB will be allocated in the GPU. Without the
31-
* allocation limit, we can freeze the graphics of the whole OS.
32-
*/
33-
const OPTIMAL_GPU_BLOCK_SIZE = 256;
34-
/** Id of message box used for knn gpu progress bar. */
35-
const KNN_GPU_MSG_ID = 'knn-gpu';
25+
26+
/** Id of message box used for knn. */
27+
const KNN_MSG_ID = 'knn';
3628

3729
/**
3830
* Returns the K nearest neighbors for each vector where the distance
39-
* computation is done on the GPU (WebGL) using cosine distance.
31+
* computation is done using tensorflow.js using cosine distance.
4032
*
4133
* @param dataPoints List of data points, where each data point holds an
4234
* n-dimensional vector. Assumes that the vector is already normalized to unit
4335
* norm.
4436
* @param k Number of nearest neighbors to find.
4537
* @param accessor A method that returns the vector, given the data point.
4638
*/
47-
export function findKNNGPUCosDistNorm<T>(
39+
export function findKNNTFCosDistNorm<T>(
4840
dataPoints: T[],
4941
k: number,
5042
accessor: (dataPoint: T) => Float32Array
5143
): Promise<NearestEntry[][]> {
5244
const N = dataPoints.length;
5345
const dim = accessor(dataPoints[0]).length;
5446
// The goal is to compute a large matrix multiplication A*A.T where A is of
55-
// size NxD and A.T is its transpose. This results in a NxN matrix which
56-
// could be too big to store on the GPU memory. To avoid memory overflow, we
57-
// compute multiple A*partial_A.T where partial_A is of size BxD (B is much
58-
// smaller than N). This results in storing only NxB size matrices on the GPU
59-
// at a given time.
47+
// size NxD and A.T is its transpose. This results in a NxN matrix.
6048
// A*A.T will give us NxN matrix holding the cosine distance between every
6149
// pair of points, which we sort using KMin data structure to obtain the
6250
// K nearest neighbors for each point.
6351
const nearest: NearestEntry[][] = new Array(N);
64-
let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE);
65-
const actualPieceSize = Math.floor(N / numPieces);
66-
const modulo = N % actualPieceSize;
67-
numPieces += modulo ? 1 : 0;
68-
let offset = 0;
69-
let progress = 0;
70-
let progressDiff = 1 / (2 * numPieces);
71-
let piece = 0;
7252

7353
const typedArray = vector.toTypedArray(dataPoints, accessor);
7454
const bigMatrix = tf.tensor(typedArray, [N, dim]);
@@ -77,80 +57,50 @@ export function findKNNGPUCosDistNorm<T>(
7757
const bigMatrixSquared = tf.matMul(bigMatrix, bigMatrixTransposed);
7858
const cosDistMatrix = tf.sub(1, bigMatrixSquared);
7959

80-
let maybePaddedCosDistMatrix = cosDistMatrix;
81-
if (actualPieceSize * numPieces > N) {
82-
// Expect the input to be rank 2 (though it is not typed that way) so we
83-
// want to pad the first dimension so we split very evenly (all splitted
84-
// tensor have exactly the same dimesion).
85-
const padding: Array<[number, number]> = [
86-
[0, actualPieceSize * numPieces - N],
87-
[0, 0],
88-
];
89-
maybePaddedCosDistMatrix = tf.pad(cosDistMatrix, padding);
90-
}
91-
const splits = tf.split(
92-
maybePaddedCosDistMatrix,
93-
new Array(numPieces).fill(actualPieceSize),
94-
0
95-
);
96-
9760
function step(resolve: (result: NearestEntry[][]) => void) {
98-
let progressMsg =
99-
'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%';
10061
util
10162
.runAsyncTask(
102-
progressMsg,
63+
'Finding nearest neighbors...',
10364
async () => {
10465
// `.data()` returns flattened Float32Array of B * N dimension.
10566
// For matrix of
10667
// [ 1 2 ]
10768
// [ 3 4 ],
10869
// `.data()` returns [1, 2, 3, 4].
109-
const partial = await splits[piece].data();
110-
progress += progressDiff;
111-
for (let i = 0; i < actualPieceSize; i++) {
70+
const partial = await cosDistMatrix.data();
71+
for (let i = 0; i < N; i++) {
11272
let kMin = new KMin<NearestEntry>(k);
113-
let iReal = offset + i;
114-
if (iReal >= N) break;
11573
for (let j = 0; j < N; j++) {
11674
// Skip diagonal entries.
117-
if (j === iReal) {
75+
if (j === i) {
11876
continue;
11977
}
12078
// Access i * N's row at `j` column.
12179
// Reach row has N entries and j-th index has cosine distance
122-
// between iReal vs. j-th vectors.
80+
// between i-th vs. j-th vectors.
12381
const cosDist = partial[i * N + j];
12482
if (cosDist >= 0) {
12583
kMin.add(cosDist, {index: j, dist: cosDist});
12684
}
12785
}
128-
nearest[iReal] = kMin.getMinKItems();
86+
nearest[i] = kMin.getMinKItems();
12987
}
130-
progress += progressDiff;
131-
offset += actualPieceSize;
132-
piece++;
13388
},
134-
KNN_GPU_MSG_ID
89+
KNN_MSG_ID,
13590
)
13691
.then(
13792
() => {
138-
if (piece < numPieces) {
139-
step(resolve);
140-
} else {
141-
logging.setModalMessage(null!, KNN_GPU_MSG_ID);
142-
// Discard all tensors and free up the memory.
143-
bigMatrix.dispose();
144-
bigMatrixTransposed.dispose();
145-
bigMatrixSquared.dispose();
146-
cosDistMatrix.dispose();
147-
splits.forEach((split) => split.dispose());
148-
resolve(nearest);
149-
}
93+
logging.setModalMessage(null!, KNN_MSG_ID);
94+
// Discard all tensors and free up the memory.
95+
bigMatrix.dispose();
96+
bigMatrixTransposed.dispose();
97+
bigMatrixSquared.dispose();
98+
cosDistMatrix.dispose();
99+
resolve(nearest);
150100
},
151101
(error) => {
152-
// GPU failed. Reverting back to CPU.
153-
logging.setModalMessage(null!, KNN_GPU_MSG_ID);
102+
// TensorFlow.js failed. Reverting back to CPU.
103+
logging.setModalMessage(null!, KNN_MSG_ID);
154104
let distFunc = (a, b, limit) => vector.cosDistNorm(a, b);
155105
findKNN(dataPoints, k, accessor, distFunc).then((nearest) => {
156106
resolve(nearest);
@@ -212,47 +162,12 @@ export function findKNN<T>(
212162
for (let i = 0; i < N; i++) {
213163
nearest[i] = kMin[i].getMinKItems();
214164
}
165+
logging.setModalMessage(null!, KNN_MSG_ID);
215166
return nearest;
216-
}
167+
},
168+
KNN_MSG_ID,
217169
);
218170
}
219-
/** Calculates the minimum distance between a search point and a rectangle. */
220-
function minDist(
221-
point: [number, number],
222-
x1: number,
223-
y1: number,
224-
x2: number,
225-
y2: number
226-
) {
227-
let x = point[0];
228-
let y = point[1];
229-
let dx1 = x - x1;
230-
let dx2 = x - x2;
231-
let dy1 = y - y1;
232-
let dy2 = y - y2;
233-
if (dx1 * dx2 <= 0) {
234-
// x is between x1 and x2
235-
if (dy1 * dy2 <= 0) {
236-
// (x,y) is inside the rectangle
237-
return 0; // return 0 as point is in rect
238-
}
239-
return Math.min(Math.abs(dy1), Math.abs(dy2));
240-
}
241-
if (dy1 * dy2 <= 0) {
242-
// y is between y1 and y2
243-
// We know it is already inside the rectangle
244-
return Math.min(Math.abs(dx1), Math.abs(dx2));
245-
}
246-
let corner: [number, number];
247-
if (x > x2) {
248-
// Upper-right vs lower-right.
249-
corner = y > y2 ? [x2, y2] : [x2, y1];
250-
} else {
251-
// Upper-left vs lower-left.
252-
corner = y > y2 ? [x1, y2] : [x1, y1];
253-
}
254-
return Math.sqrt(vector.dist22D([x, y], corner));
255-
}
256171
/**
257172
* Returns the nearest neighbors of a particular point.
258173
*
@@ -282,4 +197,3 @@ export function findKNNofPoint<T>(
282197
return kMin.getMinKItems();
283198
}
284199

285-
export const TEST_ONLY = {OPTIMAL_GPU_BLOCK_SIZE};

tensorboard/plugins/projector/vz_projector/knn_test.ts

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
import {findKNN, findKNNGPUCosDistNorm, NearestEntry, TEST_ONLY} from './knn';
15+
import {findKNN, findKNNTFCosDistNorm, NearestEntry} from './knn';
1616
import {cosDistNorm, unit} from './vector';
1717

1818
describe('projector knn test', () => {
@@ -28,9 +28,9 @@ describe('projector knn test', () => {
2828
return vector;
2929
}
3030

31-
describe('#findKNNGPUCosDistNorm', () => {
31+
describe('#findKNNTFCosDistNorm', () => {
3232
it('finds n-nearest neighbor for each item', async () => {
33-
const values = await findKNNGPUCosDistNorm(
33+
const values = await findKNNTFCosDistNorm(
3434
[
3535
{a: unitVector(new Float32Array([1, 2, 0]))},
3636
{a: unitVector(new Float32Array([1, 1, 3]))},
@@ -54,7 +54,7 @@ describe('projector knn test', () => {
5454
});
5555

5656
it('returns less than N when number of item is lower', async () => {
57-
const values = await findKNNGPUCosDistNorm(
57+
const values = await findKNNTFCosDistNorm(
5858
[
5959
unitVector(new Float32Array([1, 2, 0])),
6060
unitVector(new Float32Array([1, 1, 3])),
@@ -65,29 +65,13 @@ describe('projector knn test', () => {
6565

6666
expect(getIndices(values)).toEqual([[1], [0]]);
6767
});
68-
69-
it('splits a large data into one that would fit into GPU memory', async () => {
70-
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
71-
const data = new Array(size).fill(
72-
unitVector(new Float32Array([1, 1, 1]))
73-
);
74-
const values = await findKNNGPUCosDistNorm(data, 1, (a) => a);
75-
76-
expect(getIndices(values)).toEqual([
77-
// Since distance to the diagonal entries (distance to self is 0) is
78-
// non-sensical, the diagonal entires are ignored. So for the first
79-
// item, the nearest neighbor should be 2nd item (index 1).
80-
[1],
81-
...new Array(size - 1).fill([0]),
82-
]);
83-
});
8468
});
8569

8670
describe('#findKNN', () => {
87-
// Covered by equality tests below (#findKNNGPUCosDistNorm == #findKNN).
71+
// Covered by equality tests below (#findKNNTFCosDistNorm == #findKNN).
8872
});
8973

90-
describe('#findKNNGPUCosDistNorm and #findKNN', () => {
74+
describe('#findKNNTFCosDistNorm and #findKNN', () => {
9175
it('returns same value when dist metrics are cosine', async () => {
9276
const data = [
9377
unitVector(new Float32Array([1, 2, 0])),
@@ -97,7 +81,7 @@ describe('projector knn test', () => {
9781
unitVector(new Float32Array([100, 10, 0])),
9882
unitVector(new Float32Array([95, 23, 100])),
9983
];
100-
const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a);
84+
const findKnnTFCosVal = await findKNNTFCosDistNorm(data, 2, (a) => a);
10185
const findKnnVal = await findKNN(
10286
data,
10387
2,
@@ -106,24 +90,7 @@ describe('projector knn test', () => {
10690
);
10791

10892
// Floating point precision makes it hard to test. Just assert indices.
109-
expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal));
110-
});
111-
112-
it('splits a large data without the result being wrong', async () => {
113-
const size = TEST_ONLY.OPTIMAL_GPU_BLOCK_SIZE + 5;
114-
const data = Array.from(new Array(size)).map((_, index) => {
115-
return unitVector(new Float32Array([index + 1, index + 1]));
116-
});
117-
118-
const findKnnGpuCosVal = await findKNNGPUCosDistNorm(data, 2, (a) => a);
119-
const findKnnVal = await findKNN(
120-
data,
121-
2,
122-
(a) => a,
123-
(a, b, limit) => cosDistNorm(a, b)
124-
);
125-
126-
expect(getIndices(findKnnGpuCosVal)).toEqual(getIndices(findKnnVal));
93+
expect(getIndices(findKnnTFCosVal)).toEqual(getIndices(findKnnVal));
12794
});
12895
});
12996
});

0 commit comments

Comments
 (0)