Skip to content

Commit

Permalink
Merge pull request #893 from jquesnelle/var-store-kind
Browse files Browse the repository at this point in the history
add method to set default kind of new variables in VarStore
  • Loading branch information
LaurentMazare committed Sep 12, 2024
2 parents 2687871 + 40ec5f7 commit 9c498e6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
28 changes: 14 additions & 14 deletions src/nn/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,24 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
};

/// Creates a new float tensor with the specified shape, device, and initialization.
pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError> {
pub fn f_init(i: Init, dims: &[i64], device: Device, kind: Kind) -> Result<Tensor, TchError> {
match i {
Init::Const(cst) => {
// Optimize the case for which a single C++ code can be done.
if cst == 0. {
Tensor::f_zeros(dims, (Kind::Float, device))
Tensor::f_zeros(dims, (kind, device))
} else if (cst - 1.).abs() <= f64::EPSILON {
Tensor::f_ones(dims, (Kind::Float, device))
Tensor::f_ones(dims, (kind, device))
} else {
Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst)
Tensor::f_ones(dims, (kind, device)).map(|t| t * cst)
}
}
Init::Uniform { lo, up } => {
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(lo, up)
}
Init::Uniform { lo, up } => Tensor::f_zeros(dims, (kind, device))?.f_uniform_(lo, up),
Init::Randn { mean, stdev } => {
if mean == 0. && (stdev - 1.).abs() <= f64::EPSILON {
Tensor::f_randn(dims, (Kind::Float, device))
Tensor::f_randn(dims, (kind, device))
} else {
Tensor::f_randn(dims, (Kind::Float, device)).map(|t| t * stdev + mean)
Tensor::f_randn(dims, (kind, device)).map(|t| t * stdev + mean)
}
}
Init::Kaiming { dist, fan, non_linearity } => {
Expand All @@ -130,10 +128,10 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
match dist {
NormalOrUniform::Uniform => {
let bound = 3f64.sqrt() * std;
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(-bound, bound)
Tensor::f_zeros(dims, (kind, device))?.f_uniform_(-bound, bound)
}
NormalOrUniform::Normal => {
let randn = Tensor::f_randn(dims, (Kind::Float, device))?;
let randn = Tensor::f_randn(dims, (kind, device))?;
Ok(randn * std)
}
}
Expand All @@ -148,7 +146,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
let cols: i64 = dims.iter().skip(1).product();

let mut flattened =
Tensor::f_empty([rows, cols], (Kind::Float, device))?.f_normal_(0.0, 1.0)?;
Tensor::f_empty([rows, cols], (kind, device))?.f_normal_(0.0, 1.0)?;
let flattened = if rows < cols { flattened.f_t_()? } else { flattened };

let (mut q, r) = Tensor::f_linalg_qr(&flattened, "reduced")?;
Expand All @@ -166,7 +164,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>

/// Creates a new float tensor with the specified shape, device, and initialization.
pub fn init(i: Init, dims: &[i64], device: Device) -> Tensor {
f_init(i, dims, device).unwrap()
f_init(i, dims, device, Kind::Float).unwrap()
}

impl Init {
Expand Down Expand Up @@ -197,7 +195,9 @@ impl Init {
tensor.copy_(&(tensor.randn_like() * stdev + mean));
}
Init::Orthogonal { gain } => {
let q = f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device()).unwrap();
let q =
f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device(), Kind::Float)
.unwrap();
crate::no_grad(|| tensor.view_as(&q).copy_(&q));
}
}
Expand Down
19 changes: 16 additions & 3 deletions src/nn/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct Variables {
pub struct VarStore {
pub variables_: Arc<Mutex<Variables>>,
device: Device,
kind: Kind,
}

/// A variable store with an associated path for variables naming.
Expand All @@ -57,7 +58,7 @@ impl VarStore {
pub fn new(device: Device) -> VarStore {
let variables =
Variables { named_variables: HashMap::new(), trainable_variables: Vec::new() };
VarStore { variables_: Arc::new(Mutex::new(variables)), device }
VarStore { variables_: Arc::new(Mutex::new(variables)), device, kind: Kind::Float }
}

pub fn merge(var_stores: Vec<(VarStore, Option<&str>)>) -> Result<VarStore, TchError> {
Expand Down Expand Up @@ -110,6 +111,11 @@ impl VarStore {
self.device
}

/// Gets the default kind of new variables
pub fn kind(&self) -> Kind {
self.kind
}

/// Returns the number of tensors currently stored on this var-store.
pub fn len(&self) -> usize {
let variables = self.variables_.lock().unwrap();
Expand Down Expand Up @@ -322,13 +328,15 @@ impl VarStore {
}
}

/// Casts all variables in a var store to the target kind .
/// Casts all variables in a var store to the target kind and sets the default kind
/// for new variables.
///
/// For floating-point conversion, methods `half`, `bfloat16`, `float` and `double`
/// should be preferred as they ensure only float-like variables will be converted
/// to the target type.
pub fn set_kind(&mut self, kind: Kind) {
self.root().set_kind(kind);
self.kind = kind;
}

/// Casts all float-like variable of a var store to half-precision (Half kind).
Expand Down Expand Up @@ -410,6 +418,11 @@ impl<'a> Path<'a> {
self.var_store.device
}

/// Gets the default kind of new variables
pub fn kind(&self) -> Kind {
self.var_store.kind
}

pub fn path(&self, name: &str) -> String {
if name.chars().any(|x| x == SEP) {
panic!("variable name cannot contain {SEP} {name}");
Expand Down Expand Up @@ -551,7 +564,7 @@ impl<'a> Path<'a> {
/// The variable uses a float tensor initialized as per the
/// related argument.
pub fn f_var(&self, name: &str, dims: &[i64], init: Init) -> Result<Tensor, TchError> {
let v = super::f_init(init, dims, self.device())?;
let v = super::f_init(init, dims, self.device(), self.kind())?;
Ok(self.add(name, v, true))
}

Expand Down
3 changes: 2 additions & 1 deletion tests/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ fn init_test() {
"{}",
"ortho_norm initialization failed {ortho_norm}"
);
let ortho_shape_fail = tch::nn::f_init(Init::Orthogonal { gain: 1.0 }, &[10], Device::Cpu);
let ortho_shape_fail =
tch::nn::f_init(Init::Orthogonal { gain: 1.0 }, &[10], Device::Cpu, tch::Kind::Float);
assert!(ortho_shape_fail.is_err());
let kaiming_u = vs.root().var("kaiming_u", &[20, 100], nn::init::DEFAULT_KAIMING_UNIFORM);
assert!(f64::abs(f64_from(&kaiming_u.mean(Kind::Float))) < 5e-3);
Expand Down

0 comments on commit 9c498e6

Please sign in to comment.