Softmax activation
数据位宽为 32,10分类。
softmax函数 ,输入归一化,求得各类概率。
1 指数运算
2 计算指数和
3 求指数和倒数
4 计算每个元素的softmax值
10分类有10个指数模块。
将多个输入输入到各自的指数模块 即exponet模块来求指数。 然后通过加法器求所有指数的和,再由floatReciprocal来计算指数和的倒数,最后通过乘法器来计算各输入的softmax值并输出。
指数模块 exponent
x - e^x
指数函数,求解e^x
逻辑: 用泰勒展开拟合,包含两个乘法器和一个加法器。
时序电路
迭代七次:
需要八个周期去处理,七个周期计算,一个周期输出。
10个指数模块,一个乘法器,一个加法器,floatReciprocal计数指数和的倒数。 softmax.v顶层模块
module softmax(inputs,clk,enable,outputs,ackSoft);
parameter DATA_WIDTH=32;
localparam inputNum=10; // 10分类
input [DATA_WIDTH*inputNum-1:0] inputs;
input clk;
input enable;
output reg [DATA_WIDTH*inputNum-1:0] outputs;
output reg ackSoft; // softmax的应答信号,表示计算完毕
wire [DATA_WIDTH-1:0] expSum;
wire [DATA_WIDTH-1:0] expReciprocal; // 指数之和的倒数
wire [DATA_WIDTH-1:0] outMul;
wire [DATA_WIDTH*inputNum-1:0] exponents ; //所有指数模块的输出整合
wire [inputNum-1:0] acksExp; //acknowledge signals of exponents 指数模块的应答信号
wire ackDiv; //ack signal of the division unit倒数模块的应答信号
reg enableDiv; //signal to enable division unit initially zero
reg [DATA_WIDTH-1:0] outExpReg;
reg [3:0] mulCounter;
reg [3:0] addCounter;
// 一个输入对应一个指数模块
// generate 可以被综合 例化10个指数模块
genvar i;
generate
for (i = 0; i < inputNum; i = i + 1) begin
exponent #(.DATA_WIDTH(DATA_WIDTH)) exp (
.x(inputs[DATA_WIDTH*i+:DATA_WIDTH]),
.enable(enable),
.clk(clk),
.output_exp(exponents[DATA_WIDTH*i+:DATA_WIDTH]),
.ack(acksExp[i]));
end
endgenerate
// 计算所有指数模块的和
floatAdd FADD1 (exponents[DATA_WIDTH*addCounter+:DATA_WIDTH],outExpReg,expSum);
// 计算和的倒数
floatReciprocal #(.DATA_WIDTH(DATA_WIDTH)) FR (.number(expSum),.clk(clk),.output_rec(expReciprocal),.ack(ackDiv),.enable(enableDiv));
// 计算softmax值,计算每个种类的分类概率
floatMult FM1 (exponents[DATA_WIDTH*mulCounter+:DATA_WIDTH],expReciprocal,outMul); //multiplication with reciprocal
always @ (negedge clk) begin
if(enable==1'b1) begin
if(ackSoft==1'b0) begin
//等待指数模块计算完毕
if(acksExp[0]==1'b1) begin //if the exponents finished
//指数总和 的倒数还没开始计算,则求和得到expSum
if(enableDiv==1'b0) begin //division still did not start
if(addCounter<4'b1001) begin
addCounter=addCounter+1;
outExpReg=expSum;
end
// expSum 计算好了,开始算倒数
else begin
enableDiv=1'b1;
end
end
// 倒数也计算好了,开始计算每个类别的概率
else if(ackDiv==1'b1) begin //check if the reciprocal is ready
if(mulCounter<4'b1010) begin
outputs[DATA_WIDTH*mulCounter+:DATA_WIDTH]=outMul;
mulCounter=mulCounter+1;
end
else begin
ackSoft=1'b1;
end
end
end
end
end
else begin
//未使能则保持复位状态
//if enable is off reset all counters and acks
mulCounter=4'b0000;
addCounter=4'b0000;
outExpReg=32'b00000000000000000000000000000000;
ackSoft=1'b0;
enableDiv=1'b0;
end
end
endmodule
指数模块
指数函数,通过泰勒展开实现
处理速度: 包括复位周期在内的八个时钟周期,7个周期计算,1周期输出结果
module exponent (x,clk,enable,output_exp,ack);
parameter DATA_WIDTH=32;
localparam taylor_iter=7; // 迭代次数
input [DATA_WIDTH-1:0] x;
input clk;
input enable;
output reg ack;
output reg [DATA_WIDTH-1:0] output_exp;
reg [DATA_WIDTH*taylor_iter-1:0] divisors; // 1/6 1/5 1/4 1/3 1/2 1 1
reg [DATA_WIDTH-1:0] mult1; //第一周期为1,之后的为乘法器2的输出,用来迭代
reg [DATA_WIDTH-1:0] one_or_x; //输入,第一周期为1,其余时刻为x
wire [DATA_WIDTH-1:0] out_m1; //output of the first multiplication which is either with 1 or x
wire [DATA_WIDTH-1:0] out_m2; //the output of the second muliplication and the input of the first
wire [DATA_WIDTH-1:0] output_add1;
reg [DATA_WIDTH-1:0] out_reg; //缓存加法器输出结果
floatMult FM1 (mult1,one_or_x,out_m1);
floatMult FM2 (out_m1,divisors[31:0],out_m2);
floatAdd FADD1 (out_m2,out_reg,output_add1);
always @ (posedge clk) begin
if(enable==1'b0) begin
one_or_x=32'b00111111100000000000000000000000; //initially 1
mult1=32'b00111111100000000000000000000000; //initially 1
out_reg=32'b00000000000000000000000000000000; //initially 0
output_exp=32'b00000000000000000000000000000000; //output zero until ack is 1
divisors=224'b00111110001010101010101010101011_00111110010011001100110011001101_00111110100000000000000000000000_00111110101010101010101010101011_00111111000000000000000000000000_00111111100000000000000000000000_00111111100000000000000000000000;
ack=1'b0; // acknowledge is 0 at the beginning
end
else begin
one_or_x=x;
mult1=out_m2; //获得乘法器2的输出,用来迭代
divisors=divisors>>32; //右移32位,把已经乘过的分数去掉
out_reg=output_add1;
// 迭代结束 if(divisors==224'b00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000)
begin
output_exp=output_add1;
ack=1'b1;
end
end
end
endmodule