tensor2image and dataloder length #2889
Unanswered
wangjiawen2013
asked this question in
Q&A
Replies: 3 comments
-
|
Beta Was this translation helpful? Give feedback.
0 replies
-
Here's an implementation I use: use image::error::ParameterError;
use image::error::ParameterErrorKind;
use image::ImageBuffer;
use image::ImageError;
use image::ImageFormat;
use image::Luma;
use image::Rgb;
/// Example usage:
/// let device = WgpuDevice::default();
/// Your tensor with shape [channels, height, width]
/// let tensor = /* ... */;
/// save_image(
/// tensor,
/// "output.png",
/// ImageFormat::Png
/// ).expect("Failed to save image");
pub fn save_image<B: Backend, Q: AsRef<Path>>(
image_tensor: Tensor<B, 3, Int>,
image_path: Q,
image_format: ImageFormat,
) -> Result<(), ImageError> {
let width = image_tensor.dims()[2] as u32;
let height = image_tensor.dims()[1] as u32;
let channels = image_tensor.dims()[0] as u32;
// channels must be 1 or 3
if channels != 1 && channels != 3 {
return Err(ImageError::Parameter(ParameterError::from_kind(
ParameterErrorKind::Generic("Unsupported number of channels".to_string()),
)));
}
let image: Vec<u8> = image_tensor.into_data().iter::<u8>().collect::<Vec<u8>>();
let image_path_ref = image_path.as_ref(); // Get a reference to the path
if channels == 1 {
let image_buf = ImageBuffer::<Luma<u8>, Vec<u8>>::from_vec(width, height, image.clone())
.ok_or_else(|| {
ImageError::Parameter(ParameterError::from_kind(ParameterErrorKind::Generic(
"Failed to create image buffer".to_string(),
)))
})?;
if image_buf.is_empty() {
return Err(ImageError::Parameter(ParameterError::from_kind(
ParameterErrorKind::Generic("Image buffer is empty".to_string()),
)));
}
image_buf.save_with_format(image_path_ref, image_format)?;
}
if channels == 3 {
let image_buf = ImageBuffer::<Rgb<u8>, Vec<u8>>::from_vec(width, height, image.clone())
.ok_or_else(|| {
ImageError::Parameter(ParameterError::from_kind(ParameterErrorKind::Generic(
"Failed to create image buffer".to_string(),
)))
})?;
if image_buf.is_empty() {
return Err(ImageError::Parameter(ParameterError::from_kind(
ParameterErrorKind::Generic("Image buffer is empty".to_string()),
)));
}
image_buf.save_with_format(image_path_ref, image_format)?;
}
Ok(())
} |
Beta Was this translation helpful? Give feedback.
0 replies
-
There is also an implementation in https://github.com/tracel-ai/burn/blob/main/examples/wgan/src/training.rs pub fn save_image<B: Backend, Q: AsRef<Path>>(
images: Tensor<B, 4>,
nrow: u32,
path: Q,
) -> ImageResult<()> {
let ncol = (images.dims()[0] as f32 / nrow as f32).ceil() as u32;
let width = images.dims()[2] as u32;
let height = images.dims()[1] as u32;
// Supports both 1 and 3 channels image
let channels = match images.dims()[3] {
1 => 3,
3 => 1,
_ => panic!("Wrong channels number"),
};
let mut imgbuf = RgbImage::new(nrow * width, ncol * height);
// Write images into a nrow*ncol grid layout
for row in 0..nrow {
for col in 0..ncol {
let image: Tensor<B, 3> = images
.clone()
.slice((row * nrow + col) as usize..(row * nrow + col + 1) as usize)
.squeeze(0);
// The Rgb32 should be in range 0.0-1.0
let image = image.into_data().iter::<f32>().collect::<Vec<f32>>();
// Supports both 1 and 3 channels image
let image = image
.into_iter()
.flat_map(|n| std::iter::repeat(n).take(channels))
.collect();
let image = Rgb32FImage::from_vec(width, height, image).unwrap();
let image: RgbImage = image.convert();
for (x, y, pixel) in image.enumerate_pixels() {
imgbuf.put_pixel(row * width + x, col * height + y, *pixel);
}
}
}
imgbuf.save(path)
} |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have two questions now
from torchvision.utils import save_image
. How to convert a tensor to image and save it on disk using burn ?len(dataloader)
. How to get the batch numbers in a dataloader in burn ? It seems thatdataloader.num_items
only returns the total number of samples, not the batch number.Beta Was this translation helpful? Give feedback.
All reactions