diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..45b92cf --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,78 @@ +on: + push: + branches: + - master + pull_request: + +name: Continuous integration + +jobs: + check: + name: Check + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macOS-latest] + rust: [stable] + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + - uses: actions-rs/cargo@v1 + with: + command: check + args: --workspace + + test: + name: Test Suite + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macOS-latest] + rust: [stable] + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + - uses: actions-rs/cargo@v1 + with: + command: test + args: --workspace + + fmt: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + - run: rustup component add rustfmt + - uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + - run: rustup component add clippy + - uses: actions-rs/cargo@v1 + with: + command: clippy + args: --workspace --tests --examples -- -D warnings diff --git a/src/lib.rs b/src/lib.rs index a4e9e66..e1ee38e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,4 +22,4 @@ impl LinearLayerLike for Linear { fn bias(&self) -> Option<&Tensor> { self.bias() } -} \ No newline at end of file +} diff --git a/src/loralinear.rs b/src/loralinear.rs index 2d70fb5..ae89562 100644 --- a/src/loralinear.rs +++ b/src/loralinear.rs @@ -16,7 +16,12 @@ pub struct LoraLinear { } impl LoraLinear { - pub fn new(old: &Box, rank: usize, alpha: usize, device: &Device) -> Result { + pub fn new( + old: &dyn LinearLayerLike, + rank: usize, + alpha: usize, + device: &Device, + ) -> Result { let map = VarMap::new(); let a_weight = map.get( (rank, rank), @@ -42,9 +47,9 @@ impl LoraLinear { let a = Trc::new(Linear::new(a_weight, None)); let b = Trc::new(Linear::new(b_weight, Some(b_bias))); - + Ok(LoraLinear { - old: Trc::new(NonTrainableLinear::new_from_linear(&old)?), + old: Trc::new(NonTrainableLinear::new_from_linear(old)?), a, b, scale: alpha / rank, @@ -73,4 +78,4 @@ impl LinearLayerLike for LoraLinear { fn weight(&self) -> &Tensor { unimplemented!("Cannot get weight of LoraLinear layer"); } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 1204ecf..923f5e5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,8 @@ use std::collections::HashMap; use candle_core::{DType, Device, Result, Tensor}; use candle_lora::{ - loralinear::{LoraLinear, ALPHA_DEFAULT}, LinearLayerLike, + loralinear::{LoraLinear, ALPHA_DEFAULT}, + LinearLayerLike, }; use candle_nn::{linear_no_bias, Module, VarBuilder}; @@ -38,7 +39,7 @@ fn main() -> Result<()> { println!("Digit {digit:?} digit"); LoraLinear::new( - &model.layer, + &*model.layer, model.layer.weight().rank(), ALPHA_DEFAULT, &device, diff --git a/src/nontrainlinear.rs b/src/nontrainlinear.rs index 0af46a7..0f789c4 100644 --- a/src/nontrainlinear.rs +++ b/src/nontrainlinear.rs @@ -19,7 +19,7 @@ impl NonTrainableLinear { } } - pub fn new_from_linear(old: &Box) -> Result { + pub fn new_from_linear(old: &dyn LinearLayerLike) -> Result { Ok(Self::new( old.weight().detach()?, match old.bias() {