Skip to content

Commit

Permalink
Add CI
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Sep 9, 2023
1 parent 9ba2336 commit 866ad2c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 8 deletions.
78 changes: 78 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
on:
push:
branches:
- master
pull_request:

name: Continuous integration

jobs:
check:
name: Check
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
override: true
- uses: actions-rs/cargo@v1
with:
command: check
args: --workspace

test:
name: Test Suite
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ matrix.rust }}
override: true
- uses: actions-rs/cargo@v1
with:
command: test
args: --workspace

fmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check

clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add clippy
- uses: actions-rs/cargo@v1
with:
command: clippy
args: --workspace --tests --examples -- -D warnings
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ impl LinearLayerLike for Linear {
fn bias(&self) -> Option<&Tensor> {
self.bias()
}
}
}
13 changes: 9 additions & 4 deletions src/loralinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ pub struct LoraLinear {
}

impl LoraLinear {
pub fn new(old: &Box<dyn LinearLayerLike>, rank: usize, alpha: usize, device: &Device) -> Result<Self> {
pub fn new(
old: &dyn LinearLayerLike,
rank: usize,
alpha: usize,
device: &Device,
) -> Result<Self> {
let map = VarMap::new();
let a_weight = map.get(
(rank, rank),
Expand All @@ -42,9 +47,9 @@ impl LoraLinear {

let a = Trc::new(Linear::new(a_weight, None));
let b = Trc::new(Linear::new(b_weight, Some(b_bias)));

Ok(LoraLinear {
old: Trc::new(NonTrainableLinear::new_from_linear(&old)?),
old: Trc::new(NonTrainableLinear::new_from_linear(old)?),
a,
b,
scale: alpha / rank,
Expand Down Expand Up @@ -73,4 +78,4 @@ impl LinearLayerLike for LoraLinear {
fn weight(&self) -> &Tensor {
unimplemented!("Cannot get weight of LoraLinear layer");
}
}
}
5 changes: 3 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::collections::HashMap;

use candle_core::{DType, Device, Result, Tensor};
use candle_lora::{
loralinear::{LoraLinear, ALPHA_DEFAULT}, LinearLayerLike,
loralinear::{LoraLinear, ALPHA_DEFAULT},
LinearLayerLike,
};
use candle_nn::{linear_no_bias, Module, VarBuilder};

Expand Down Expand Up @@ -38,7 +39,7 @@ fn main() -> Result<()> {
println!("Digit {digit:?} digit");

LoraLinear::new(
&model.layer,
&*model.layer,
model.layer.weight().rank(),
ALPHA_DEFAULT,
&device,
Expand Down
2 changes: 1 addition & 1 deletion src/nontrainlinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl NonTrainableLinear {
}
}

pub fn new_from_linear(old: &Box<dyn LinearLayerLike>) -> Result<Self> {
pub fn new_from_linear(old: &dyn LinearLayerLike) -> Result<Self> {
Ok(Self::new(
old.weight().detach()?,
match old.bias() {
Expand Down

0 comments on commit 866ad2c

Please sign in to comment.