Skip to content

Commit dcc9a91

Browse files
authored
[android] Add support for float16 tensor (pytorch#15479)
The Android binding exposes helpers for feeding IEEE-754 half-precision (FP16) inputs directly. Use `Tensor.fromBlob(shortArray, shape)` or reuse a direct `ShortBuffer` created via `Tensor.allocateHalfBuffer(numElements)` to avoid extra copies: ```kotlin val shape = longArrayOf(24, 4096) val halfData: ShortArray = buildHalfEncodedData() val tensor = Tensor.fromBlob(halfData, shape) val buffer = Tensor.allocateHalfBuffer(halfData.size) buffer.put(halfData) buffer.rewind() val tensorNoCopy = Tensor.fromBlob(buffer, shape) ``` All buffers must be direct and use the native byte order; the helper above takes care of this.
1 parent d9a8f2d commit dcc9a91

File tree

3 files changed

+218
-4
lines changed

3 files changed

+218
-4
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.nio.FloatBuffer;
1919
import java.nio.IntBuffer;
2020
import java.nio.LongBuffer;
21+
import java.nio.ShortBuffer;
2122
import java.util.Arrays;
2223
import java.util.Locale;
2324
import org.pytorch.executorch.annotations.Experimental;
@@ -57,6 +58,7 @@ public abstract class Tensor {
5758
private static final int BYTE_SIZE_BYTES = 1;
5859
private static final int INT_SIZE_BYTES = 4;
5960
private static final int LONG_SIZE_BYTES = 8;
61+
private static final int HALF_SIZE_BYTES = 2;
6062
private static final int FLOAT_SIZE_BYTES = 4;
6163
private static final int DOUBLE_SIZE_BYTES = 8;
6264

@@ -107,6 +109,18 @@ public static LongBuffer allocateLongBuffer(int numElements) {
107109
.asLongBuffer();
108110
}
109111

112+
/**
113+
* Allocates a new direct {@link ShortBuffer} with native byte order and specified capacity that
114+
* can be used in {@link Tensor#fromBlob(ShortBuffer, long[])}.
115+
*
116+
* @param numElements capacity (number of elements) of result buffer.
117+
*/
118+
public static ShortBuffer allocateHalfBuffer(int numElements) {
119+
return ByteBuffer.allocateDirect(numElements * HALF_SIZE_BYTES)
120+
.order(ByteOrder.nativeOrder())
121+
.asShortBuffer();
122+
}
123+
110124
/**
111125
* Allocates a new direct {@link DoubleBuffer} with native byte order with specified capacity that
112126
* can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}.
@@ -187,6 +201,23 @@ public static Tensor fromBlob(float[] data, long[] shape) {
187201
return new Tensor_float32(floatBuffer, shape);
188202
}
189203

204+
/**
205+
* Creates a new Tensor instance with dtype torch.float16 with specified shape and data as array
206+
* of IEEE-754 half-precision values encoded in {@code short}s.
207+
*
208+
* @param data Tensor elements encoded as 16-bit floats.
209+
* @param shape Tensor shape
210+
*/
211+
public static Tensor fromBlob(short[] data, long[] shape) {
212+
checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL);
213+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
214+
checkShape(shape);
215+
checkShapeAndDataCapacityConsistency(data.length, shape);
216+
final ShortBuffer shortBuffer = allocateHalfBuffer((int) numel(shape));
217+
shortBuffer.put(data);
218+
return new Tensor_float16(shortBuffer, shape);
219+
}
220+
190221
/**
191222
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of
192223
* longs.
@@ -301,6 +332,26 @@ public static Tensor fromBlob(FloatBuffer data, long[] shape) {
301332
return new Tensor_float32(data, shape);
302333
}
303334

335+
/**
336+
* Creates a new Tensor instance with dtype torch.float16 with specified shape and data.
337+
*
338+
* @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)}
339+
* elements encoded as IEEE-754 half-precision floats. The buffer is used directly without
340+
* copying.
341+
* @param shape Tensor shape
342+
*/
343+
public static Tensor fromBlob(ShortBuffer data, long[] shape) {
344+
checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL);
345+
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
346+
checkShape(shape);
347+
checkShapeAndDataCapacityConsistency(data.capacity(), shape);
348+
checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT);
349+
checkArgument(
350+
(data.order() == ByteOrder.nativeOrder()),
351+
ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER);
352+
return new Tensor_float16(data, shape);
353+
}
354+
304355
/**
305356
* Creates a new Tensor instance with dtype torch.int64 with specified shape and data.
306357
*
@@ -388,6 +439,16 @@ public byte[] getDataAsByteArray() {
388439
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
389440
}
390441

442+
/**
443+
* @return a Java short array that contains the tensor data interpreted as IEEE-754 half-precision
444+
* bit patterns. This may be a copy or reference.
445+
* @throws IllegalStateException if it is called for a non-float16 tensor.
446+
*/
447+
public short[] getDataAsShortArray() {
448+
throw new IllegalStateException(
449+
"Tensor of type " + getClass().getSimpleName() + " cannot return data as short array.");
450+
}
451+
391452
/**
392453
* @return a Java byte array that contains the tensor data. This may be a copy or reference.
393454
* @throws IllegalStateException if it is called for a non-uint8 tensor.
@@ -569,6 +630,74 @@ public String toString() {
569630
}
570631
}
571632

633+
static class Tensor_float16 extends Tensor {
634+
private final ShortBuffer data;
635+
636+
private Tensor_float16(ShortBuffer data, long[] shape) {
637+
super(shape);
638+
this.data = data;
639+
}
640+
641+
@Override
642+
public DType dtype() {
643+
return DType.HALF;
644+
}
645+
646+
@Override
647+
Buffer getRawDataBuffer() {
648+
return data;
649+
}
650+
651+
@Override
652+
public short[] getDataAsShortArray() {
653+
data.rewind();
654+
short[] arr = new short[data.remaining()];
655+
data.get(arr);
656+
return arr;
657+
}
658+
659+
@Override
660+
public float[] getDataAsFloatArray() {
661+
data.rewind();
662+
int remaining = data.remaining();
663+
float[] arr = new float[remaining];
664+
for (int i = 0; i < remaining; i++) {
665+
arr[i] = halfBitsToFloat(data.get());
666+
}
667+
return arr;
668+
}
669+
670+
@Override
671+
public String toString() {
672+
return String.format("Tensor(%s, dtype=torch.float16)", Arrays.toString(shape));
673+
}
674+
675+
private static float halfBitsToFloat(short halfBits) {
676+
int h = halfBits & 0xFFFF;
677+
int sign = (h >>> 15) & 0x1;
678+
int exp = (h >>> 10) & 0x1F;
679+
int mant = h & 0x3FF;
680+
681+
if (exp == 0) {
682+
if (mant == 0) {
683+
return sign == 0 ? 0.0f : -0.0f;
684+
}
685+
float result = mant * 5.9604645e-8f; // 2^-24
686+
return sign == 0 ? result : -result;
687+
} else if (exp == 0x1F) {
688+
if (mant == 0) {
689+
return sign == 0 ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY;
690+
}
691+
int bits = (sign << 31) | 0x7f800000 | (mant << 13);
692+
return Float.intBitsToFloat(bits);
693+
} else {
694+
int exp32 = exp + 112; // 127 (float bias) - 15 (half bias)
695+
int bits = (sign << 31) | (exp32 << 23) | (mant << 13);
696+
return Float.intBitsToFloat(bits);
697+
}
698+
}
699+
}
700+
572701
static class Tensor_int64 extends Tensor {
573702
private final LongBuffer data;
574703

@@ -691,6 +820,8 @@ private static Tensor nativeNewTensor(
691820

692821
if (DType.FLOAT.jniCode == dtype) {
693822
tensor = new Tensor_float32(data.asFloatBuffer(), shape);
823+
} else if (DType.HALF.jniCode == dtype) {
824+
tensor = new Tensor_float16(data.asShortBuffer(), shape);
694825
} else if (DType.INT32.jniCode == dtype) {
695826
tensor = new Tensor_int32(data.asIntBuffer(), shape);
696827
} else if (DType.INT64.jniCode == dtype) {
@@ -727,6 +858,11 @@ public byte[] toByteArray() {
727858
tensorAsByteArray = new byte[(int) numel()];
728859
Tensor_int8 thiz = (Tensor_int8) this;
729860
ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray());
861+
} else if (dtype() == DType.HALF) {
862+
dtypeSize = HALF_SIZE_BYTES;
863+
tensorAsByteArray = new byte[(int) numel() * dtypeSize];
864+
Tensor_float16 thiz = (Tensor_float16) this;
865+
ByteBuffer.wrap(tensorAsByteArray).asShortBuffer().put(thiz.getDataAsShortArray());
730866
} else if (dtype() == DType.INT16) {
731867
throw new IllegalArgumentException("DType.INT16 is not supported in Java so far");
732868
} else if (dtype() == DType.INT32) {
@@ -794,6 +930,8 @@ public static Tensor fromByteArray(byte[] bytes) {
794930
return new Tensor_uint8(buffer, shape);
795931
} else if (dtype == DType.INT8.jniCode) {
796932
return new Tensor_int8(buffer, shape);
933+
} else if (dtype == DType.HALF.jniCode) {
934+
return new Tensor_float16(buffer.asShortBuffer(), shape);
797935
} else if (dtype == DType.INT32.jniCode) {
798936
return new Tensor_int32(buffer.asIntBuffer(), shape);
799937
} else if (dtype == DType.INT64.jniCode) {

extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
*/
88
package org.pytorch.executorch
99

10+
import java.nio.ByteOrder
1011
import org.assertj.core.api.Assertions.assertThatThrownBy
12+
import org.junit.Assert.assertArrayEquals
1113
import org.junit.Assert.assertEquals
14+
import org.junit.Assert.assertTrue
1215
import org.junit.Test
1316
import org.junit.runner.RunWith
1417
import org.junit.runners.JUnit4
@@ -184,6 +187,65 @@ class TensorTest {
184187
assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong())
185188
}
186189

190+
@Test
191+
fun testHalfTensorFromShortArrayAndBuffer() {
192+
val data =
193+
shortArrayOf(
194+
0x3C00.toShort(), // 1.0
195+
0xC000.toShort(), // -2.0
196+
0x0000.toShort(), // 0.0
197+
0x7C00.toShort(), // +inf
198+
)
199+
val shape = longArrayOf(2, 2)
200+
var tensor = Tensor.fromBlob(data, shape)
201+
assertEquals(DType.HALF, tensor.dtype())
202+
assertEquals(shape[0], tensor.shape()[0])
203+
assertEquals(shape[1], tensor.shape()[1])
204+
assertEquals(4, tensor.numel())
205+
assertArrayEquals(data, tensor.dataAsShortArray)
206+
val floats = tensor.dataAsFloatArray
207+
assertEquals(1.0f.toDouble(), floats[0].toDouble(), 1e-6)
208+
assertEquals((-2.0f).toDouble(), floats[1].toDouble(), 1e-6)
209+
assertEquals(0.0f.toDouble(), floats[2].toDouble(), 1e-6)
210+
assertEquals(Float.POSITIVE_INFINITY.toDouble(), floats[3].toDouble(), 0.0)
211+
212+
val buffer = Tensor.allocateHalfBuffer(data.size)
213+
assertTrue(buffer.isDirect)
214+
assertEquals(ByteOrder.nativeOrder(), buffer.order())
215+
buffer.put(data)
216+
buffer.rewind()
217+
218+
tensor = Tensor.fromBlob(buffer, longArrayOf(data.size.toLong()))
219+
assertEquals(DType.HALF, tensor.dtype())
220+
assertEquals(data.size.toLong(), tensor.shape()[0])
221+
assertEquals(data.size.toLong(), tensor.numel())
222+
assertArrayEquals(data, tensor.dataAsShortArray)
223+
val raw = tensor.rawDataBuffer as java.nio.ShortBuffer
224+
assertTrue(raw === buffer)
225+
}
226+
227+
@Test
228+
fun testHalfTensorSerializationRoundTrip() {
229+
val data =
230+
shortArrayOf(
231+
0x0000.toShort(),
232+
0x0400.toShort(),
233+
0x3C00.toShort(),
234+
0x7BFF.toShort(),
235+
)
236+
val shape = longArrayOf(2, 2)
237+
val tensor = Tensor.fromBlob(data, shape)
238+
val serialized = tensor.toByteArray()
239+
val deserialized = Tensor.fromByteArray(serialized)
240+
241+
assertEquals(DType.HALF, deserialized.dtype())
242+
assertEquals(shape[0], deserialized.shape()[0])
243+
assertEquals(shape[1], deserialized.shape()[1])
244+
assertEquals(4, deserialized.numel())
245+
assertArrayEquals(data, deserialized.dataAsShortArray)
246+
assertEquals(1.0f.toDouble(), deserialized.dataAsFloatArray[2].toDouble(), 1e-6)
247+
}
248+
187249
@Test
188250
fun testIllegalDataTypeException() {
189251
val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE)

extension/android/jni/jni_layer.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/extension/module/module.h>
1414
#include <executorch/extension/runner_util/inputs.h>
1515
#include <executorch/extension/tensor/tensor.h>
16+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
1617
#include <executorch/runtime/core/portable_type/tensor_impl.h>
1718
#include <executorch/runtime/platform/log.h>
1819
#include <executorch/runtime/platform/platform.h>
@@ -117,7 +118,7 @@ class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
117118
std::vector<executorch::aten::SizesType> shape_vec;
118119
shape_vec.reserve(rank);
119120

120-
auto numel = 1;
121+
int64_t numel = 1;
121122
for (int i = 0; i < rank; i++) {
122123
shape_vec.push_back(shapeArr[i]);
123124
}
@@ -132,11 +133,24 @@ class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
132133
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
133134
}
134135
ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
135-
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
136-
if (dataCapacity != numel) {
136+
const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
137+
if (dataCapacity < 0) {
138+
std::stringstream ss;
139+
ss << "Tensor buffer is not direct or has invalid capacity";
140+
jni_helper::throwExecutorchException(
141+
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
142+
}
143+
const size_t elementSize = executorch::runtime::elementSize(scalar_type);
144+
const jlong expectedElements = static_cast<jlong>(numel);
145+
const jlong expectedBytes =
146+
expectedElements * static_cast<jlong>(elementSize);
147+
const bool matchesElements = dataCapacity == expectedElements;
148+
const bool matchesBytes = dataCapacity == expectedBytes;
149+
if (!matchesElements && !matchesBytes) {
137150
std::stringstream ss;
138151
ss << "Tensor dimensions(elements number: " << numel
139-
<< "inconsistent with buffer capacity " << dataCapacity << "]";
152+
<< ") inconsistent with buffer capacity " << dataCapacity
153+
<< " (element size bytes: " << elementSize << ")";
140154
jni_helper::throwExecutorchException(
141155
static_cast<uint32_t>(Error::InvalidArgument), ss.str().c_str());
142156
}

0 commit comments

Comments
 (0)